In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, "..")
sys.path.insert(0, "../..")

import torch
from transformers import AutoTokenizer
from sparse_pretrain.src.pruning.run_pruning import load_model

# Load model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "jacobcd52/ss_bridges_d1024_f0.015625"
tokenizer_name = "SimpleStories/SimpleStories-1.25M"

print(f"Loading model: {model_path}")
model, _ = load_model(model_path, device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
print(f"Model loaded on {device}")


Loading model: jacobcd52/ss_bridges_d1024_f0.015625
Loading sparse model from HuggingFace Hub bridge checkpoint: jacobcd52/ss_bridges_d1024_f0.015625
Note: Bridge weights are ignored for pruning
Model loaded on cuda


In [30]:
# ========================================
# EDIT YOUR PROMPT HERE
# ========================================
prompt = 'when jose wrote a letter, he'

# Tokenize
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
print(f"Prompt: '{prompt}'")
print(f"Token IDs: {input_ids[0].tolist()}")
print(f"Tokens: {[tokenizer.decode([t]) for t in input_ids[0].tolist()]}")


Prompt: 'when jose wrote a letter, he'
Token IDs: [335, 730, 1050, 32, 776, 13, 103]
Tokens: ['when', 'jose', 'wrote', 'a', 'letter', ',', 'he']


In [31]:
# Forward pass
with torch.no_grad():
    output = model(input_ids)
    logits = output[0] if isinstance(output, tuple) else output

# Get logits at final position
final_logits = logits[0, -2, :]  # (vocab_size,)
print(f"Logits shape: {logits.shape}")
print(f"Final position logits shape: {final_logits.shape}")

# Get top 10 predictions
probs = torch.softmax(final_logits, dim=-1)
top_probs, top_indices = torch.topk(probs, 10)

print(f"\nTop 10 next token predictions:")
print("-" * 40)
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
    token = tokenizer.decode([idx.item()])
    print(f"{i+1}. '{token}' (id={idx.item()}) - prob={prob.item():.4f}")


Logits shape: torch.Size([1, 7, 4096])
Final position logits shape: torch.Size([4096])

Top 10 next token predictions:
----------------------------------------
1. '"' (id=3) - prob=0.6352
2. 'he' (id=103) - prob=0.0656
3. 'saying' (id=1929) - prob=0.0283
4. 'telling' (id=1868) - prob=0.0162
5. 'asking' (id=2393) - prob=0.0147
6. 'and' (id=94) - prob=0.0117
7. 'a' (id=32) - prob=0.0102
8. 'kim' (id=415) - prob=0.0102
9. 'the' (id=85) - prob=0.0066
10. 'it' (id=120) - prob=0.0056


In [33]:
# ========================================
# DUMMY PRONOUN TASK EXAMPLES
# ========================================
from sparse_pretrain.src.pruning.tasks import get_task

# Load the task
task = get_task("dummy_pronoun", tokenizer, seed=42, split="val")
print(f"Task: {task.name}")
print(f"Templates: {len(task.templates)}")

# Generate examples and show predictions
n_examples = 30
print(f"\n{'='*80}")
print(f"Showing {n_examples} examples from dummy pronoun task")
print(f"{'='*80}\n")

for i in range(n_examples):
    ex = task.generate_example()
    
    # Get the prompt (everything before the final token)
    prompt_ids = ex.positive_ids[:]  # All but last token
    prompt_text = tokenizer.decode(prompt_ids)
    
    # Get correct and incorrect tokens
    correct_token = tokenizer.decode([ex.correct_token])
    incorrect_token = tokenizer.decode([ex.incorrect_token])
    
    # Run forward pass
    input_ids = prompt_ids.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_ids)
        logits = output[0] if isinstance(output, tuple) else output
        final_logits = logits[0, -1, :]
    
    probs = torch.softmax(final_logits, dim=-1)
    top_probs, top_indices = torch.topk(probs, 3)
    
    # Get logits for correct/incorrect
    correct_logit = final_logits[ex.correct_token].item()
    incorrect_logit = final_logits[ex.incorrect_token].item()
    logit_diff = correct_logit - incorrect_logit
    
    print(f"Example {i+1}:")
    print(f"  Prompt: \"{prompt_text}\"")
    print(f"  Correct: '{correct_token}' | Incorrect: '{incorrect_token}'")
    print(f"  Top 3 predictions:")
    for j, (prob, idx) in enumerate(zip(top_probs, top_indices)):
        token = tokenizer.decode([idx.item()])
        marker = "✓" if idx.item() == ex.correct_token else ("✗" if idx.item() == ex.incorrect_token else " ")
        print(f"    {j+1}. '{token}' - {prob.item():.3f} {marker}")
    print(f"  Logit diff (correct - incorrect): {logit_diff:.2f} {'✓' if logit_diff > 0 else '✗'}")
    print()


Task: dummy_pronoun_val
Templates: 24

Showing 30 examples from dummy pronoun task

Example 1:
  Prompt: "when rita helped the lost puppy,"
  Correct: 'she' | Incorrect: 'he'
  Top 3 predictions:
    1. 'she' - 0.319 ✓
    2. 'leaving' - 0.065  
    3. 'the' - 0.051  
  Logit diff (correct - incorrect): 2.64 ✓

Example 2:
  Prompt: "when mia fed the hungry cat,"
  Correct: 'she' | Incorrect: 'he'
  Top 3 predictions:
    1. 'she' - 0.213 ✓
    2. 'the' - 0.052  
    3. 'it' - 0.028  
  Logit diff (correct - incorrect): 2.28 ✓

Example 3:
  Prompt: "when leo cheered for the team,"
  Correct: 'he' | Incorrect: 'she'
  Top 3 predictions:
    1. 'he' - 0.369 ✓
    2. 'they' - 0.086  
    3. 'knowing' - 0.041  
  Logit diff (correct - incorrect): 3.52 ✓

Example 4:
  Prompt: "when maria swam in the lake,"
  Correct: 'she' | Incorrect: 'he'
  Top 3 predictions:
    1. 'she' - 0.528 ✓
    2. 'the' - 0.081  
    3. 'they' - 0.057  
  Logit diff (correct - incorrect): 3.25 ✓

Example 5:
  Promp