In [None]:
pip install huggingface_hub==0.36.1 transformers==4.57.6 transformer_lens==2.17.0

In [None]:
import huggingface_hub
assert huggingface_hub.__version__ == "0.36.1", \
    f"RESTART KERNEL! Got {huggingface_hub.__version__}, need 0.36.1"

In [None]:
# Cell 1: Imports
import torch
from huggingface_hub import list_repo_refs
from transformer_lens import HookedTransformer
import time

In [None]:
# Cell 2: Check GPU
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(
    f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
    if torch.cuda.is_available()
    else "N/A"
)

In [None]:
# Cell 3: Verify checkpoint availability
MODELS = [
    "EleutherAI/pythia-160m-deduped",
    "EleutherAI/pythia-1b-deduped",
    "EleutherAI/pythia-2.8b-deduped",
]
CHECKPOINT_INDICES = [0, 15, 30, 60, 90, 120, 140, 150, 152, 153]

for model in MODELS:
    print(f"\n=== {model} ===")
    refs = list_repo_refs(model)
    available = [r.name for r in refs.branches if "step" in r.name]
    required = [f"step{i * 1000}" for i in CHECKPOINT_INDICES]
    missing = set(required) - set(available)
    print(f"Available: {len(available)} step revisions")
    print(f"Missing required: {missing if missing else 'None'}")

In [None]:
# Cell 4: Test model loading (smallest first)
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
print("model loaded successfully!")

In [None]:
# Cell 5: Test checkpoint loading (corrected — uses revision parameter)
try:
    from transformers import AutoModelForCausalLM as _AMCLM

    hf_model = _AMCLM.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        revision="step30000",
        torch_dtype=torch.float32,
    )
    model_step = HookedTransformer.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        hf_model=hf_model,
        tokenizer=AutoTokenizer.from_pretrained(
            "EleutherAI/pythia-160m-deduped", revision="step30000"
        ),
    )
    print("✅ Checkpoint loading works (via HF revision parameter)!")
    del hf_model, model_step
except Exception as e:
    print(f"❌ Checkpoint loading failed: {e}")

In [None]:
# Cell 6: Memory test with 2.8B
print("\nLoading pythia-2.8b (this is the critical test)...")
start = time.time()
model_2_8b = HookedTransformer.from_pretrained("pythia-2.8b-deduped")
load_time = time.time() - start
print(f"Loaded in {load_time:.1f}s")

In [None]:
# Cell 7: Attention extraction speed test
test_input = "A screen reader is"
tokens = model_2_8b.to_tokens(test_input)

start = time.time()
logits, cache = model_2_8b.run_with_cache(tokens)
cache_time = time.time() - start
print(f"Cache extraction: {cache_time:.2f}s")
print(
    f"Cache memory: {sum(v.element_size() * v.nelement() for v in cache.values()) / 1e6:.1f} MB"
)