In [None]:
!pip install torch transformer-lens datasets

In [None]:
from datasets import load_dataset
import json

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import gc
from transformer_lens import HookedTransformer
from huggingface_hub import login
from torch.utils.data import DataLoader, TensorDataset

In [None]:
hf_token = "..."
login(token=hf_token)

In [None]:
MODEL_NAME = "google/gemma-2-2b-it"  # Or any supported model; replace with your desired model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 2
CHECKPOINT_EVERY = 1  # Save probe checkpoint every N epochs

# Load model and tokenizer
model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE, torch_dtype=torch.float16, center_unembed=False)

# Freeze model parameters & set to eval mode
model.eval()
for param in model.parameters():
    param.requires_grad = False

tokenizer = model.tokenizer

In [None]:
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
d_vocab = model.cfg.d_vocab

In [None]:
#dataset on which to train tunedlens probes
NUM_EXAMPLES = 39470
dataset = load_dataset("allenai/c4", name="en", split="train", streaming=True)
text_samples = []
for example in dataset:
    text = example.get("text", "").strip()
    if text:
        text_samples.append(text)
    if len(text_samples) >= NUM_EXAMPLES:
        break
print(f"Loaded {len(text_samples)} clean text samples")

# Join all text into one long sequence of tokens
joined_text = "\n\n".join(text_samples) 
all_token_ids = tokenizer.encode(joined_text, return_tensors='pt')[0]  # shape: (total_tokens,)
print(f"Token sequence length: {len(all_token_ids)} tokens")


In [None]:
# special strings which will be added to the start and end of each text sample to follow the transformer model's expected input format
#we train the probes on dataset samples with these special tokens added, as we would want to use the lens on examples with the tokens. 
#In the original tuned lens implementation, 
prefix_text = "<start_of_turn>user\n"
suffix_text = "<end_of_turn>"

# Encode prefix and suffix tokens with the tokenizer:
prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)  # e.g. [....]
suffix_tokens = tokenizer.encode(suffix_text, add_special_tokens=False)  # e.g. [....]

print(f"Prefix tokens count: {len(prefix_tokens)}")
print(f"Suffix tokens count: {len(suffix_tokens)}")

In [None]:
BATCH_SIZE = 128
SEQ_LEN = 128
SUBSET_SIZE = 128

In [None]:
def create_dataset_from_tokens(token_ids, seq_len):
    examples = []
    total_tokens = len(token_ids)
    seq_len=seq_len-4
    for i in range(0, total_tokens - seq_len, seq_len):
        input_ids = token_ids[i : i + seq_len]
        wrapped_seq=prefix_tokens+input_ids.tolist()+suffix_tokens
        wrapped_tensor=torch.tensor(wrapped_seq, dtype=torch.long)
        examples.append(wrapped_tensor)
    input_ids_tensor = torch.stack(examples)
    return input_ids_tensor

In [None]:
TARGET_NUM_SEQUENCES = 130351
input_ids = create_dataset_from_tokens(all_token_ids, seq_len=SEQ_LEN)
input_ids = input_ids[:TARGET_NUM_SEQUENCES]
print(f"Total sequences: {len(input_ids)}")
N = len(input_ids)
num_subsets = N // SUBSET_SIZE
print(num_subsets)

In [None]:
def stable_kl_loss(probe_logits, final_logits, top_k=100, eps=1e-8):
    B, S, V = final_logits.size()

    _, topk_final = torch.topk(final_logits, top_k, dim=-1)
    _, topk_probe = torch.topk(probe_logits, top_k, dim=-1)

    combined = torch.cat([topk_final, topk_probe], dim=-1).view(-1, 2 * top_k)  # shape: (B*S, 2*k)

    max_len = combined.size(-1)
    unique_list = []
    mask_list = []
    for idxs in combined:
        unique = torch.unique(idxs, sorted=True)
        # Pad with the last valid index if shorter than max_len
        if unique.size(0) < max_len:
            padded = torch.cat([unique, unique[-1].repeat(max_len - unique.size(0))])
        else:
            padded = unique[:max_len]
        unique_list.append(padded)
        # Create mask to mark valid indices
        mask_list.append(torch.arange(max_len, device=unique.device) < unique.size(0))

    union_indices = torch.stack(unique_list).view(B, S, max_len)
    valid_mask = torch.stack(mask_list).view(B, S, max_len)  


    probe_subset_logits = torch.gather(probe_logits, 2, union_indices).float()
    final_subset_logits = torch.gather(final_logits, 2, union_indices).float()

    target_probs = F.softmax(final_subset_logits, dim=-1).clamp(min=eps)
    target_log_probs = torch.log(target_probs)
    pred_log_probs = F.log_softmax(probe_subset_logits, dim=-1)

    # Compute KL divergence elements and mask out padded positions
    kl_elements = target_probs * (target_log_probs - pred_log_probs) * valid_mask.float()

    # Sum over vocabulary subset dimension
    kl_per_token = kl_elements.sum(dim=-1)

    # Compute mean only over valid tokens (positions where any valid vocab exists)
    # This prevents division by zero and excludes padded positions properly
    token_mask = (valid_mask.any(dim=-1)).float()  # shape (B, S)
    loss = kl_per_token.sum() / (token_mask.sum() + eps)
    return loss


