In [2]:
import torch
from tqdm import tqdm, trange
from nnsight import LanguageModel
import plotly.graph_objects as go
#from pyvene import BoundlessRotatedSpaceIntervention
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from datasets import Dataset as hf_Dataset
from transformers import get_linear_schedule_with_warmup, AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import numpy as np
import gc
import pickle
import matplotlib.pyplot as plt
import glob
from einops import einsum, repeat, reduce

In [3]:
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"  # Just pass string instead of torch.device
full_model_name = 'EleutherAI/gpt-j-6B'
MODEL_NAME = full_model_name.split('/')[-1]
model = LanguageModel(
    full_model_name,
    device_map=device,  # or you can use "auto" for automatic device mapping
    torch_dtype=torch.bfloat16,
    dispatch=True,
    trust_remote_code=True  # Add this if you're still getting warnings
)
remote = False
NLAYERS = model.config.num_hidden_layers

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

In [4]:
def get_correct_df(mina = 0, maxa = 99):
    df = pd.read_pickle(f'data_addition/gen_math/data_addition_correct_{mina}_{maxa}.pkl')
    return df[df['correct'] == 1]

def gen_intervention(samplesize = 1000, run = False, mina = 0, maxa = 99):
    if run:
        df = get_correct_df(mina = mina, maxa = maxa)
        print(len(df))
        # Drop the 'correct' column since we only have correct examples now
        df = df.drop('correct', axis=1)
        # Create cartesian product of df with itself
        intervention_df = df.merge(df, how='cross', suffixes=('_original', '_intervened'))
        # Filter out cases where original answer equals intervened answer
        intervention_df = intervention_df[intervention_df['answer_original'] != intervention_df['answer_intervened']]
        # Take random sample of size samplesize
        intervention_df = intervention_df.sample(n=samplesize, random_state=42).reset_index(drop = True)
        # Save to pickle file
        intervention_df.to_pickle(f'data_addition/data_intervention_{mina}_{maxa}_sample{samplesize}.pkl')
    intervention_df = pd.read_pickle(f'data_addition/data_intervention_{mina}_{maxa}_sample{samplesize}.pkl')
    return intervention_df

def get_sorted_neuron_df():
   df = pd.read_csv('data_addition/neuron_att_patching.csv')
   df = df.sort_values('logit_difference', ascending=False)
   return df.reset_index(drop=True)

def get_top_percentile_neurondf(fraction):
   df = get_sorted_neuron_df()
   n_rows = int(len(df) * fraction)
   return df.head(n_rows).reset_index(drop=True)


#get_top_percentile_neurondf(0.1)
df_int = gen_intervention(run = False)

In [18]:
def calc_logit_diff(patched_logits, og_logits, cor_answer_tokens):
    # Get logit differences between patched and original logits for each answer token
    patched_target =  patched_logits[torch.arange(len(cor_answer_tokens)), cor_answer_tokens]
    og_target = og_logits[torch.arange(len(cor_answer_tokens)), cor_answer_tokens]
    logit_diffs = patched_target - og_target
    return logit_diffs

def log_diff_metrics(log_diff):
    mean = log_diff.mean()
    nonzero_mean = log_diff[log_diff != 0].mean() if (log_diff != 0).any() else torch.tensor(0.0)
    return {'mean': mean, 'nonzero_mean': nonzero_mean}

def path_patch_neuron(layer, neuron_idx, num_to_use = 100, batch_size=100):
    df = gen_intervention(run = False)
    if num_to_use is not None:
        df = df.sample(n=num_to_use, random_state=42)
    clean_tokens = torch.stack([torch.tensor(x) for x in df['q_tok_original'].values])
    corrupt_tokens = torch.stack([torch.tensor(x) for x in df['q_tok_intervened'].values])
    corrupt_answer_tokens = torch.stack([torch.tensor(x[0]) for x in df['answer_tok_intervened'].values])
    with torch.no_grad():
        all_log_diffs = []
        for i in range(0, len(clean_tokens), batch_size):
            batch_clean = clean_tokens[i:i+batch_size].to('cuda')
            batch_corrupt = corrupt_tokens[i:i+batch_size].to('cuda') 
            batch_corrupt_answers = corrupt_answer_tokens[i:i+batch_size]
            with model.trace() as tracer:
                with tracer.invoke(batch_clean) as invoker_clean: # gets the clean tokens for calculation
                    pass
                with tracer.invoke(batch_corrupt) as invoker:
                    neuron_val = model.transformer.h[layer].mlp.act.output[:,-1, neuron_idx].unsqueeze(1).save() #k, 1
                    neuron_projection = model.transformer.h[layer].mlp.fc_out.weight[:, neuron_idx].unsqueeze(0)  #1, model_dim
                    neuron_act_corrupt = (neuron_val * neuron_projection).save()
                with tracer.invoke(batch_clean) as invoker:
                    neuron_val = model.transformer.h[layer].mlp.act.output[:,-1, neuron_idx].unsqueeze(1).save() #k, 1
                    neuron_projection = model.transformer.h[layer].mlp.fc_out.weight[:, neuron_idx].unsqueeze(0)  #1, model_dim
                    neuron_act_clean = (neuron_val * neuron_projection).save()
                    model.transformer.h[-1].output[0][:,-1] += neuron_act_corrupt - neuron_act_clean
            
                #print(neuron_act.shape)
                output = model.output.save()
            clean_logits, patched_logits = output.logits[:batch_size,-1].cpu(), output.logits[-batch_size:,-1].cpu()
            log_diff = calc_logit_diff(patched_logits, clean_logits, batch_corrupt_answers).cpu()
            all_log_diffs.append(log_diff)
            # correct = calc_correct(patched_logits, cor_answer_tokens)
            # metrics = {'logit_diff':log_diff, 'accuracy': correct}
        logdiffs = torch.cat(all_log_diffs)
        return logdiffs


path_patch_neuron(20,7741, 300, 300).mean()

tensor(0.1292)

In [19]:
def patch_patch_all(frac = 0.01):
    ndf = get_sorted_neuron_df()
    ndf['path_patch'] = float('nan')
    ndf['path_patch_nonzero'] = float('nan')
    # Take top 1% by logit_difference
    ndf = ndf.nlargest(int(len(ndf) * frac), 'logit_difference')
    bar = tqdm(ndf.iterrows(), total=len(ndf))
    for i, row in bar:
        layer, neuron_idx = int(row['layer']), int(row['neuron_idx'])
        log_diff = path_patch_neuron(layer, neuron_idx, num_to_use=300, batch_size=300)
        metrics = log_diff_metrics(log_diff)
        ndf.loc[i, 'path_patch'] = metrics['mean'].item()
        ndf.loc[i, 'path_patch_nonzero'] = metrics['nonzero_mean'].item()
        
        if i % 1000 == 0:
            ndf.to_csv('data_addition/neuron_path_patching.csv', index=False)
        bar.set_postfix({'rank':i, 'log_diff':f"{row['logit_difference']:.4f}", 'path_patch':f"{ndf.loc[i, 'path_patch']:.4f}"})
    # Save final results
    ndf.to_csv('data_addition/neuron_path_patching.csv')

#patch_patch_all()

100%|██████████| 4587/4587 [2:00:35<00:00,  1.58s/it, rank=4586, log_diff=0.0010, path_patch=-0.0002]
