In [None]:
!pip install torch transformer-lens datasets accelerate wandb

In [None]:

from datasets import load_dataset
import json

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm 
import gc
from transformer_lens import HookedTransformer
from huggingface_hub import login
from torch.utils.data import DataLoader, TensorDataset

from accelerate import Accelerator
import wandb


In [None]:

from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()

wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wandb_api_key)

In [None]:

ACCUMULATION_STEPS = 16 # Increase this number if you still face OOM errors. It trades speed for memory.

accelerator = Accelerator(
    gradient_accumulation_steps=ACCUMULATION_STEPS,
    log_with="wandb"
)


In [None]:
config = {
    "model_name": "google/gemma-2-2b-it",
    "num_epochs": 2,
    "seq_len": 128,
    "subset_size": 8, # Number of sequences to process at a time
    "learning_rate": 1e-4,
    "accumulation_steps": ACCUMULATION_STEPS,
    "torch_dtype": torch.float16,
    "soft_cap": 30.0,
}

In [None]:
accelerator.init_trackers(
    project_name="Tuned-Lens-Gemma-2B", 
    config=config
)

In [None]:

ACTIVATION_DIR = "/kaggle/working/activations"
PROBE_DIR = "/kaggle/working/probes"
os.makedirs(ACTIVATION_DIR, exist_ok=True)
os.makedirs(PROBE_DIR, exist_ok=True)
DEVICE = accelerator.device

In [None]:
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)


model = HookedTransformer.from_pretrained(
    config["model_name"], 
    device="cpu", 
    torch_dtype=config["torch_dtype"],
    center_unembed=False
)

for block in model.blocks:
    block.gradient_checkpointing = True

# Freeze parameters
model.eval()
for param in model.parameters():
    param.requires_grad = False

tokenizer = model.tokenizer

# Move the model to the correct device
model = model.to(DEVICE)

print(f"Model '{config['model_name']}' loaded onto device: {model.cfg.device}")
print("Activation Checkpointing: ENABLED (manual block method)")

In [None]:
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
d_vocab = model.cfg.d_vocab

In [None]:

BATCH_SIZE = 128
SEQ_LEN = 128
SUBSET_SIZE = 8 

In [None]:
NUM_EXAMPLES = 39470
dataset = load_dataset("allenai/c4", name="en", split="train", streaming=True)
text_samples = []
for example in dataset:
    text = example.get("text", "").strip()
    if text:
        text_samples.append(text)
    if len(text_samples) >= NUM_EXAMPLES:
        break
print(f"Loaded {len(text_samples)} clean text samples")

# Join all text into one long sequence of tokens
joined_text = "\n\n".join(text_samples) 
all_token_ids = tokenizer.encode(joined_text, return_tensors='pt')[0]  # shape: (total_tokens,)
print(f"Token sequence length: {len(all_token_ids)} tokens")


In [None]:
# Let's define your special strings exactly:
prefix_text = "<start_of_turn>user\n"
suffix_text = "<end_of_turn>"

# Encode prefix and suffix tokens with the tokenizer:
prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)  # e.g. [....]
suffix_tokens = tokenizer.encode(suffix_text, add_special_tokens=False)  # e.g. [....]

print(f"Prefix tokens count: {len(prefix_tokens)}")
print(f"Suffix tokens count: {len(suffix_tokens)}")

In [None]:
def create_dataset_from_tokens(token_ids, seq_len):
    examples = []
    total_tokens = len(token_ids)
    seq_len=seq_len-4
    for i in range(0, total_tokens - seq_len, seq_len):
        input_ids = token_ids[i : i + seq_len]
        wrapped_seq=prefix_tokens+input_ids.tolist()+suffix_tokens
        wrapped_tensor=torch.tensor(wrapped_seq, dtype=torch.long)
        examples.append(wrapped_tensor)
    input_ids_tensor = torch.stack(examples)
    return input_ids_tensor

In [None]:
TARGET_NUM_SEQUENCES = 130351
input_ids = create_dataset_from_tokens(all_token_ids, seq_len=config["seq_len"])
input_ids = input_ids[:TARGET_NUM_SEQUENCES]
print(f"Total sequences: {len(input_ids)}")
N = len(input_ids)
num_subsets = N // config["subset_size"]
print(num_subsets)

