# Checkpoint Inference Demo

This notebook demonstrates how to load a trained checkpoint and run inference on examples from the fixed dataset to see how the model performs on individual cases.


In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
from scipy.stats import entropy
import json
import os
from pathlib import Path

# Import our modules
from config import ExperimentConfig
from model_training import ModelManager
from dataset_setup import DatasetManager
from evaluation import Evaluator

# Set up device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Set up plotting
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12


Using device: cuda


In [2]:
# Load configuration
config = ExperimentConfig()

# Load the checkpoint
checkpoint_path = "results/checkpoints/model_depth_12_heads_6_20251023_223242.pt"
print(f"Loading checkpoint: {checkpoint_path}")

model, checkpoint_info = ModelManager.load_checkpoint(checkpoint_path, device)

print(f"Model loaded successfully!")
print(f"Model configuration:")
print(f"  Depth: {checkpoint_info['model_config']['depth']}")
print(f"  Heads: {checkpoint_info['model_config']['heads']}")
print(f"  Embedding dim: {checkpoint_info['model_config']['n_embd']}")
print(f"  Training iterations: {checkpoint_info['training_info']['iterations']}")
print(f"  Final loss: {checkpoint_info['training_info']['final_loss']:.5f}")
print(f"  Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


Loading checkpoint: results/checkpoints/model_depth_12_heads_6_20251023_223242.pt
number of parameters: 5.38M
Model loaded successfully!
Model configuration:
  Depth: 12
  Heads: 6
  Embedding dim: 192
  Training iterations: 30000
  Final loss: 1.57309
  Number of parameters: 5.40M


In [None]:
# Set up dataset manager and get a fixed dataset example
dataset_manager = DatasetManager(config)

# Create a mixed dataset to get the dictionary
mixed_dataset = dataset_manager.create_mixed_dataset()

# Create fixed dataset loader (uses dictionary from config)
fixed_loader = dataset_manager.create_fixed_dataset_loader()

# Get a single example
example = next(iter(fixed_loader))
a, b = example

print(f"Example shape: input={a.shape}, target={b.shape}")
print(f"Input sequence: {a[0].tolist()}")
print(f"Target sequence: {b[0].tolist()}")

# Show the key-value pairs
sequence = a[0].tolist()
print(f"\nKey-Value pairs:")
for i in range(0, len(sequence), 2):
    if i+1 < len(sequence):
        print(f"  Position {i//2}: Key={sequence[i]}, Value={sequence[i+1]}")


In [None]:
# Run inference on the example
model.eval()
with torch.no_grad():
    # Get logits for the input sequence
    logits, _ = model(a.to(device))
    probs = F.softmax(logits, dim=-1)
    
    # Get predictions for each position
    predictions = torch.argmax(probs, dim=-1)
    
    print(f"Model predictions:")
    print(f"Input:  {a[0].tolist()}")
    print(f"Target: {b[0].tolist()}")
    print(f"Pred:   {predictions[0].tolist()}")
    
    # Calculate accuracy for each position
    correct = (predictions[0] == b[0]).float()
    accuracy = correct.mean().item()
    
    print(f"\nPosition-wise accuracy: {correct.tolist()}")
    print(f"Overall accuracy: {accuracy:.3f}")


In [None]:
# Analyze the probability distributions and entropies
print(f"\nDetailed Analysis:")
print(f"{'Position':<10} {'True':<8} {'Pred':<8} {'Correct':<8} {'Entropy':<10} {'Top-3 Probs'}")
print("-" * 80)

for i in range(min(len(a[0]), len(b[0]), len(predictions[0]))):
    true_val = b[0][i].item()
    pred_val = predictions[0][i].item()
    correct = "✓" if true_val == pred_val else "✗"
    
    # Calculate entropy for this position
    pos_probs = probs[0, i].cpu().numpy()
    pos_entropy = entropy(pos_probs)
    
    # Get top-3 probabilities
    top3_probs, top3_indices = torch.topk(probs[0, i], 3)
    top3_str = ", ".join([f"{idx.item()}:{prob.item():.3f}" for idx, prob in zip(top3_indices, top3_probs)])
    
    print(f"{i:<10} {true_val:<8} {pred_val:<8} {correct:<8} {pos_entropy:<10.3f} {top3_str}")


In [None]:
# Visualize the probability distributions for key positions
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Show probability distributions for the first 6 value positions (odd indices: 1, 3, 5, 7, 9, 11)
value_positions = [1, 3, 5, 7, 9, 11]

for idx, pos in enumerate(value_positions):
    if pos < len(probs[0]):
        ax = axes[idx]
        
        # Get probabilities for this position
        pos_probs = probs[0, pos].cpu().numpy()
        
        # Plot top-10 probabilities
        top10_probs, top10_indices = torch.topk(probs[0, pos], 10)
        
        bars = ax.bar(range(len(top10_indices)), top10_probs.cpu().numpy())
        
        # Highlight the correct answer
        true_val = b[0][pos].item()
        if true_val in top10_indices:
            correct_idx = (top10_indices == true_val).nonzero(as_tuple=True)[0][0]
            bars[correct_idx].set_color('red')
            bars[correct_idx].set_alpha(0.8)
        
        ax.set_title(f'Position {pos} (Value)\nTrue: {true_val}, Pred: {predictions[0][pos].item()}')
        ax.set_xlabel('Token Index')
        ax.set_ylabel('Probability')
        ax.set_xticks(range(len(top10_indices)))
        ax.set_xticklabels([str(idx.item()) for idx in top10_indices], rotation=45)
        
        # Add entropy as text
        pos_entropy = entropy(pos_probs)
        ax.text(0.02, 0.98, f'Entropy: {pos_entropy:.3f}', 
                transform=ax.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

plt.tight_layout()
plt.suptitle('Probability Distributions for Value Positions', y=1.02, fontsize=16)
plt.show()


In [None]:
# Test the model's ability to learn the fixed pattern
print(f"\nTesting Fixed Pattern Learning:")
print(f"Dictionary: {mixed_dataset.dictionary.tolist()}")

# Show how the model should behave for the fixed pattern
print(f"\nExpected behavior for fixed dataset:")
print(f"- Early positions should have high confidence (low entropy)")
print(f"- Later positions should have very high confidence (very low entropy)")
print(f"- The model should learn the fixed key-value mapping")

# Calculate entropies for value positions only
value_entropies = []
value_positions = list(range(1, len(probs[0]), 2))  # Odd indices

for pos in value_positions:
    if pos < len(probs[0]):
        pos_probs = probs[0, pos].cpu().numpy()
        pos_entropy = entropy(pos_probs)
        value_entropies.append(pos_entropy)

print(f"\nEntropies for value positions: {[f'{e:.3f}' for e in value_entropies]}")

# Plot entropy vs position
plt.figure(figsize=(12, 6))
plt.plot(range(len(value_entropies)), value_entropies, 'o-', linewidth=2, markersize=6)
plt.xlabel('Value Position Index')
plt.ylabel('Entropy')
plt.title('Model Entropy vs Position (Fixed Dataset)')
plt.grid(True, alpha=0.3)
plt.xticks(range(0, len(value_entropies), 5))

# Add expected pattern (should decrease)
expected_entropies = [0.7] + [0.02] + [0.001] + [0.0001] + [0.00001] + [0.000001] + [0.0] * (len(value_entropies) - 6)
expected_entropies = expected_entropies[:len(value_entropies)]
plt.plot(range(len(expected_entropies)), expected_entropies, 'r--', linewidth=2, label='Expected Pattern')
plt.legend()
plt.show()

print(f"\nAnalysis:")
print(f"- If the model learned the fixed pattern, entropies should decrease rapidly")
print(f"- High entropies (~3.6-3.8) suggest the model is not learning the pattern")
print(f"- The model should show high confidence for later positions in fixed dataset")


In [None]:
# Test with multiple examples to get a better sense of performance
print(f"\nTesting Multiple Examples:")

# Get several examples
examples = []
for i, example in enumerate(fixed_loader):
    if i >= 5:  # Test 5 examples
        break
    examples.append(example)

accuracies = []
entropies_by_position = []

for ex_idx, (a, b) in enumerate(examples):
    with torch.no_grad():
        logits, _ = model(a.to(device))
        probs = F.softmax(logits, dim=-1)
        predictions = torch.argmax(probs, dim=-1)
        
        # Calculate accuracy
        correct = (predictions[0] == b[0]).float()
        accuracy = correct.mean().item()
        accuracies.append(accuracy)
        
        # Calculate entropies for value positions
        value_entropies = []
        for pos in range(1, len(probs[0]), 2):
            if pos < len(probs[0]):
                pos_probs = probs[0, pos].cpu().numpy()
                pos_entropy = entropy(pos_probs)
                value_entropies.append(pos_entropy)
        entropies_by_position.append(value_entropies)
        
        print(f"Example {ex_idx+1}: Accuracy = {accuracy:.3f}, Mean Entropy = {np.mean(value_entropies):.3f}")

print(f"\nSummary across {len(examples)} examples:")
print(f"Mean accuracy: {np.mean(accuracies):.3f} ± {np.std(accuracies):.3f}")
print(f"Mean entropy: {np.mean([np.mean(ents) for ents in entropies_by_position]):.3f}")

# Plot average entropy by position across examples
avg_entropies = np.mean(entropies_by_position, axis=0)
plt.figure(figsize=(12, 6))
plt.plot(range(len(avg_entropies)), avg_entropies, 'o-', linewidth=2, markersize=6)
plt.xlabel('Value Position Index')
plt.ylabel('Average Entropy')
plt.title('Average Entropy vs Position (5 Examples)')
plt.grid(True, alpha=0.3)
plt.show()
