In [1]:
import re
import json
import pickle as pk
from pathlib import Path as pt

In [2]:
runtime_path_inp = input("Enter the runtime path ('same', '<path>'): ").strip().lower()
runtime_uuid_inp = input("Enter the model configuration ('<uuid>'/'all'): ").lower()
runtime_type_inp = input("Enter the extract type ('from_pkl_obj', 'from_pkl_wgt'): ").lower()

In [3]:
########################
# Runtime variables
########################

if runtime_path_inp == "same":
    runtime_path = "."
else:
    runtime_path = runtime_path_inp

runtime_uuid = runtime_uuid_inp
runtime_type = runtime_type_inp

config_path = f"{runtime_path}/config"
inputs_path = f"{runtime_path}/inputs"
outputs_path = f"{runtime_path}/outputs"
models_path = f"{runtime_path}/models"

In [4]:
########################
# Extract model components
########################

    
runtime_uuids = []
if runtime_uuid == "all":
    # Define the directory
    directory = pt(f"{models_path}")

    # Regex pattern for GUID
    pattern = re.compile(r'model_([a-f0-9\-]{36})\.pkl')

    # Find matching files and extract GUIDs
    for file in directory.glob('model_*.pkl'):
        match = pattern.match(file.name)
        if match:
            runtime_uuids.append(match.group(1))

# Loop through each GUID to extract weights
for runtime_uuid in runtime_uuids:
    if runtime_type == "from_pkl_obj":
        # Load the model from a single file
        models_path_inp = f"{models_path}/model_{runtime_uuid}.pkl"
        with open(models_path_inp, "rb") as f:
            model_ = pk.load(f)

        # Extract all weight arrays from the model
        weights_dict = {
            # Token and position embeddings
            'wte_weight': model_.wte.weight,
            'wpe_weight': model_.wpe.weight,
            
            # Final layer norm
            'ln_f_gamma': model_.ln_f.gamma,
            'ln_f_beta': model_.ln_f.beta,
            
            # Language model head
            'lm_head_weight': model_.lm_head.weight,
            'lm_head_bias': model_.lm_head.bias if model_.lm_head.bias is not None else None,
        }

        # Add block weights
        for i, block in enumerate(model_.blocks):
            # Layer norms
            weights_dict[f'block_{i}_ln1_gamma'] = block.ln_1.gamma
            weights_dict[f'block_{i}_ln1_beta'] = block.ln_1.beta
            weights_dict[f'block_{i}_ln2_gamma'] = block.ln_2.gamma
            weights_dict[f'block_{i}_ln2_beta'] = block.ln_2.beta
            
            # Multi-head attention
            weights_dict[f'block_{i}_mha_q_weight'] = block.mha.q_proj.weight
            weights_dict[f'block_{i}_mha_k_weight'] = block.mha.k_proj.weight
            weights_dict[f'block_{i}_mha_v_weight'] = block.mha.v_proj.weight
            weights_dict[f'block_{i}_mha_c_weight'] = block.mha.c_proj.weight
            weights_dict[f'block_{i}_mha_c_bias'] = block.mha.c_proj.bias if block.mha.c_proj.bias is not None else None
            
            # MLP
            weights_dict[f'block_{i}_mlp_fc_weight'] = block.mlp.c_fc.weight
            weights_dict[f'block_{i}_mlp_fc_bias'] = block.mlp.c_fc.bias if block.mlp.c_fc.bias is not None else None
            weights_dict[f'block_{i}_mlp_proj_weight'] = block.mlp.c_proj.weight
            weights_dict[f'block_{i}_mlp_proj_bias'] = block.mlp.c_proj.bias if block.mlp.c_proj.bias is not None else None

        # Store the model weights to a single file
        models_path_out = f"{models_path}/model_{runtime_uuid}.weights"
        with open(models_path_out, "w") as f:
            json.dump(weights_dict, f, indent=4)

    elif runtime_type == "from_pkl_wgt":
        models_path_inp = f"{models_path}/model_{runtime_uuid}.pkl"
        with open(models_path_inp, f, "rb") as f:
            weights_dict = pk.load(f)

        # Convert numpy arrays to lists for JSON serialization
        json_weights_dict = {}
        for key, value in weights_dict.items():
            if value is not None and hasattr(value, 'tolist'):
                json_weights_dict[key] = value.tolist()
            else:
                json_weights_dict[key] = value

        # Store the model weights to a single file
        models_path_out = f"{models_path}/model_{runtime_uuid}.weights"
        with open(models_path_out, f, "w") as f:
            json.dump(json_weights_dict, f, indent=4)
