Skip to content

Commit

Permalink
Don't split MoE weights.
Browse files Browse the repository at this point in the history
As per ggerganov#7058 (comment).
This helps avoid a memcopy when running.
  • Loading branch information
heiner committed May 3, 2024
1 parent 7c54f47 commit 72920b1
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def maybe_quantize_tensor(tensor, ggml_type):


def get_dtype_and_ggml_type(tensor, ggml_type):
if tensor.ndim == 2:
if tensor.ndim in (2, 3):
if tensor.shape[1] % GGML_QK8_0 == 0:
return np.int8, ggml_type
else:
return np.float16, gguf.GGMLQuantizationType.F16
else:
# 1d weight: convert it to float32
assert tensor.ndim == 1
assert tensor.ndim == 1, tensor
return np.float32, gguf.GGMLQuantizationType.F32


Expand Down Expand Up @@ -236,15 +236,15 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config):
cache.update(state_dict)
tensor = cache.pop(key)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
tensor = maybe_quantize_tensor(tensor, tensor_ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()

array = tensor.numpy()
print(
f"dumping {key}: {tensor_ggml_type.name}/{array.dtype}, {array.shape}, {array.nbytes} bytes"
f"dumping {key}:",
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
)
f.write_tensor_data(array)

tensor_info.append((key, tensor.shape, tensor_ggml_type.name))
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name))

try:
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
Expand Down Expand Up @@ -282,15 +282,10 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
weight = weight.transpose(-1, -2)

if tensor_name.endswith("ffn_gate_inp.weight"):
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"):
result[tensor_name] = weight[experts] # gather.
elif "experts" not in tensor_name:
result[tensor_name] = weight
else:
# split moe
for i, expert in enumerate(experts):
key = tensor_name.replace("experts", str(i))
result[key] = weight[expert]

return result

Expand Down Expand Up @@ -328,14 +323,10 @@ def extract_vocabulary_from_model(vocab):
def get_weight_names(config):
weight_names = ["token_embd.weight"]
for i in range(config.num_hidden_layers):
for j in range(config.num_experts):
weight_names += [
f"blk.{i}.ffn_gate.{j}.weight",
f"blk.{i}.ffn_down.{j}.weight",
f"blk.{i}.ffn_up.{j}.weight",
]

weight_names += [
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
Expand Down Expand Up @@ -399,9 +390,9 @@ def ffn_size(emb_size, widening_factor):
]
for i in range(config.num_hidden_layers):
tensor_names += [
f"blk.{i}.ffn_gate.experts.weight",
f"blk.{i}.ffn_down.experts.weight",
f"blk.{i}.ffn_up.experts.weight",
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
Expand Down

0 comments on commit 72920b1

Please sign in to comment.