In [1]:
import torch
import safetensors.torch as st
import os
import json

In [2]:


# Define the output directory and checkpoint file name
out_dir = 'out'
checkpoint_file = 'ckpt.pt'

# Load the checkpoint
checkpoint_path = os.path.join(out_dir, checkpoint_file)
checkpoint_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# Extract the model's state dictionary and additional metadata
model_state_dict = checkpoint_dict['model']
model_args = checkpoint_dict['model_args']
config = checkpoint_dict['config']
iter_num = str(checkpoint_dict['iter_num'])
best_val_loss = str(checkpoint_dict['best_val_loss'])

# Convert additional metadata to string if necessary (example for JSON serializable data)
metadata = {
    'model_args': json.dumps(model_args),
    'config': json.dumps(config),
    'iter_num': iter_num,
    'best_val_loss': best_val_loss
}

# Clone tensors that share memory if necessary
# Check and clone specific tensors known to share memory
shared_tensors = ['_orig_mod.output.weight', '_orig_mod.tok_embeddings.weight']
for tensor_name in shared_tensors:
    if tensor_name in model_state_dict:
        model_state_dict[tensor_name] = model_state_dict[tensor_name].clone()

# Serialize the model's state dictionary along with metadata to SafeTensors format
serialized_data = st.save(model_state_dict, metadata=metadata)

# Write the serialized data to a file
output_file_path = os.path.join(out_dir, 'model_state_dict.safetensors')
with open(output_file_path, 'wb') as f:
    f.write(serialized_data)

print(f"Model state dictionary with metadata saved to SafeTensors format at: {output_file_path}")


Model state dictionary with metadata saved to SafeTensors format at: out/model_state_dict.safetensors
