diff --git a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py index 950cefc26..e3b96d813 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py +++ b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py @@ -1934,25 +1934,11 @@ def __init__(self, module, config, sdp_module_ref, distributed=False): or isinstance(module.v_proj, WeightOnlyQuantizedLinear) ) ) and not (hasattr(self, "use_qk_layernorm") and self.use_qk_layernorm): - - def get_weight_shape(mod): - if hasattr(mod, "in_features") and hasattr(mod, "out_features"): - return [mod.in_features, mod.out_features] - elif hasattr(mod, "weight") and hasattr(mod.weight, "shape"): - return list(mod.weight.shape) - return None - - weight_shapes = [ - get_weight_shape(mod) - for mod in [module.q_proj, module.k_proj, module.v_proj] - ] - if weight_shapes[0] is not None and all( - weight_shapes[0] == shape for shape in weight_shapes[1:] - ): - self.concat_qkv = _IPEXConcatLinearRef( - [module.q_proj, module.k_proj, module.v_proj] - ) - del module.q_proj, module.k_proj, module.v_proj + # we support MHA, GQA, MQA for concat linear + self.concat_qkv = _IPEXConcatLinearRef( + [module.q_proj, module.k_proj, module.v_proj] + ) + del module.q_proj, module.k_proj, module.v_proj self._IPEXScaleDotProduct = _IPEXScaleDotProductRef(module, config)