# llama model checkpointing project - phase 1 & 2

this notebook implements phase 1: pytorch approach and phase 2: tensorstore approach with performance comparison

In [None]:
# import required libraries
import torch
import time
import os
from transformers import LlamaForCausalLM, LlamaTokenizer
import gc
import tensorstore as ts
import numpy as np
import matplotlib.pyplot as plt
import asyncio
import json

In [None]:
# setup device and check cuda availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using device: {device}")

if torch.cuda.is_available():
    print(f"cuda device: {torch.cuda.get_device_name(0)}")
    print(f"cuda memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} gb")
    print(f"cuda memory free: {torch.cuda.memory_reserved(0) / 1e9:.2f} gb")

In [None]:
# create saved_models directory if it doesn't exist
os.makedirs('saved_models', exist_ok=True)
print("created saved_models directory")

In [None]:
# load openllama-3b model with pretrained weights
model_name = "openlm-research/open_llama_3b"
print(f"loading model: {model_name}")

# load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_name)
print("tokenizer loaded successfully")

# load model with memory optimization
model = LlamaForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16,  # use half precision for memory efficiency
    device_map="auto" if torch.cuda.is_available() else None,
    low_cpu_mem_usage=True
)

print(f"model loaded successfully")
print(f"model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}m")

if torch.cuda.is_available():
    print(f"cuda memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} gb")

In [None]:
# test model inference to verify it's working
test_prompt = "the future of artificial intelligence is"
inputs = tokenizer(test_prompt, return_tensors="pt")

if torch.cuda.is_available():
    inputs = {k: v.to(device) for k, v in inputs.items()}

print(f"testing model with prompt: '{test_prompt}'")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=50,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"generated text: {generated_text}")
print("model inference test successful")

In [None]:
# phase 1: save model using pytorch approach with timing
pytorch_save_path = "saved_models/openllama_3b_pytorch.pth"

print("=== phase 1: pytorch saving ===")
start_time = time.time()

# save only model state dict for weights_only=True compatibility
torch.save(model.state_dict(), pytorch_save_path)

pytorch_save_time = time.time() - start_time
pytorch_file_size = os.path.getsize(pytorch_save_path) / (1024**3)  # convert to gb

print(f"pytorch save completed in {pytorch_save_time*1000:.1f} ms")
print(f"file size: {pytorch_file_size:.2f} gb")
print(f"saved to: {pytorch_save_path}")

In [None]:
# phase 1: test pytorch loading
print("\n=== phase 1: pytorch loading ===")
start_time = time.time()

# load the saved state dict to cpu to verify integrity
state_dict = torch.load(pytorch_save_path, map_location='cpu')

pytorch_load_time = time.time() - start_time

print(f"pytorch load completed in {pytorch_load_time*1000:.1f} ms")
print(f"loaded {len(state_dict)} parameters successfully")

# cleanup
del state_dict
gc.collect()

In [None]:
# phase 2: save model using tensorstore approach (simplified version)
tensorstore_save_dir = "saved_models/openllama_3b_tensorstore/"
os.makedirs(tensorstore_save_dir, exist_ok=True)

print("\n=== phase 2: tensorstore saving ===")
start_time = time.time()

# get model state dict and handle meta tensors
model_state = {}
param_count = 0
for name, param in model.named_parameters():
    if param.device.type != 'meta':  # skip meta tensors
        model_state[name] = param
        param_count += 1

print(f"processing {param_count} non-meta parameters...")

# save each parameter tensor using tensorstore with zarr format
def save_tensorstore_simple():
    saved_count = 0
    for param_name, param_tensor in model_state.items():
        try:
            # convert to numpy and move to cpu, convert to float32 for tensorstore compatibility
            param_np = param_tensor.detach().cpu().float().numpy()
            
            # create safe filename by replacing dots and slashes
            safe_name = param_name.replace('.', '_').replace('/', '_')
            
            # create tensorstore spec for zarr format with proper dtype
            spec = {
                'driver': 'zarr',
                'kvstore': {
                    'driver': 'file',
                    'path': f"{tensorstore_save_dir}{safe_name}.zarr"
                },
                'metadata': {
                    'shape': list(param_np.shape),
                    'dtype': '<f4',  # little-endian float32 format for zarr
                    'chunks': [min(64, s) for s in param_np.shape] if param_np.shape else [1]
                }
            }
            
            # create and write tensor synchronously
            store = ts.open(spec, create=True, delete_existing=True).result()
            store.write(param_np).result()
            saved_count += 1
            
        except Exception as e:
            print(f"skipping parameter {param_name}: {e}")
            continue
        
    return saved_count

# save parameter metadata for reconstruction
metadata = {
    'param_names': list(model_state.keys()),
    'total_params': len(model_state)
}
with open(f"{tensorstore_save_dir}metadata.json", 'w') as f:
    json.dump(metadata, f)

# run save
num_params = save_tensorstore_simple()

tensorstore_save_time = time.time() - start_time

# calculate total size of tensorstore files
tensorstore_size = 0
for root, dirs, files in os.walk(tensorstore_save_dir):
    for file in files:
        tensorstore_size += os.path.getsize(os.path.join(root, file))
tensorstore_file_size = tensorstore_size / (1024**3)

