Skip to content

Commit

Permalink
support concat linear for gqa (#2733)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianan-gu committed Apr 7, 2024
1 parent 7bc3869 commit f5b941c
Showing 1 changed file with 5 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1934,25 +1934,11 @@ def __init__(self, module, config, sdp_module_ref, distributed=False):
or isinstance(module.v_proj, WeightOnlyQuantizedLinear)
)
) and not (hasattr(self, "use_qk_layernorm") and self.use_qk_layernorm):

def get_weight_shape(mod):
if hasattr(mod, "in_features") and hasattr(mod, "out_features"):
return [mod.in_features, mod.out_features]
elif hasattr(mod, "weight") and hasattr(mod.weight, "shape"):
return list(mod.weight.shape)
return None

weight_shapes = [
get_weight_shape(mod)
for mod in [module.q_proj, module.k_proj, module.v_proj]
]
if weight_shapes[0] is not None and all(
weight_shapes[0] == shape for shape in weight_shapes[1:]
):
self.concat_qkv = _IPEXConcatLinearRef(
[module.q_proj, module.k_proj, module.v_proj]
)
del module.q_proj, module.k_proj, module.v_proj
# we support MHA, GQA, MQA for concat linear
self.concat_qkv = _IPEXConcatLinearRef(
[module.q_proj, module.k_proj, module.v_proj]
)
del module.q_proj, module.k_proj, module.v_proj

self._IPEXScaleDotProduct = _IPEXScaleDotProductRef(module, config)

Expand Down

0 comments on commit f5b941c

Please sign in to comment.