In [None]:
def get_activations(batch_input):
    with torch.no_grad():
        hooks_to_cache = [f'blocks.{i}.hook_resid_post' for i in range(n_layers)] 

        logits, cache = model.run_with_cache(batch_input, names_filter=hooks_to_cache)
        #print(cache[f'blocks.0.hook_resid_post'].dtype)
        resid_post_outs = [cache[f'blocks.{i}.hook_resid_post'] for i in range(n_layers)]

    return resid_post_outs

In [None]:
def compute_logits_from_resid(resid, soft_cap=30.0):

    normed = ln_final(resid)

    W_U_casted = W_U.to(normed.device).type_as(normed)
    b_U_casted = b_U.to(normed.device).type_as(normed)
    logits = torch.einsum('bsd,dk->bsk', normed, W_U_casted) + b_U_casted

    logits = soft_cap * torch.tanh(logits / soft_cap)#done in gemma 2
    return logits


In [None]:
class Probe(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(d_model, d_model, bias=True)
        torch.nn.init.xavier_normal_(self.linear.weight)
        torch.nn.init.zeros_(self.linear.bias)            

    def forward(self, x):
        return self.linear(x)

In [None]:
# Start GPU monitoring in background
#the interval is in seconds, adjust as needed
import subprocess
import threading
import time

def monitor_gpu(interval=10):
    while True:
        try:
            output = subprocess.check_output(['nvidia-smi'], encoding='utf-8')
            print("\n[GPU USAGE]")
            print(output)
        except Exception as e:
            print("[nvidia-smi error]:", e)
        time.sleep(interval)

monitor_thread = threading.Thread(target=monitor_gpu, args=(15,), daemon=True)
monitor_thread.start()

In [None]:
#ignore this, just to clear memory if you want to run loop again
for var_name in ['post_outs', 'final_logits', 'post_logits', 'post_probes','new_probes', 'resid_post_outs', 'subset_input', 'optimizers']:
    if var_name in globals():
        del globals()[var_name]
    elif var_name in locals():
        del locals()[var_name]

gc.collect()
torch.cuda.empty_cache()

In [None]:
soft_cap_value=30.0

In [None]:
for subset_idx in range(num_subsets):    
    start = subset_idx * SUBSET_SIZE
    end = start + SUBSET_SIZE
    print(f"\n--- Processing subset {subset_idx+1} / {num_subsets} -- Sequences {start} to {end-1}")

    subset_input = input_ids[start:end].to(DEVICE)
    # Extract activations for this subset
    resid_post_outs = get_activations(subset_input)
    # Move activations to CPU to save GPU memory (keep as FP16 for storage efficiency)
    torch.save(resid_post_outs, f'activations/activations_{subset_idx + 1}.pth')

In [None]:
W_U = model.state_dict()['unembed.W_U'].to(DEVICE).clone()
b_U = model.state_dict()['unembed.b_U'].to(DEVICE).clone()

import copy
ln_final = copy.deepcopy(model.ln_final).to(DEVICE)

del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
BATCH_MULTIPLIER = 2 # Set to 2 for 2 subsets at once, or 3 for 3, etc.
num_combined_batches = (num_subsets + BATCH_MULTIPLIER - 1) // BATCH_MULTIPLIER


In [None]:
# Initialize one probe per layer (FP32 on DEVICE)
probes = [Probe(d_model).to(DEVICE).to(torch.float32) for _ in range(n_layers)]
optimizers = [torch.optim.AdamW(probe.parameters(), lr=1e-5) for probe in probes]

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    epoch_loss_per_layer = [0.0 for _ in range(n_layers)]

    for layer in range(n_layers):
        probe = probes[layer]
        opt = optimizers[layer]

        for combined_idx in range(num_combined_batches):
            start_subset = combined_idx * BATCH_MULTIPLIER + 1
            end_subset = min(start_subset + BATCH_MULTIPLIER - 1, num_subsets)

            # Gather and concatenate only this layer's activations for all needed subsets
            activ_list = []
            final_layer_activ_list = []
            for subset_idx in range(start_subset, end_subset + 1):
                acts = torch.load(f'activations/activations_{subset_idx}.pth', map_location='cpu')
                activ_list.append(acts[layer])
                # Use the last layer’s activations as the reference for logits
                final_layer_activ_list.append(acts[-1])

            activ_batch = torch.cat(activ_list, dim=0).to(DEVICE).to(torch.float32)
            final_layer_activ_batch = torch.cat(final_layer_activ_list, dim=0).to(DEVICE).to(torch.float32)

            with torch.no_grad():
                final_logits_batch = compute_logits_from_resid(final_layer_activ_batch, soft_cap=30.0)

            out = probe(activ_batch)
            normed_out = ln_final(out)
            B, S, d_model_ = normed_out.shape
            normed_out_flat = normed_out.view(-1, d_model_)

            W_U_casted = W_U.to(normed_out_flat.dtype)
            b_U_casted = b_U.to(normed_out_flat.dtype)
            logits_flat = torch.nn.functional.linear(normed_out_flat, W_U_casted.t().contiguous(), b_U_casted)
            probe_logits = logits_flat.view(B, S, -1)
            probe_logits = soft_cap_value * torch.tanh(probe_logits / soft_cap_value)

            loss = stable_kl_loss(probe_logits, final_logits_batch)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(probe.parameters(), max_norm=1.0)
            opt.step()

            epoch_loss_per_layer[layer] += loss.item()

            # Cleanup
            del activ_batch, final_layer_activ_batch, out, normed_out, normed_out_flat
            del logits_flat, probe_logits, loss, final_logits_batch
            gc.collect()
            torch.cuda.empty_cache()

    print("Average loss per layer:")
    for layer in range(n_layers):
        avg_loss = epoch_loss_per_layer[layer] / num_batches
        print(f"  Layer {layer}: {avg_loss:.4f}")


In [None]:
# Initialize one probe per layer (FP32 on DEVICE)
probes = [Probe(d_model).to(DEVICE).to(torch.float32) for _ in range(n_layers)]
optimizers = [torch.optim.AdamW(probe.parameters(), lr=1e-5) for probe in probes]

num_batches = num_subsets

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    epoch_loss_per_layer = [0.0 for _ in range(n_layers)]
    
    for combined_idx in range(num_combined_batches):
        start_subset = combined_idx * BATCH_MULTIPLIER + 1
        end_subset = min(start_subset + BATCH_MULTIPLIER - 1, num_subsets)
    
        # For each layer (and final layer), build list of tensors to concatenate
        batch_resid_post_outs = [[] for _ in range(n_layers)]

        # Load each subset activation and append per-layer activations
        for subset_idx in range(start_subset, end_subset + 1):
            acts = torch.load(f'activations/activations_{subset_idx}.pth', map_location='cpu')
            for layer_idx in range(n_layers):
                batch_resid_post_outs[layer_idx].append(acts[layer_idx])
                
        print("Shapes of all tensors in last layer before concatenation:")
        for i, t in enumerate(batch_resid_post_outs[-1]):
            print(f"  Tensor {i}: {t.shape}")

        batch_resid_post_outs = [
            torch.cat(layer_acts_list, dim=0) for layer_acts_list in batch_resid_post_outs
        ]

        for layer in range(n_layers):
            probe = probes[layer]
            opt = optimizers[layer]

            activ_batch = batch_resid_post_outs[layer].to(DEVICE).to(torch.float32)     # activations for this layer
            final_layer_activ_batch = batch_resid_post_outs[-1].to(DEVICE).to(torch.float32)  # final layer activations

            with torch.no_grad():
                print(final_layer_activ_batch.shape)
                final_logits_batch = compute_logits_from_resid(final_layer_activ_batch, soft_cap=30.0)

            out = probe(activ_batch)
            normed_out = ln_final(out)#using the stored ln_final
            B, S, d_model_ = normed_out.shape
            normed_out_flat = normed_out.view(-1, d_model_)
    
            # Before using W_U and b_U in the loop:
            W_U_casted = W_U.to(normed_out_flat.dtype)
            b_U_casted = b_U.to(normed_out_flat.dtype)
            logits_flat = torch.nn.functional.linear(normed_out_flat, W_U_casted.t().contiguous(), b_U_casted)
            probe_logits = logits_flat.view(B, S, -1)
            probe_logits = soft_cap_value * torch.tanh(probe_logits / soft_cap_value)
        
            loss = stable_kl_loss(probe_logits, final_logits_batch)
    
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(probe.parameters(), max_norm=1.0)
            opt.step()
    
            epoch_loss_per_layer[layer] += loss.item()
    
                # Cleanup
            del activ_batch, final_layer_activ_batch, out, normed_out, normed_out_flat
            del logits_flat, probe_logits, loss, final_logits_batch
            gc.collect()
            torch.cuda.empty_cache()

    print("Average loss per layer:")
    for layer in range(n_layers):
        avg_loss = epoch_loss_per_layer[layer] / num_batches
        print(f"  Layer {layer}: {avg_loss:.4f}")

# Save final trained probes (one per layer) after all epochs
print("\nSaving final trained probes per layer...")
for layer in range(n_layers):
    save_dir='probes'
    save_path = f'{save_dir}/probe_{layer}.pt'
    torch.save(probes[layer].state_dict(), save_path)
    print(f"Saved layer {layer} final probe to {save_path}")
