## Part 1

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [10]:
DEVICE = "cuda" if torch.cuda.is_available else "cpu"
MODEL_NAME = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states = True).to(DEVICE).eval()


In [None]:
def logitlens(model, tokenizer, prompt, top_k=5, device=DEVICE):
    #token -> tensor
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    ##forward pass to get all hidden states
    with torch.no_grad():
        outputs = model(**inputs)   
    hidden_states = outputs.hidden_states

    W_U = model.lm_head.weight
    final_norm = model.transformer.ln_f
    results = []

    for layer_idx in range(1, len(hidden_states)):
        h = hidden_states[layer_idx]
        h_last = h[0, -1]   
        
        h_norm = final_norm(h_last)

        logits = torch.matmul(W_U, h_norm)
        probs = torch.softmax(logits, dim=-1)

        top_probs, top_indices = torch.topk(probs, k=top_k)
        top_tokens = [tokenizer.decode([idx.item()]) for idx in top_indices]

        results.append({
            "layer": layer_idx,
            "tokens": top_tokens,
            "probs": top_probs.cpu().tolist()
        })

    return results


In [7]:
def print_logitlens_results(results, prompt):
    print(f"\nPrompt: {prompt}\n")
    for r in results:
        layer = r["layer"]
        toks = r["tokens"]
        probs = r["probs"]

        pretty = [f"{tok.strip() or repr(tok)} ({p:.3f})"
                  for tok, p in zip(toks, probs)]
        print(f"Layer {layer:2d}: " + " | ".join(pretty))


In [9]:
prompt = "The capital of France is"
results = logitlens(model, tokenizer, prompt, top_k=5)
print_logitlens_results(results, prompt)



Prompt: The capital of France is

Layer  1: not (0.332) | now (0.128) | still (0.082) | also (0.073) | a (0.038)
Layer  2: now (0.271) | not (0.256) | also (0.156) | still (0.081) | unlikely (0.017)
Layer  3: now (0.366) | not (0.172) | also (0.152) | still (0.092) | currently (0.028)
Layer  4: now (0.515) | also (0.145) | not (0.126) | currently (0.040) | still (0.038)
Layer  5: now (0.627) | not (0.080) | also (0.053) | still (0.047) | already (0.033)
Layer  6: now (0.552) | still (0.109) | not (0.088) | also (0.068) | already (0.030)
Layer  7: now (0.572) | not (0.123) | still (0.056) | also (0.050) | a (0.016)
Layer  8: now (0.689) | also (0.063) | still (0.059) | not (0.048) | already (0.025)
Layer  9: now (0.423) | located (0.145) | also (0.071) | not (0.070) | still (0.049)
Layer 10: France (0.621) | Paris (0.182) | now (0.039) | located (0.029) | French (0.011)
Layer 11: France (0.242) | Paris (0.139) | now (0.103) | the (0.042) | located (0.023)
Layer 12: the (0.038) | a (0.0