Skip to content

Commit

Permalink
fix weigit loading for GQA with TP (vllm-project#2379)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangch9 committed Jan 15, 2024
1 parent 65276fa commit 8c6aa5e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ def weight_loader(self,
shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
shard_id = tp_rank // self.num_kv_head_replicas
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
Expand Down

0 comments on commit 8c6aa5e

Please sign in to comment.