print(f"tensorstore save completed in {tensorstore_save_time*1000:.1f} ms")
print(f"saved {num_params} parameters")
print(f"total size: {tensorstore_file_size:.2f} gb")
print(f"saved to: {tensorstore_save_dir}")

In [None]:
# phase 2: test tensorstore loading
print("\n=== phase 2: tensorstore loading ===")
start_time = time.time()

# load metadata
with open(f"{tensorstore_save_dir}metadata.json", 'r') as f:
    metadata = json.load(f)

# load parameters using tensorstore
def load_tensorstore_simple():
    loaded_state = {}
    loaded_count = 0
    
    for param_name in metadata['param_names']:
        try:
            # create safe filename
            safe_name = param_name.replace('.', '_').replace('/', '_')
            zarr_path = f"{tensorstore_save_dir}{safe_name}.zarr"
            
            if os.path.exists(zarr_path):
                # load tensor from tensorstore
                spec = {
                    'driver': 'zarr',
                    'kvstore': {
                        'driver': 'file',
                        'path': zarr_path
                    }
                }
                
                store = ts.open(spec).result()
                param_np = store.read().result()
                # convert back to torch tensor and half precision
                loaded_state[param_name] = torch.from_numpy(param_np.copy()).half()
                loaded_count += 1
                
        except Exception as e:
            print(f"failed to load parameter {param_name}: {e}")
            continue
    
    return loaded_state, loaded_count

# run load
loaded_state_dict, loaded_count = load_tensorstore_simple()

tensorstore_load_time = time.time() - start_time

print(f"tensorstore load completed in {tensorstore_load_time*1000:.1f} ms")
print(f"loaded {loaded_count} parameters successfully")

# cleanup
del loaded_state_dict, model_state
gc.collect()

In [None]:
# performance comparison and visualization
print("\n=== performance comparison ===")

# create comparison data
methods = ['PyTorch', 'TensorStore']
save_times = [pytorch_save_time * 1000, tensorstore_save_time * 1000]  # convert to ms
load_times = [pytorch_load_time * 1000, tensorstore_load_time * 1000]  # convert to ms
file_sizes = [pytorch_file_size, tensorstore_file_size]  # in gb

# print comparison table
print(f"{'Method':<12} {'Save (ms)':<10} {'Load (ms)':<10} {'Size (GB)':<10}")
print("-" * 50)
for i, method in enumerate(methods):
    print(f"{method:<12} {save_times[i]:<10.1f} {load_times[i]:<10.1f} {file_sizes[i]:<10.2f}")

# create visualization
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# save time comparison
ax1.bar(methods, save_times, color=['blue', 'orange'])
ax1.set_title('save time comparison')
ax1.set_ylabel('time (ms)')
ax1.set_ylim(0, max(save_times) * 1.1)
for i, v in enumerate(save_times):
    ax1.text(i, v + max(save_times) * 0.02, f'{v:.1f}ms', ha='center')

# load time comparison
ax2.bar(methods, load_times, color=['blue', 'orange'])
ax2.set_title('load time comparison')
ax2.set_ylabel('time (ms)')
ax2.set_ylim(0, max(load_times) * 1.1)
for i, v in enumerate(load_times):
    ax2.text(i, v + max(load_times) * 0.02, f'{v:.1f}ms', ha='center')

# file size comparison
ax3.bar(methods, file_sizes, color=['blue', 'orange'])
ax3.set_title('file size comparison')
ax3.set_ylabel('size (gb)')
ax3.set_ylim(0, max(file_sizes) * 1.1)
for i, v in enumerate(file_sizes):
    ax3.text(i, v + max(file_sizes) * 0.02, f'{v:.2f}gb', ha='center')

plt.tight_layout()
plt.savefig('saved_models/performance_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nperformance chart saved to: saved_models/performance_comparison.png")

In [None]:
# phase 1 & 2 summary
print("\n=== phase 1 & 2 summary ===")
print(f"model: openllama-3b")
print(f"device: {device}")
print(f"\npytorch approach:")
print(f"  save time: {pytorch_save_time*1000:.1f} ms")
print(f"  load time: {pytorch_load_time*1000:.1f} ms")
print(f"  file size: {pytorch_file_size:.2f} gb")
print(f"\ntensorstore approach:")
print(f"  save time: {tensorstore_save_time*1000:.1f} ms")
print(f"  load time: {tensorstore_load_time*1000:.1f} ms")
print(f"  file size: {tensorstore_file_size:.2f} gb")

# calculate performance differences
save_diff = ((tensorstore_save_time - pytorch_save_time) / pytorch_save_time) * 100
load_diff = ((tensorstore_load_time - pytorch_load_time) / pytorch_load_time) * 100
size_diff = ((tensorstore_file_size - pytorch_file_size) / pytorch_file_size) * 100

print(f"\nperformance differences (tensorstore vs pytorch):")
print(f"  save time: {save_diff:+.1f}%")
print(f"  load time: {load_diff:+.1f}%")
print(f"  file size: {size_diff:+.1f}%")

print("\nphase 1 & 2 completed successfully!")

In [None]:
# final cleanup
del model, tokenizer
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("cuda memory cleared")
print("memory cleanup completed")