From 57ed901141042fe208d67589e99cd16ee39353a6 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Mon, 22 Sep 2025 19:28:50 -0700 Subject: [PATCH] Metal: support sharded checkpoints in the converter --- gpt_oss/metal/scripts/create-local-model.py | 269 +++++++++++--------- 1 file changed, 142 insertions(+), 127 deletions(-) diff --git a/gpt_oss/metal/scripts/create-local-model.py b/gpt_oss/metal/scripts/create-local-model.py index c0de8bdf..7f16ba99 100644 --- a/gpt_oss/metal/scripts/create-local-model.py +++ b/gpt_oss/metal/scripts/create-local-model.py @@ -14,7 +14,7 @@ from tqdm import tqdm from openai_harmony import load_harmony_encoding, HarmonyEncodingName -parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights') +parser = argparse.ArgumentParser(prog='create-local-model.py', description='Convert a checkpoint directory to a local model file') parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory') parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file') @@ -204,140 +204,155 @@ def main(args): num_included_tokens = 200013 + 1 print(f"Tokenizer: {num_included_tokens} tokens") - tensors = {} + # Read from all files ending with .safetensors in the checkpoint directory + safetensor_files = [ + os.path.join(options.src, fname) + for fname in os.listdir(options.src) + if fname.endswith(".safetensors") + ] + # Build a mapping from tensor name to filepath + tensor_name_to_file = {} + for safetensor_file in safetensor_files: + with safe_open(safetensor_file, framework="pt", device="cpu") as src: + for key in src.keys(): + tensor_name_to_file[key] = safetensor_file + + def get_tensor(name): + with safe_open(tensor_name_to_file[name], framework="pt", device="cpu") as src: + return src.get_tensor(name) + with open(options.dst, "wb") as dst: - with safe_open(os.path.join(options.src, "model.safetensors"), framework="pt", device="cpu") as src: - write_file_header(dst) - - yarn_low = ( - head_dim / 2 - * math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi)) - / math.log(rope_theta) - ) - yarn_high = ( - head_dim / 2 - * math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi)) - / math.log(rope_theta) - ) - - write_model_header(dst, - context_length=int(initial_context_length * rope_scaling_factor), - num_blocks=num_blocks, - num_experts=num_experts, - num_active_experts=num_active_experts, - embedding_dim=embedding_dim, - mlp_dim=mlp_dim, - swiglu_limit=swiglu_limit, - head_dim=head_dim, - num_heads=num_q_heads, - num_kv_heads=num_kv_heads, - attention_window=attention_window, - rope_theta=rope_theta, - interpolation_scale=1.0 / rope_scaling_factor, - yarn_offset=-yarn_low / (yarn_high - yarn_low), - yarn_scale=1.0 / (yarn_high - yarn_low), - yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0, - rmsnorm_epsilon=1.0e-5) - - write_tokenizer_header(dst, - num_special_tokens=num_included_tokens - num_text_tokens, - num_text_tokens=num_text_tokens, - regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1, - tokens_size=tokens_size) - - ### Tokenizer - # Special tokens - for token_idx in range(num_text_tokens, num_included_tokens): - token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii') - if token in INCLUDE_SPECIAL_TOKENS: - dst.write(SPECIAL_TOKEN_UUID[token]) - else: - dst.write(bytes(16)) - # Regex - dst.write(o200k_gptoss._pat_str.encode("ascii")) - dst.write(struct.pack('B', 0)) - # Text tokens - tokenizer_bytes_written = 0 - for t in range(num_text_tokens): - token_bytes = o200k_gptoss.decode_single_token_bytes(t) - assert len(token_bytes) > 0 - dst.write(struct.pack(' 0 + dst.write(struct.pack('