In [2]:
import sys
sys.path.insert(0, "..")

import json
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

from src.model import SparseGPT


In [3]:
# Load model from HuggingFace using from_pretrained (handles config + weights correctly)
repo_id = "jacobcd52/ss_d128_f1"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SparseGPT.from_pretrained(repo_id, device=device)

print(f"Model loaded! Parameters: {model.get_num_params():,}")
print(f"Activation sparsity enabled: {model.sparsity_config.enable_activation_sparsity}")
print(f"Using device: {device}")


Model loaded! Parameters: 1,444,752
Activation sparsity enabled: False
Using device: cuda


In [4]:
# Load tokenizer from training config
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
    config_dict = json.load(f)

tokenizer_name = config_dict["training_config"]["tokenizer_name"]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
print(f"Tokenizer: {tokenizer_name}")

Tokenizer: SimpleStories/SimpleStories-1.25M


In [5]:
# Test text generation
prompt = "John and Mary went hiking in the woods, when they came across a"
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)

print(f"Prompt: {prompt}")
print("-" * 50)

# Generate with different temperatures
temp = 0.01
output_ids = model.generate(
    input_ids,
    max_new_tokens=50,
    temperature=temp,
)

Prompt: John and Mary went hiking in the woods, when they came across a
--------------------------------------------------


In [6]:
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
print(f"\nTemperature {temp}:")
print(generated_text)


Temperature 0.01:
john and mary went hiking in the woods, when they came across a big tree. the squirrel was very happy. " i ' m so happy! " he said. the squirrel laughed and said, " you are very good! " the squirrel thought for a moment. " i can ' t wait to eat! "


In [None]:
# Test the model
import torch.nn.functional as F
test_prompt = "John and Mary went hiking in the woods, when they came across a"
tokens = tokenizer.encode(test_prompt, return_tensors="pt").cuda()
print(f"Prompt: '{test_prompt}'")
print(f"Tokens: {tokens.tolist()}")

# Run forward pass
with torch.no_grad():
    logits = model(tokens)
    probs = F.softmax(logits[0][0, -1], dim=-1)
    top_tokens = probs.topk(5)
    
print(f"\nTop 5 predictions:")
for i in range(5):
    tok_id = top_tokens.indices[i].item()
    prob = top_tokens.values[i].item()
    tok_str = tokenizer.decode([tok_id])
    print(f"  {tok_str!r}: {prob:.3f}")


Prompt: 'John and Mary went hiking in the woods, when they came across a'
Tokens: [[1804, 72, 60, 94, 634, 64, 1088, 3241, 289, 108, 85, 1223, 13, 335, 111, 624, 1216, 32]]

Top 5 predictions:
  'big': 0.095
  'giant': 0.055
  'small': 0.044
  'sparkling': 0.030
  'dark': 0.023


In [16]:
logits[0].shape

torch.Size([1, 19, 4096])

In [6]:
# Load SimpleStories dataset for loss evaluation
from datasets import load_dataset
import torch.nn.functional as F

dataset = load_dataset("SimpleStories/SimpleStories", split="test")
print(f"Loaded {len(dataset)} validation examples")

Loaded 21371 validation examples


In [8]:
# Prepare batches for loss evaluation
ctx_len = 512
batch_size = 16
num_batches = 10

# Tokenize texts and create sequences of length ctx_len
all_tokens = []
for example in dataset:
    tokens = tokenizer.encode(example["story"], add_special_tokens=False)
    all_tokens.extend(tokens)

# Create non-overlapping sequences of ctx_len tokens
sequences = []
for i in range(0, len(all_tokens) - ctx_len, ctx_len):
    sequences.append(all_tokens[i:i + ctx_len])

print(f"Total tokens: {len(all_tokens):,}")
print(f"Created {len(sequences)} sequences of length {ctx_len}")


Total tokens: 6,126,675
Created 11966 sequences of length 512


In [9]:
# Compute loss on multiple batches
model.eval()
losses = []

with torch.no_grad():
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        
        if end_idx > len(sequences):
            break
            
        # Create batch tensor
        batch_sequences = sequences[start_idx:end_idx]
        input_ids = torch.tensor(batch_sequences, dtype=torch.long, device=device)
        
        # Forward pass
        logits, _, _ = model(input_ids)
        
        # Compute loss (predict next token)
        # shift by 1: logits[:-1] predicts labels[1:]
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        
        losses.append(loss.item())
        print(f"Batch {batch_idx + 1}/{num_batches}: Loss = {loss.item():.4f}")

# Summary statistics
avg_loss = sum(losses) / len(losses)
print(f"\n{'='*50}")
print(f"Average Loss over {len(losses)} batches: {avg_loss:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(avg_loss)).item():.2f}")


Batch 1/10: Loss = 2.0954
Batch 2/10: Loss = 2.1700
Batch 3/10: Loss = 2.1475
Batch 4/10: Loss = 2.1774
Batch 5/10: Loss = 2.1813
Batch 6/10: Loss = 2.0844
Batch 7/10: Loss = 2.1094
Batch 8/10: Loss = 2.1027
Batch 9/10: Loss = 2.1449
Batch 10/10: Loss = 2.0504

Average Loss over 10 batches: 2.1263
Perplexity: 8.38
