In [1]:
%load_ext autoreload
%autoreload 2

import os
from os.path import expanduser
home = expanduser("~")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import pandas as pd
from neel_plotly import line, imshow, scatter
from jaxtyping import Float
import plotly.io as pio

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f8b525cd910>

In [None]:
from load_dataset import get_batched_dataset


In [None]:
n_devices = torch.cuda.device_count()

model = HookedTransformer.from_pretrained(
    "gemma-2b",
    center_unembed=True,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    #n_devices=2
)
model.set_use_attn_result(True)
# Get the default device used
device: torch.device = utils.get_device()

In [None]:
from load_dataset import load_sva_dataset

dataset_type = 'both' # singular / plural / both
language = 'english' # english / spanish / both
num_samples = 200
batch_size = 10
start_at = 0
dataset = load_sva_dataset(model, language, dataset_type, num_samples)
batches_base_tokens, batches_src_tokens, batches_answer_token_indices = get_batched_dataset(model,
                                                                                            dataset['base_list'],
                                                                                            dataset['src_list'],
                                                                                            dataset['answers'],
                                                                                            batch_size=batch_size)

In [None]:
# Getting steering vectors (PCA directions or difference in means)
tensors_dir = f"{home}/mats/sva_tensors"
attn_layer_index = 13
attn_head_index = 7
hook_name = utils.get_act_name('result', attn_layer_index)
method = 'pca' # diff_means / pca
tensors_dict_ = {}
with safe_open(f"{tensors_dir}/{model.cfg.model_name}_both_english_{hook_name}_{method}_singular_plural.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors_dict_[k] = f.get_tensor(k)
if method == 'pca':
      # pca_component, head_index, d_repre
      direction = tensors_dict_['directions'][0][attn_head_index]
elif method == 'diff_means':
      direction = tensors_dict_['direction'][attn_head_index]


# adding 8*(PCA 1st component) flips prediction in singular (+) and plural (-)
# adding 8*(diff_means) flips prediction in singular (-) and plural (+)

In [None]:
clean_logit_list = []
zeroed_logit_list = []
patched_logit_list = []

def patch_hook(acts, hook, attn_head_index, steering_vector):
            # Adding steering vector to hook
            # [batch, pos, head_index, d_model]
            steering_vector = steering_vector / steering_vector.norm()
            new_acts = acts.clone()
            mean_head_norm = new_acts[:,-1,attn_head_index].norm(dim=-1).mean().item()
            #print('Orig attn out norm', mean_head_norm)
            #print('steering vector norm', steering_vector.norm())
            # print('attn output shape', new_acts[:,-1,attn_head_index].shape)
            # project out the subject number feature in attn_head_index output
            #new_acts[:,-1,attn_head_index] = new_acts[:,-1,attn_head_index] - (new_acts[:,-1,attn_head_index] @ steering_vector)[..., None] * steering_vector
            # Project out on every head output
            #new_acts[:,-1,:] = new_acts[:,-1,:] - (new_acts[:,-1,:] @ steering_vector)[..., None] * steering_vector

            new_acts[:,-1,attn_head_index] -= mean_head_norm*steering_vector.unsqueeze(0) #torch.zeros([1,2048]).to(device)
            return new_acts

def zero_hook(acts, hook, attn_head_index):
            # Adding steering vector to hook
            new_acts = acts.clone()
            #print('hey', new_acts.shape)
            new_acts[:,-1,attn_head_index,:] = torch.zeros([1,1,2048]).to(device)
            return new_acts

patch_hook_fn = partial(patch_hook, attn_head_index=attn_head_index, steering_vector=direction.to(device))
zero_hook_fn = partial(zero_hook, attn_head_index=attn_head_index)
hook_to_steer = hook_name#utils.get_act_name('result', attn_layer_index)#utils.get_act_name('resid_mid', 13)
hook_to_zero = hook_name#utils.get_act_name('result', attn_layer_index)

for batch in range(batchs):
    model.reset_hooks(including_permanent=True)
    # Clean run
    base_logits = model(batches_base_tokens[batch])
    base_logits_idxs = base_logits[0, -1, :].topk(20, dim=-1).indices
    logit_toks = [model.to_str_tokens(x) for x in base_logits_idxs]
    print('before', logit_toks)

    # Zero ablation run (zero ablate attention head)
    zeroed_logits = model.run_with_hooks(batches_base_tokens[batch],
                            return_type="logits",
                            fwd_hooks=[(hook_to_zero, zero_hook_fn)]
                            )

    model.reset_hooks(including_permanent=True)
    # Patched (steered) run
    patched_logits = model.run_with_hooks(batches_base_tokens[batch],
                            return_type="logits",
                            fwd_hooks=[(hook_to_steer, patch_hook_fn)]#,
                                        #(hook_to_zero, zero_hook_fn)]
                            )
    patched_logits_idxs = patched_logits[0, -1, :].topk(20, dim=-1).indices
    patched_toks = [model.to_str_tokens(x) for x in patched_logits_idxs]
    print('after', patched_toks)

    answer_token_indices = batches_answer_token_indices[batch]
    answer_token_indices = answer_token_indices.to(base_logits.device)

    base_logit_diff = get_logit_diff(base_logits, answer_token_indices).item()
    clean_logit_list.append(base_logit_diff)

    zeroed_logit_diff = get_logit_diff(zeroed_logits, answer_token_indices).item()
    zeroed_logit_list.append(zeroed_logit_diff)

    patched_logit_diff = get_logit_diff(patched_logits, answer_token_indices, mean=True).item()
    # condition = patched_logit_diff > 0.
    # row_cond = condition.all(1)
    patched_logit_list.append(patched_logit_diff)

print(f"Base logit diff: {np.array(clean_logit_list).mean():.4f}")
#print(f"Zero ablation logit diff: {np.array(zeroed_logit_list).mean():.4f}")
print(f"Patched logit diff: {np.array(patched_logit_list).mean():.4f}")



In [None]:
# model.reset_hooks(including_permanent=True)
# model.add_perma_hook(hook_to_steer, patch_hook_fn)

In [None]:
clean_probs = base_logits.softmax(dim=-1)[:,-1]
clean_answer_probs = clean_probs.gather(-1, answer_token_indices[:,0].unsqueeze(1)).squeeze()

patched_probs = patched_logits.softmax(dim=-1)[:,-1]
patched_answer_probs = patched_probs.gather(-1, answer_token_indices[:,0].unsqueeze(1)).squeeze()

from rich import print as rprint
from rich.table import Table, Column
cols = [
    "Prompt",
    Column("Correct", style="rgb(0,200,0) bold"),
    #Column("Incorrect", style="rgb(255,0,0) bold"),
    Column("Prob correct", style="bold"),
    Column("Patched prob correct", style="bold")
]
table = Table(*cols, title="Logit differences")

for prompt, answer, prob_clean_answer, prob_patched_answer in zip(batches_base_list[batch], batches_answers[batch], clean_answer_probs, patched_answer_probs):
    table.add_row(prompt, repr(answer[0]), f"{prob_clean_answer.item():.3f}", f"{prob_patched_answer.item():.3f}")

rprint(table)
