In [1]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")

from spd.data import DatasetConfig, create_data_loader
from spd.experiments.lm.configs import LMTaskConfig
from spd.models.component_model import ComponentModel, SPDRunInfo
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import replace_pydantic_model, set_seed
import torch
import numpy as np
import pandas as pd

set_seed(0)
device = get_device()

# Load model
run_info = SPDRunInfo.from_path("wandb:goodfire/spd/runs/9d313yrl")
config = run_info.config
model = ComponentModel.from_run_info(run_info)
model.to(device)
model.target_model.requires_grad_(False)
model.eval()

print("Setup complete")

  from .autonotebook import tqdm as notebook_tqdm


Setup complete


In [2]:
def evaluate_loss_with_position_tracking(n_ctx, n_batches=10, batch_size=32):
    set_seed(0)
    
    task_config = replace_pydantic_model(
        config.task_config, 
        {"max_seq_len": n_ctx, "train_data_split": "train[:5000]"}
    )
    
    data_config = DatasetConfig(
        name=task_config.dataset_name,
        hf_tokenizer_path=config.tokenizer_name,
        split=task_config.train_data_split,
        n_ctx=task_config.max_seq_len,
        is_tokenized=task_config.is_tokenized,
        streaming=task_config.streaming,
        column_name=task_config.column_name,
        shuffle_each_epoch=task_config.shuffle_each_epoch,
        seed=0,
    )
    
    data_loader, _tokenizer = create_data_loader(
        dataset_config=data_config,
        batch_size=batch_size,
        buffer_size=task_config.buffer_size,
        global_seed=0,
        ddp_rank=0,
        ddp_world_size=1,
    )
    
    position_losses = torch.zeros(n_ctx - 1, device=device)
    position_counts = torch.zeros(n_ctx - 1, device=device)
    batch_losses = []
    
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            if i >= n_batches:
                break
            
            input_ids = batch["input_ids"].to(device)
            output = model.target_model(input_ids)
            
            if hasattr(output, 'logits'):
                logits = output.logits
            else:
                logits = output
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            per_token_loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            
            per_token_loss = per_token_loss.view(shift_logits.shape[0], shift_logits.shape[1])
            position_losses[:n_ctx-1] += per_token_loss.sum(dim=0)
            position_counts[:n_ctx-1] += per_token_loss.shape[0]
            batch_losses.append(per_token_loss.mean().item())
    
    avg_position_losses = (position_losses / position_counts.clamp(min=1)).cpu().numpy()
    
    return {
        'mean_loss': np.mean(batch_losses),
        'std_loss': np.std(batch_losses),
        'min_loss': np.min(batch_losses),
        'max_loss': np.max(batch_losses),
        'losses_per_position': avg_position_losses,
        'num_predictions': n_ctx - 1,
        'input_seq_len': n_ctx
    }

print("Function defined")

Function defined


In [3]:
# Run analysis for all context sizes
context_sizes = [8, 32, 128, 505, 511, 512, 513, 515, 550, 800]
results = {}

for n_ctx in context_sizes:
    result = evaluate_loss_with_position_tracking(n_ctx, n_batches=10, batch_size=32)
    results[n_ctx] = result

In [4]:
# Create the comprehensive summary table
summary_data = []
for n_ctx, result in results.items():
    summary_data.append({
        'input_seq_len': n_ctx,
        'predicts_positions': f"1-{n_ctx-1}",
        'num_predictions': n_ctx - 1,
        'mean_loss': result['mean_loss'],
        'std_loss': result['std_loss'],
        'min_loss': result['min_loss'],
        'max_loss': result['max_loss']
    })

df = pd.DataFrame(summary_data)
df = df.sort_values('input_seq_len')

print("=" * 90)
print("COMPREHENSIVE SUMMARY: Target Model Loss vs Context Size (With Position Tracking)")
print("=" * 90)
print(df.to_string(index=False))
print("=" * 90)

COMPREHENSIVE SUMMARY: Target Model Loss vs Context Size (With Position Tracking)
 input_seq_len predicts_positions  num_predictions  mean_loss  std_loss  min_loss  max_loss
             8                1-7                7   3.432429  0.170315  3.158330  3.762946
            32               1-31               31   2.735031  0.084271  2.597227  2.897082
           128              1-127              127   2.451201  0.043478  2.362967  2.514810
           505              1-504              504   2.326909  0.026657  2.290959  2.374750
           511              1-510              510   2.336479  0.046317  2.275845  2.423364
           512              1-511              511   2.325966  0.042399  2.248553  2.391512
           513              1-512              512   2.342385  0.030618  2.265809  2.392439
           515              1-514              514   2.338477  0.019998  2.309671  2.381780
           550              1-549              549   2.345897  0.050434  2.281770  2.44571