In [None]:
def stable_kl_loss(probe_logits, final_logits, top_k=100, eps=1e-8):
    B, S, V = final_logits.size()

    _, topk_final = torch.topk(final_logits, top_k, dim=-1)
    _, topk_probe = torch.topk(probe_logits, top_k, dim=-1)

    combined = torch.cat([topk_final, topk_probe], dim=-1).view(-1, 2 * top_k)  # shape: (B*S, 2*k)

    max_len = combined.size(-1)
    unique_list = []
    mask_list = []
    for idxs in combined:
        unique = torch.unique(idxs, sorted=True)
        # Pad with the last valid index if shorter than max_len
        if unique.size(0) < max_len:
            padded = torch.cat([unique, unique[-1].repeat(max_len - unique.size(0))])
        else:
            padded = unique[:max_len]
        unique_list.append(padded)
        # Create mask to mark valid indices
        mask_list.append(torch.arange(max_len, device=unique.device) < unique.size(0))

    union_indices = torch.stack(unique_list).view(B, S, max_len)
    valid_mask = torch.stack(mask_list).view(B, S, max_len)  


    probe_subset_logits = torch.gather(probe_logits, 2, union_indices).float()
    final_subset_logits = torch.gather(final_logits, 2, union_indices).float()

    target_probs = F.softmax(final_subset_logits, dim=-1).clamp(min=eps)
    target_log_probs = torch.log(target_probs)
    pred_log_probs = F.log_softmax(probe_subset_logits, dim=-1)

    # Compute KL divergence elements and mask out padded positions
    kl_elements = target_probs * (target_log_probs - pred_log_probs) * valid_mask.float()

    # Sum over vocabulary subset dimension
    kl_per_token = kl_elements.sum(dim=-1)

    # Compute mean only over valid tokens (positions where any valid vocab exists)
    # This prevents division by zero and excludes padded positions properly
    token_mask = (valid_mask.any(dim=-1)).float()  # shape (B, S)
    loss = kl_per_token.sum() / (token_mask.sum() + eps)
    return loss


In [None]:


def get_activations(model, batch_input):  # Added 'model' as the first argument
    with torch.no_grad():
        hooks_to_cache = [f'blocks.{i}.hook_resid_post' for i in range(n_layers)] 

        # Now, it uses the specific 'model' object that we pass to it
        logits, cache = model.run_with_cache(batch_input, names_filter=hooks_to_cache)
        
        resid_post_outs = [cache[f'blocks.{i}.hook_resid_post'] for i in range(n_layers)]

    return resid_post_outs

In [None]:
def compute_logits_from_resid(resid, soft_cap=30.0):

    normed = ln_final(resid)

    W_U_casted = W_U.to(normed.device).type_as(normed)
    b_U_casted = b_U.to(normed.device).type_as(normed)
    logits = torch.einsum('bsd,dk->bsk', normed, W_U_casted) + b_U_casted

    logits = soft_cap * torch.tanh(logits / soft_cap)#done in gemma 2
    return logits


