In [2]:
import torch
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
import json
import pandas as pd
import plotly.express as px
from tqdm.auto import tqdm
import gc

# --- CONFIGURATION ---
# Manually select a GPU that has free memory.
# Based on your nvidia-smi, cuda:2 or cuda:3 are good choices.
TARGET_DEVICE = "cuda:2" 

# --- Load the MODELING dataset ---
# Make sure you have a file named 'modeling_dataset.json' in the same directory
# or provide the correct full path to the file.
try:
    with open('modeling_dataset.json', 'r') as f:
        modeling_dataset = json.load(f)
except FileNotFoundError:
    print("Error: 'modeling_dataset.json' not found. Please ensure the dataset file is in the correct directory.")
    exit()

# --- Load the model using TransformerLens ---
print(f"Loading model with TransformerLens onto device: {TARGET_DEVICE}...")
model = HookedTransformer.from_pretrained(
    "Qwen/Qwen2.5-3B-Instruct",
    device=TARGET_DEVICE,
    trust_remote_code=True,
    fold_ln=False, 
    center_unembed=False,
    center_writing_weights=False
)
print("Model loaded successfully.")


# --- FULL DATASET EXPERIMENT ---

all_results = []
# Use tqdm for a progress bar over the problems
for problem in tqdm(modeling_dataset['problems'], desc="Processing Problems"):
    problem_name = problem['name']
    
    # Inner loop for each question in the current problem
    for question_index in range(len(problem['questions'])):
        
        # --- Prepare prompts for the current question ---
        question = problem['questions'][question_index]
        correct_answer = problem['answers'][question_index]

        if not question or not correct_answer:
            continue

        examples = "\n".join(problem['data'])
        clean_prompt_text = f"Translate based on the examples:\n{examples}\n\n{question}"
        
        corrupted_prompt_text = f"Translate based on the examples:\n{examples}\n\n{question} Note: this is a test."

        # --- Tokenize all inputs ---
        # THE FIX: Explicitly move the token tensors to the TARGET_DEVICE.
        clean_tokens = model.to_tokens(clean_prompt_text).to(TARGET_DEVICE)
        corrupted_tokens = model.to_tokens(corrupted_prompt_text).to(TARGET_DEVICE)
        
        try:
            correct_answer_tokens = model.to_tokens(correct_answer, prepend_bos=False)
            if correct_answer_tokens.nelement() == 0: continue
            correct_answer_first_token_id = correct_answer_tokens[0, 0]
            correct_answer_first_token_str = model.to_string([correct_answer_first_token_id])
        except:
            continue

        # --- 1. Get Baseline Logits and Clean Cache ---
        clean_logits, clean_cache = model.run_with_cache(clean_tokens)
        corrupted_logits = model(corrupted_tokens)

        clean_logit_val = clean_logits[0, -1, correct_answer_first_token_id].item()
        corrupted_logit_val = corrupted_logits[0, -1, correct_answer_first_token_id].item()
        
        if abs(clean_logit_val - corrupted_logit_val) < 1e-6:
            continue

        # --- 2. Run the Patching Experiment Across All Layers ---
        hook_name_template = "blocks.{layer}.attn.hook_z"

        for layer in range(model.cfg.n_layers):
            hook_name = hook_name_template.format(layer=layer)
            
            def patching_hook(activation_at_hook, hook):
                clean_activation = clean_cache[hook.name]
                clean_len = clean_activation.shape[1]
                activation_at_hook[:, :clean_len, :] = clean_activation
                return activation_at_hook

            patched_logits = model.run_with_hooks(
                corrupted_tokens,
                fwd_hooks=[(hook_name, patching_hook)]
            )
            
            patched_logit_val = patched_logits[0, -1, correct_answer_first_token_id].item()
            
            recovery_percentage = ((patched_logit_val - corrupted_logit_val) / (clean_logit_val - corrupted_logit_val)) * 100
            
            all_results.append({
                'problem_name': problem_name,
                'question_index': question_index,
                'layer': layer, 
                'recovery': recovery_percentage
            })
            
        # Clear memory after processing each question
        del clean_cache, clean_logits, corrupted_logits
        gc.collect()
        torch.cuda.empty_cache()


# --- 3. Aggregate and Visualize the Results ---
print("\n--- Experiment Complete. Aggregating and Visualizing Results... ---")
if not all_results:
    print("No results to visualize. The experiment may have skipped all problems.")
else:
    df = pd.DataFrame(all_results)
    
    agg_df = df.groupby(['problem_name', 'layer']).agg(mean_recovery=('recovery', 'mean')).reset_index()

    agg_df['layer'] = agg_df['layer'].astype(str).str.zfill(2)

    fig = px.bar(
        agg_df, 
        x='layer', 
        y='mean_recovery', 
        facet_col='problem_name',
        facet_col_wrap=4,
        title='Mean Logit Recovery Percentage by Layer Across Different Problems',
        labels={'layer': 'Transformer Layer', 'mean_recovery': 'Mean Logit Recovery (%)'},
        template='plotly_white',
        height=300 * ((len(agg_df['problem_name'].unique()) // 4) + 1)
    )
    fig.update_yaxes(matches=None)
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fig.show()



Loading model with TransformerLens onto device: cuda:2...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s]


Loaded pretrained model Qwen/Qwen2.5-3B-Instruct into HookedTransformer
Model loaded successfully.


Processing Problems:   0%|          | 0/48 [00:00<?, ?it/s]


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)