Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 142 additions & 127 deletions gpt_oss/metal/scripts/create-local-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm
from openai_harmony import load_harmony_encoding, HarmonyEncodingName

parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights')
parser = argparse.ArgumentParser(prog='create-local-model.py', description='Convert a checkpoint directory to a local model file')
parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory')
parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file')

Expand Down Expand Up @@ -204,140 +204,155 @@ def main(args):
num_included_tokens = 200013 + 1
print(f"Tokenizer: {num_included_tokens} tokens")

tensors = {}
# Read from all files ending with .safetensors in the checkpoint directory
safetensor_files = [
os.path.join(options.src, fname)
for fname in os.listdir(options.src)
if fname.endswith(".safetensors")
]
# Build a mapping from tensor name to filepath
tensor_name_to_file = {}
for safetensor_file in safetensor_files:
with safe_open(safetensor_file, framework="pt", device="cpu") as src:
for key in src.keys():
tensor_name_to_file[key] = safetensor_file

def get_tensor(name):
with safe_open(tensor_name_to_file[name], framework="pt", device="cpu") as src:
return src.get_tensor(name)

with open(options.dst, "wb") as dst:
with safe_open(os.path.join(options.src, "model.safetensors"), framework="pt", device="cpu") as src:
write_file_header(dst)

yarn_low = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi))
/ math.log(rope_theta)
)
yarn_high = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi))
/ math.log(rope_theta)
)

write_model_header(dst,
context_length=int(initial_context_length * rope_scaling_factor),
num_blocks=num_blocks,
num_experts=num_experts,
num_active_experts=num_active_experts,
embedding_dim=embedding_dim,
mlp_dim=mlp_dim,
swiglu_limit=swiglu_limit,
head_dim=head_dim,
num_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attention_window=attention_window,
rope_theta=rope_theta,
interpolation_scale=1.0 / rope_scaling_factor,
yarn_offset=-yarn_low / (yarn_high - yarn_low),
yarn_scale=1.0 / (yarn_high - yarn_low),
yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0,
rmsnorm_epsilon=1.0e-5)

write_tokenizer_header(dst,
num_special_tokens=num_included_tokens - num_text_tokens,
num_text_tokens=num_text_tokens,
regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1,
tokens_size=tokens_size)

### Tokenizer
# Special tokens
for token_idx in range(num_text_tokens, num_included_tokens):
token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii')
if token in INCLUDE_SPECIAL_TOKENS:
dst.write(SPECIAL_TOKEN_UUID[token])
else:
dst.write(bytes(16))
# Regex
dst.write(o200k_gptoss._pat_str.encode("ascii"))
dst.write(struct.pack('B', 0))
# Text tokens
tokenizer_bytes_written = 0
for t in range(num_text_tokens):
token_bytes = o200k_gptoss.decode_single_token_bytes(t)
assert len(token_bytes) > 0
dst.write(struct.pack('<H', len(token_bytes)))
dst.write(token_bytes)
tokenizer_bytes_written += len(token_bytes) + 2
assert(tokenizer_bytes_written == tokens_size), (tokenizer_bytes_written, tokens_size)
write_file_header(dst)

yarn_low = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi))
/ math.log(rope_theta)
)
yarn_high = (
head_dim / 2
* math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi))
/ math.log(rope_theta)
)

write_model_header(dst,
context_length=int(initial_context_length * rope_scaling_factor),
num_blocks=num_blocks,
num_experts=num_experts,
num_active_experts=num_active_experts,
embedding_dim=embedding_dim,
mlp_dim=mlp_dim,
swiglu_limit=swiglu_limit,
head_dim=head_dim,
num_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attention_window=attention_window,
rope_theta=rope_theta,
interpolation_scale=1.0 / rope_scaling_factor,
yarn_offset=-yarn_low / (yarn_high - yarn_low),
yarn_scale=1.0 / (yarn_high - yarn_low),
yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0,
rmsnorm_epsilon=1.0e-5)

write_tokenizer_header(dst,
num_special_tokens=num_included_tokens - num_text_tokens,
num_text_tokens=num_text_tokens,
regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1,
tokens_size=tokens_size)