In [None]:
class Probe(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(d_model, d_model, bias=True)
        torch.nn.init.xavier_normal_(self.linear.weight)
        torch.nn.init.zeros_(self.linear.bias)            

    def forward(self, x):
        return self.linear(x)

In [None]:
W_U = model.state_dict()['unembed.W_U'].to(DEVICE).clone()
b_U = model.state_dict()['unembed.b_U'].to(DEVICE).clone()

import copy
ln_final = copy.deepcopy(model.ln_final).to(DEVICE)


In [None]:


import os
import torch


checkpoint_dir = "/kaggle/working/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)


resume_epoch = 0
resume_chunk = 0

resume_from_checkpoint = os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0

if resume_from_checkpoint:
    print("Resuming from checkpoint...")
    accelerator.load_state(checkpoint_dir)
    
    progress_tracker = torch.load(os.path.join(checkpoint_dir, "progress_tracker.pt"))
    resume_epoch = progress_tracker["epoch"]
    resume_chunk = progress_tracker["chunk"] + 1 # We start from the *next* chunk
    print(f"Resuming from Epoch {resume_epoch}, Chunk {resume_chunk}")

In [None]:
import time

CHUNK_SIZE = 32
num_chunks = (num_subsets + CHUNK_SIZE - 1) // CHUNK_SIZE

probes = nn.ModuleList([Probe(d_model) for _ in range(n_layers)])
optimizers = [torch.optim.AdamW(probe.parameters(), lr=config["learning_rate"]) for probe in probes]
probes, optimizers = accelerator.prepare(probes, optimizers)

W_U = W_U.to(DEVICE)
b_U = b_U.to(DEVICE)
ln_final = ln_final.to(DEVICE)


for epoch in range(resume_epoch, config["num_epochs"]):
    print(f"--- Starting Epoch {epoch + 1}/{config['num_epochs']} ---")
    
    for chunk_idx in tqdm(range(resume_chunk, num_chunks), initial=resume_chunk, total=num_chunks, desc="Chunk Progress"):
        start_subset = chunk_idx * CHUNK_SIZE
        end_subset = min(start_subset + CHUNK_SIZE, num_subsets)
        print(f"\n-- Processing Chunk {chunk_idx + 1}/{num_chunks} (Subsets {start_subset} to {end_subset-1}) --")
        caching_start_time = time.time()

        # === STEP 1: CACHE Activations for the current chunk ===
        print("Caching activations for this chunk...")
        # Prepare the model for inference on the accelerator
        model_for_caching = accelerator.prepare_model(model)
        
        for subset_idx in tqdm(range(start_subset, end_subset), desc="Caching"):
            start = subset_idx * config["subset_size"]
            end = start + config["subset_size"]
            subset_input = input_ids[start:end].to(DEVICE)
            with torch.no_grad():
                resid_post_outs = get_activations(model_for_caching, subset_input)
            
            cpu_resid_post_outs = [act.to("cpu") for act in resid_post_outs]
            torch.save(cpu_resid_post_outs, f'{ACTIVATION_DIR}/activations_{subset_idx + 1}.pth')

            del subset_input, resid_post_outs, cpu_resid_post_outs
            gc.collect()
            torch.cuda.empty_cache()

        caching_duration = time.time() - caching_start_time
        training_start_time = time.time()


        # === STEP 2: TRAIN Probes on this chunk's activations ===
        print("Training probes on this chunk's activations...")
        for layer in range(n_layers):
            probe = probes[layer]
            opt = optimizers[layer]
            probe.train()
            
            for subset_idx in tqdm(range(start_subset, end_subset), desc=f"Layer {layer} Training"):
                with accelerator.accumulate(probe):
                    acts = torch.load(f'{ACTIVATION_DIR}/activations_{subset_idx + 1}.pth', map_location='cpu')
                    activ_batch = acts[layer].to(DEVICE).to(torch.float32)
                    final_layer_activ_batch = acts[-1].to(DEVICE).to(torch.float32)
                    
                    with torch.no_grad():
                        final_logits_batch = compute_logits_from_resid(final_layer_activ_batch)
                    
                    with accelerator.autocast():
                        out = probe(activ_batch)
                        normed_out = ln_final(out)
                        W_U_casted = W_U.to(normed_out.dtype)
                        b_U_casted = b_U.to(normed_out.dtype)
                        probe_logits = torch.einsum('bsd,dk->bsk', normed_out, W_U_casted) + b_U_casted
                        probe_logits = config["soft_cap"] * torch.tanh(probe_logits / config["soft_cap"])
                        loss = stable_kl_loss(probe_logits, final_logits_batch)

                    accelerator.backward(loss)
                    opt.step()
                    opt.zero_grad()
                    accelerator.log({"loss": loss.item() * config["accumulation_steps"]})

        training_duration = time.time() - training_start_time

        accelerator.log({
            "caching_duration_sec": caching_duration,
            "training_duration_sec": training_duration
        })

        # === STEP 3: CLEAN UP the cached files for this chunk ===
        print("Cleaning up cached activation files for this chunk...")
        for subset_idx in range(start_subset, end_subset):
            file_path = f'{ACTIVATION_DIR}/activations_{subset_idx + 1}.pth'
            if os.path.exists(file_path):
                os.remove(file_path)

        print(f"Chunk {chunk_idx} complete. Saving checkpoint...")
        accelerator.save_state(checkpoint_dir)
        torch.save({"epoch": epoch, "chunk": chunk_idx}, os.path.join(checkpoint_dir, "progress_tracker.pt"))
        # === END OF CHUNK ===

    resume_chunk = 0

# --- Save Final Probes ---
accelerator.wait_for_everyone()
print("\nSaving final trained probes...")
for i, probe in enumerate(probes):
    unwrapped_probe = accelerator.unwrap_model(probe)
    accelerator.save(unwrapped_probe.state_dict(), f'{PROBE_DIR}/probe_{i}.pt')

accelerator.end_training()

In [None]:
print("Training finished. Deleting the base model from VRAM...")
del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Cleanup complete.")