### Tokenizer
# Special tokens
for token_idx in range(num_text_tokens, num_included_tokens):
token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii')
if token in INCLUDE_SPECIAL_TOKENS:
dst.write(SPECIAL_TOKEN_UUID[token])
else:
dst.write(bytes(16))
# Regex
dst.write(o200k_gptoss._pat_str.encode("ascii"))
dst.write(struct.pack('B', 0))
# Text tokens
tokenizer_bytes_written = 0
for t in range(num_text_tokens):
token_bytes = o200k_gptoss.decode_single_token_bytes(t)
assert len(token_bytes) > 0
dst.write(struct.pack('<H', len(token_bytes)))
dst.write(token_bytes)
tokenizer_bytes_written += len(token_bytes) + 2
assert(tokenizer_bytes_written == tokens_size), (tokenizer_bytes_written, tokens_size)
write_padding(dst)

embedding_weight = get_tensor("embedding.weight")
# Filter out unused tokens
embedding_weight = embedding_weight[:num_included_tokens, :]
write_embedding_weight(dst, embedding_weight)

for n in tqdm(range(num_blocks)):
write_rmsnorm_gain(dst, get_tensor(f"block.{n}.attn.norm.scale"))

attn_qkv_weight = get_tensor(f"block.{n}.attn.qkv.weight")
attn_qkv_bias = get_tensor(f"block.{n}.attn.qkv.bias")
for qkv in (attn_qkv_weight, attn_qkv_bias):
qk = qkv[:head_dim * (num_q_heads + num_kv_heads), ...].contiguous()
v = qkv[head_dim * (num_q_heads + num_kv_heads):, ...].contiguous()
qk = qk.view(num_q_heads + num_kv_heads, 2, head_dim // 2, -1).transpose(1, 2).reshape(num_q_heads + num_kv_heads, head_dim, -1)
q = qk[:num_q_heads, ...]
k = qk[num_q_heads:, ...]
# Factor multiplication by 1/sqrt(64) = 0.125 = 0.5 * 0.25 in SDPA into Q and K projections
assert head_dim == 64
q *= 0.5
k *= 0.25
v = v.view(num_kv_heads, head_dim, -1)
qkv.copy_(torch.cat((q, k, v), dim=0).reshape(*qkv.shape))

write_linear_weight(dst, attn_qkv_weight, attn_qkv_bias)

write_attn_sink(dst, get_tensor(f"block.{n}.attn.sinks"))

write_linear_weight(dst, get_tensor(f"block.{n}.attn.out.weight"), get_tensor(f"block.{n}.attn.out.bias"))

write_rmsnorm_gain(dst, get_tensor(f"block.{n}.mlp.norm.scale"))

write_linear_weight(dst, get_tensor(f"block.{n}.mlp.gate.weight"), get_tensor(f"block.{n}.mlp.gate.bias"))

write_rmsnorm_gain(dst, get_tensor("norm.scale"))

unembedding_weight = get_tensor("unembedding.weight")
unembedding_weight = unembedding_weight[:num_included_tokens, :]
write_linear_weight(dst, unembedding_weight)

for n in tqdm(range(num_blocks)):
mlp1_blocks = get_tensor(f"block.{n}.mlp.mlp1_weight.blocks")
mlp1_scales = get_tensor(f"block.{n}.mlp.mlp1_weight.scales")
assert mlp1_scales.min().item() < 254 - UE8_OFFSET
mlp1_bias = get_tensor(f"block.{n}.mlp.mlp1_bias")

mlp2_blocks = get_tensor(f"block.{n}.mlp.mlp2_weight.blocks")
mlp2_scales = get_tensor(f"block.{n}.mlp.mlp2_weight.scales")
assert mlp2_scales.min().item() < 254 - UE8_OFFSET
mlp2_bias = get_tensor(f"block.{n}.mlp.mlp2_bias")

# Write MoE weights grouped by expert
write_padding(dst)

embedding_weight = src.get_tensor("embedding.weight")
# Filter out unused tokens
embedding_weight = embedding_weight[:num_included_tokens, :]
write_embedding_weight(dst, embedding_weight)

for n in tqdm(range(num_blocks)):
write_rmsnorm_gain(dst, src.get_tensor(f"block.{n}.attn.norm.scale"))

attn_qkv_weight = src.get_tensor(f"block.{n}.attn.qkv.weight")
attn_qkv_bias = src.get_tensor(f"block.{n}.attn.qkv.bias")
for qkv in (attn_qkv_weight, attn_qkv_bias):
qk = qkv[:head_dim * (num_q_heads + num_kv_heads), ...].contiguous()
v = qkv[head_dim * (num_q_heads + num_kv_heads):, ...].contiguous()
qk = qk.view(num_q_heads + num_kv_heads, 2, head_dim // 2, -1).transpose(1, 2).reshape(num_q_heads + num_kv_heads, head_dim, -1)
q = qk[:num_q_heads, ...]
k = qk[num_q_heads:, ...]
# Factor multiplication by 1/sqrt(64) = 0.125 = 0.5 * 0.25 in SDPA into Q and K projections
assert head_dim == 64
q *= 0.5
k *= 0.25
v = v.view(num_kv_heads, head_dim, -1)
qkv.copy_(torch.cat((q, k, v), dim=0).reshape(*qkv.shape))

write_linear_weight(dst, attn_qkv_weight, attn_qkv_bias)

write_attn_sink(dst, src.get_tensor(f"block.{n}.attn.sinks"))

write_linear_weight(dst, src.get_tensor(f"block.{n}.attn.out.weight"), src.get_tensor(f"block.{n}.attn.out.bias"))

write_rmsnorm_gain(dst, src.get_tensor(f"block.{n}.mlp.norm.scale"))

write_linear_weight(dst, src.get_tensor(f"block.{n}.mlp.gate.weight"), src.get_tensor(f"block.{n}.mlp.gate.bias"))

write_rmsnorm_gain(dst, src.get_tensor("norm.scale"))

unembedding_weight = src.get_tensor("unembedding.weight")
unembedding_weight = unembedding_weight[:num_included_tokens, :]
write_linear_weight(dst, unembedding_weight)

for n in tqdm(range(num_blocks)):
mlp1_blocks = src.get_tensor(f"block.{n}.mlp.mlp1_weight.blocks")
mlp1_scales = src.get_tensor(f"block.{n}.mlp.mlp1_weight.scales")
assert mlp1_scales.min().item() < 254 - UE8_OFFSET
mlp1_bias = src.get_tensor(f"block.{n}.mlp.mlp1_bias")

mlp2_blocks = src.get_tensor(f"block.{n}.mlp.mlp2_weight.blocks")
mlp2_scales = src.get_tensor(f"block.{n}.mlp.mlp2_weight.scales")
assert mlp2_scales.min().item() < 254 - UE8_OFFSET
mlp2_bias = src.get_tensor(f"block.{n}.mlp.mlp2_bias")

# Write MoE weights grouped by expert
write_padding(dst)

for e in range(num_experts):
write_padding(dst, alignment_multiple=16)
dst.write(mlp1_blocks[e, ...].view(torch.uint8).numpy().tobytes())
for e in range(num_experts):
write_padding(dst, alignment_multiple=16)
dst.write(mlp1_blocks[e, ...].view(torch.uint8).numpy().tobytes())

write_padding(dst, alignment_multiple=16)
dst.write((mlp1_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())
write_padding(dst, alignment_multiple=16)
dst.write((mlp1_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())

write_padding(dst, alignment_multiple=16)
dst.write(mlp1_bias[e, ...].view(torch.uint8).numpy().tobytes())
write_padding(dst, alignment_multiple=16)
dst.write(mlp1_bias[e, ...].view(torch.uint8).numpy().tobytes())

write_padding(dst, alignment_multiple=16)
dst.write(mlp2_blocks[e, ...].view(torch.uint8).numpy().tobytes())
write_padding(dst, alignment_multiple=16)
dst.write(mlp2_blocks[e, ...].view(torch.uint8).numpy().tobytes())

write_padding(dst, alignment_multiple=16)
dst.write((mlp2_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())
write_padding(dst, alignment_multiple=16)
dst.write((mlp2_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())

write_padding(dst, alignment_multiple=16)
dst.write(mlp2_bias[e, ...].view(torch.uint8).numpy().tobytes())
write_padding(dst, alignment_multiple=16)
dst.write(mlp2_bias[e, ...].view(torch.uint8).numpy().tobytes())

if __name__ == "__main__":
main(sys.argv[1:])