# Causal Analysis of Symmetric Facts
We analyze where facts might be locateed in an LLM. A fact is represented by a tuple of subject, relation, object (s,r,o). We also investigate where an inverse fact mis located. An inverse fact is represented by the tuple (o, r^-1, s). For example, the fact "Paris is the capital of France" has the inverse fact "France's capital is France". 

## Setup

In [4]:
try:
    import google.colab
    IN_COLAB = True
    # %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    renderer = "colab"
except:
    IN_COLAB = False
    from IPython import get_ipython
    %load_ext autoreload
    %autoreload 2
    renderer = "jupyterlab"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
%%bash
cd ../
pip install poetry
poetry install
cd notebooks

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting urllib3<2.0.0,>=1.26.0 (from poetry)
  Using cached urllib3-1.26.16-py2.py3-none-any.whl (143 kB)
Collecting platformdirs<4.0.0,>=3.0.0 (from poetry)
  Using cached platformdirs-3.10.0-py3-none-any.whl (17 kB)
Installing collected packages: urllib3, platformdirs
  Attempting uninstall: urllib3
    Found existing installation: urllib3 2.0.3
    Uninstalling urllib3-2.0.3:
      Successfully uninstalled urllib3-2.0.3
  Attempting uninstall: platformdirs
    Found existing installation: platformdirs 3.8.0
    Uninstalling platformdirs-3.8.0:
      Successfully uninstalled platformdirs-3.8.0
Successfully installed platformdirs-3.10.0 urllib3-1.26.16
Installing dependencies from lock file

Package operations: 0 installs, 2 updates, 0 removals

  • Updating urllib3 (1.26.16 -> 2.0.3)
  • Updating platformdirs (3.10.0 -> 3.8.0)

Installing the current project: transformer-lens (0.0.0)


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sagemaker 2.167.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 5.2.0 which is incompatible.
sagemaker 2.167.0 requires protobuf<4.0,>=3.1, but you have protobuf 4.23.3 which is incompatible.
sparkmagic 0.20.5 requires nest-asyncio==1.5.5, but you have nest-asyncio 1.5.6 which is incompatible.


Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting urllib3<2.0.0,>=1.26.0 (from poetry)
  Using cached urllib3-1.26.16-py2.py3-none-any.whl (143 kB)
Collecting platformdirs<4.0.0,>=3.0.0 (from poetry)
  Using cached platformdirs-3.10.0-py3-none-any.whl (17 kB)
Installing collected packages: urllib3, platformdirs
  Attempting uninstall: urllib3
    Found existing installation: urllib3 2.0.3
    Uninstalling urllib3-2.0.3:
      Successfully uninstalled urllib3-2.0.3
  Attempting uninstall: platformdirs
    Found existing installation: platformdirs 3.8.0
    Uninstalling platformdirs-3.8.0:
      Successfully uninstalled platformdirs-3.8.0
Successfully installed platformdirs-3.10.0 urllib3-1.26.16
Installing dependencies from lock file

Package operations: 0 installs, 2 updates, 0 removals

  • Updating urllib3 (1.26.16 -> 2.0.3)
  • Updating platformdirs (3.10.0 -> 3.8.0)

Installing the current project: transformer-lens (0.0.0)


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sagemaker 2.167.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 5.2.0 which is incompatible.
sagemaker 2.167.0 requires protobuf<4.0,>=3.1, but you have protobuf 4.23.3 which is incompatible.
sparkmagic 0.20.5 requires nest-asyncio==1.5.5, but you have nest-asyncio 1.5.6 which is incompatible.


In [6]:

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = renderer

In [7]:


# Import stuff
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
import json
from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy
import subprocess

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from tqdm import tqdm

In [8]:
import ast
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import torch
from matplotlib.cm import ScalarMappable
import pickle

Matplotlib is building the font cache; this may take a moment.
Matplotlib is building the font cache; this may take a moment.


In [9]:
if IN_COLAB: 
    import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens.utilities import devices
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [10]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fbb22d548e0>

<torch.autograd.grad_mode.set_grad_enabled at 0x7fbb22d548e0>

Plotting helper functions:

In [11]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def cuda():
    return torch.cuda.is_available()

def get_device(): 
    return "cuda" if cuda() else "cpu"

def save_pickle(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

# Function to load a pickle object from a file
def load_pickle(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    return obj

def save_json(data, file_path):
    with open(file_path, "w") as json_file:
        json.dump(data, json_file, indent=4)

def load_json(file_path):
    with open(file_path, "r") as json_file:
        loaded_data = json.load(json_file)
    return loaded_data

def get_gpu_usage(): 
    subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

device = get_device()
device

'cuda'

'cuda'

## Model

- fold_ln: Whether to fold in the LayerNorm weights to the subsequent linear layer. This does not change the computation.
- center_writing_weights: Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNormthis doesn't change the computation.
- center_unembed : Whether to center W_U (ie set mean to be zero). Softmax is translation invariant so this doesn't affect log probs or loss, but does change logits. Defaults to True.
- refactor_factored_attn_matrices: Whether to convert the factoredmatrices (W_Q & W_K, and W_O & W_V) to be "even". 

In [12]:
MODEL_NAME = "gpt2-large"
model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        center_unembed=True,  
        center_writing_weights=True,              # Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the computation.      
        fold_ln=True,                             # Whether to  fold in the LayerNorm weights to the subsequent linear layer.
        refactor_factored_attn_matrices=True,
    )


 

Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer
Loaded pretrained model gpt2-large into HookedTransformer


## Dataset
We use the CounterFact Dataset

In [13]:

dataset = load_json("data/fact_dataset.json")

In [14]:
def sample_dataset(dataset, idx = None): 
    if idx is None: 
        prompt = "The {} is located in"
        subject = "Eiffel Tower"
        target = "Paris"
    else: 
        sample = dataset[idx]
        prompt = sample["requested_rewrite"]["prompt"]
        subject = sample["requested_rewrite"]["subject"]
        target = sample["requested_rewrite"]["target_true"]["str"]
    return prompt, subject, target

In [15]:
sample = sample_dataset(dataset)
prompt, subject, target = sample
prompt = prompt.format(subject)

In [16]:
model.to_tokens(prompt).shape

torch.Size([1, 9])

torch.Size([1, 9])

## Ablation Methods: 
- Noise ablation 
- Resample ablation

In [17]:
def noise_ablation(prompt, subject, target, n_noise_samples=5, vx=3):
    subject_tokens = model.to_tokens(subject)
    
    #shape: batch, n_tokens, embedding_dim
    subject_embedding = model.embed(subject_tokens)
    _, n_tokens, embedding_dim = subject_embedding.shape
    
    #noise: N(0,v), v = 3*std(embedding)
    embedding = model.W_E
    v = vx*torch.std(embedding, dim=0) #for each v in V
    noise = torch.randn(
        (n_noise_samples, n_tokens, embedding_dim)
    ).to(device) + v
    
    subject_embedding_w_noise = subject_embedding + noise
    
    #shape: batch, n_tokens, vocab_size (logits)
    unembedded_subject = model.unembed(subject_embedding_w_noise)

    noisy_subject_tokens = torch.argmax(unembedded_subject, dim=-1)
    noisy_subject_str = [
        model.to_string(nst) for nst in noisy_subject_tokens
    ]
    true_prompt = prompt.format(subject)
    corrupted_prompts = [
        prompt.format(nss) for nss in noisy_subject_str
    ]
    return true_prompt, corrupted_prompts, target

In [19]:
prompt, subject, target = sample_dataset(dataset)
noise_ablation(prompt, subject, target)

('The Eiffel Tower is located in',
 ['The  Zan agre¶Unionfacts is located in',
  'The benefjadOntfenande is located in',
  'The Sharp smackitsch MAGeper is located in',
  'The  Canceiasmrothdescriptionzek is located in',
  'The idth nominal subsequ unemploy Conserv is located in'],
 'Paris')

('The Eiffel Tower is located in',
 ['The  Zan agre¶Unionfacts is located in',
  'The benefjadOntfenande is located in',
  'The Sharp smackitsch MAGeper is located in',
  'The  Canceiasmrothdescriptionzek is located in',
  'The idth nominal subsequ unemploy Conserv is located in'],
 'Paris')

In [20]:
def resample_ablation(prompt, subject, target, n_noise_samples=20):
    subject_tokens = model.to_tokens(subject)
    embedding = model.W_E
    #we select n random rows from the embedding matrix
    permutations = torch.randperm(embedding.size(0))[:n_noise_samples]
    random_samples = embedding[permutations]
    #unsqueeze a token dimension between batch and embedding dims
    random_samples = random_samples.unsqueeze(dim=1)
    #we de-embed these rows
    random_embeddings = model.unembed(random_samples)
    random_tokens = torch.argmax(random_embeddings, dim=-1)
    random_subject_str = [
        model.to_string(t) for t in random_tokens
    ]
    corrupted_facts = [
        prompt.format(s) for s in random_subject_str
    ]
    true_fact = prompt.format(subject)
    return true_fact, corrupted_facts, target
    

    
    

In [21]:
resample_ablation(prompt, subject, target, n_noise_samples=5)

('The Eiffel Tower is located in',
 ['The  CE is located in',
  'The Graphics is located in',
  'The  advancing is located in',
  'The  curls is located in',
  'The  cozy is located in'],
 'Paris')

('The Eiffel Tower is located in',
 ['The  CE is located in',
  'The Graphics is located in',
  'The  advancing is located in',
  'The  curls is located in',
  'The  cozy is located in'],
 'Paris')

In [22]:
def pad_from_left(tokens : torch.tensor, maxlen:int):
    pad_token = model.tokenizer.pad_token_id
    padded_tokenized_inputs = torch.zeros(tokens.shape[0], maxlen)
    
    n_pads = maxlen - tokens.shape[-1]
    padded_tokenized_inputs[:,n_pads] = pad_token
    padded_tokenized_inputs[:,n_pads:] = tokens
    return padded_tokenized_inputs.long()

def pad_to_same_length(clean_tokens, corrupted_tokens): 
    
    maxlen = max([clean_tokens.shape[-1], corrupted_tokens.shape[-1]])
    
    if clean_tokens.shape[-1] > corrupted_tokens.shape[-1]: 
        corrupted_tokens = pad_from_left(corrupted_tokens, maxlen)
    elif clean_tokens.shape[-1] < corrupted_tokens.shape[-1]: 
        clean_tokens = pad_from_left(clean_tokens, maxlen)
    return clean_tokens, corrupted_tokens

In [23]:
import pickle

def save(tensor, filename):
    with open(filename, 'wb') as f:
        pickle.dump(tensor, f)

def load(filename):
    with open(filename, 'rb') as f:
        loaded_tensor = pickle.load(f)
    return loaded_tensor


In [30]:
def unembedding_function(residual_stack, cache, mlp=False) -> float:
    #we are only interested in applying the layer norm of the final layer on the final token
    #shape: [74, 5, 10, 1280] = n_layers, prompts, tokens, d_model
    z = cache.apply_ln_to_stack(residual_stack, layer = -1, mlp_input=mlp)
    z = z @ model.W_U
    return z


def delta_ablate(cache,
                  mle_token_idx,
                  target_token_idx, 
                  mlp_input, 
                return_labels=False):
    if return_labels: 
        residual_ablated_stack, labels = cache.decompose_resid(layer=-1, mlp_input=mlp_input, mode="all", return_labels=True)
    else: 
        residual_ablated_stack = cache.decompose_resid(layer=-1, mlp_input=mlp_input, mode="all", return_labels=False)
    
    #shape: [74, 4, 9, 50257] = n_layers, batch, tokens, vocab_size
    residual_logits = unembedding_function(residual_ablated_stack, cache)
    residual_logits = residual_logits[:,:,-1,:]

    target_idx_expanded = target_token_idx.repeat(residual_logits.shape[0],1,1)
    mle_idx_expanded = mle_token_idx.repeat(residual_logits.shape[0],1,1)

    target_residual_logits = residual_logits.gather(index=target_idx_expanded, dim=-1) - residual_logits.mean(dim=-1, keepdim=True)
    mle_residual_logits = residual_logits.gather(index=mle_idx_expanded, dim=-1) - residual_logits.mean(dim=-1, keepdim=True)
    
    if return_labels: 
        return labels, mle_residual_logits, target_residual_logits
    return mle_residual_logits, target_residual_logits

def patch_layer(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    cache):
    corrupted_residual_component[:, :, :] = cache[hook.name][:, :, :]
    """
    Restore a patch in clean run. 
    """
    return corrupted_residual_component


def run_all(clean_prompt: str,
                         corrupted_prompts: List[str],
                         target: str, 
                         corrupted_ablation=True, 
                         activation_to_ablate = "attn_out", 
                         mlp_input=False):
    """
    activation_to_ablate: activation (resid_pre, attn_out, or mlp_out to ablate). resid_pre corresponds to the value of the residual
    stream at a layer and a position. attn_out and mlp_out correspond to the quantities added into the residual stream at a given 
    attention or mlp head. The values in the residual stream are preserved but the quantities being injected are corrupted. 
    
    mlp_input: Whether to include attn_out for the current
    layer - essentially decomposing the residual stream that's input to the MLP input rather than the Attn input.
    
    mode (str): Values are "all", "mlp" or "attn". "all" returns all
    components, "mlp" returns only the MLP components, and "attn" returns only the attention components.
    
    Here we get the compensatory effect of downstream layers following an ablation on the MLE token logits and the target token logits.
    """
    
    clean_tokens = model.to_tokens(clean_prompt, prepend_bos=True) 
    corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
    #we take first token for the target. Not ideal though
    target_token_idx = model.to_tokens(target)[:,1] #remember to slice out the start pad token

    #we pad the clean and corrupted tokens to the same length (pad from left)
    clean_tokens, corrupted_tokens = pad_to_same_length(clean_tokens, corrupted_tokens)
    #repeat clean_tokens to match corrupted_tokens which have multiple noise samples. 
    clean_tokens = clean_tokens.expand(corrupted_tokens.shape[0], -1)
    target_token_idx = target_token_idx.expand(corrupted_tokens.shape[0], -1)
    
    #run the model on the clean and corrupted tokens, saving the model states for each
    #logits_shape: batch, n_tokens, vocab
    
    clean_logits, clean_cache = model.run_with_cache(clean_tokens, return_type="logits")
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
    #The difference is states is attributed to the corrupted input. The difference in logit attributions
    #at each layer is the direct effect on that layer. 
    
    #if we want to inverse the corruption: set corrupted_ablation to False. This runs the corrupted prompts
    #and tracks the effect of restore clean states. For sake of clean code, we simply swap variables. 
    if not corrupted_ablation: 
        temp_cache = corrupted_cache
        temp_logits = corrupted_logits
        corrupted_cache = clean_cache
        corrupted_logits = clean_logits
        clean_cache = temp_cache
        clean_logits = temp_logits
    
    mle_token_idx = torch.argmax(clean_logits[:,-1,:], dim=-1).unsqueeze(-1)
    # layers, batch, 1 = direct effect at each layer due to the ablation
    
    clean_logits = clean_logits[:,-1,:]
    corrupted_logits = corrupted_logits[:,-1,:]
    
#     print(clean_logits.shape)
#     print(mle_token_idx.shape)
#     print(clean_logits[mle_token_idx].shape)
    
    
    
    
    
    
    clean_mle_logit = clean_logits[:,-1,mle_token_idx] - clean_logits[:,-1,:].mean(dim=-1, keepdim=True)
    clean_target_logit = clean_logits[:,-1,target_token_idx]- clean_logits[:,-1,:].mean(dim=-1, keepdim=True)
    corrupted_mle_logit = corrupted_logits[:,-1,mle_token_idx]- corrupted_logits[:,-1,:].mean(dim=-1, keepdim=True)
    corrupted_target_logit = corrupted_logits[:,-1,target_token_idx]- corrupted_logits[:,-1,:].mean(dim=-1, keepdim=True)
    
    clean_mle_logit = clean_mle_logit.mean(dim=0, keepdim=True).to("cpu")
    clean_target_logit = clean_target_logit.mean(dim=0, keepdim=True).to("cpu")
    corrupted_mle_logit = corrupted_mle_logit.mean(dim=0, keepdim=True).to("cpu")
    corrupted_target_logit = corrupted_target_logit.mean(dim=0, keepdim=True).to("cpu")
    
    layer_names, clean_mle_residual_logits, clean_target_residual_logits = delta_ablate(clean_cache,
                                                                          mle_token_idx,
                                                                          target_token_idx, 
                                                                          mlp_input, 
                                                                         return_labels=True)
    
    corrupted_mle_residual_logits, corrupted_target_residual_logits = delta_ablate(corrupted_cache,
                                                                              mle_token_idx,
                                                                              target_token_idx, 
                                                                              mlp_input)
    #average over batch
    clean_mle_residual_logits = clean_mle_residual_logits.squeeze().mean(dim=-1)
    clean_target_residual_logits = clean_target_residual_logits.squeeze().mean(dim=-1)
    
    corrupted_mle_residual_logits = corrupted_mle_residual_logits.squeeze().mean(dim=-1)
    corrupted_target_residual_logits = corrupted_target_residual_logits.squeeze().mean(dim=-1)
        
    
    total_ablated_mle_residual_logits = torch.zeros(model.cfg.n_layers, len(layer_names), 
                                          device="cpu", dtype=torch.float32)
    total_ablated_target_residual_logits = torch.zeros(model.cfg.n_layers,len(layer_names),
                                             device="cpu", dtype=torch.float32)
    total_ablated_mle_logits = torch.zeros(model.cfg.n_layers,device="cpu", dtype=torch.float32)
    
    total_ablated_target_logits = torch.zeros(model.cfg.n_layers,device="cpu", dtype=torch.float32)

    for layer in range(model.cfg.n_layers):
        #here we run the prompt, add a hook point from the corrupted_cache to ablate 
        #a position and layer. We then run with cache to get a saved state of the activations. 
        hook_fn = partial(patch_layer, cache=corrupted_cache)            
        with model.hooks(
            fwd_hooks = [(utils.get_act_name(activation_to_ablate, layer),hook_fn)]
        ) as hooked_model:
            restored_logits, ablated_cache = hooked_model.run_with_cache(clean_tokens, return_type="logits")
            # layers, batch, 1 = direct effect at each layer due to the ablation
            ablated_mle_residual_logits, ablated_target_residual_logits  = delta_ablate(ablated_cache,
                                                                                          mle_token_idx,
                                                                                          target_token_idx, 
                                                                                          mlp_input
            )
            hooked_model.reset_hooks()
            model.reset_hooks()
        
        
        mle_logit = restored_logit[:, -1, mle_token_idx] - restored_logit[:,-1,:].mean(dim=-1, keepdim=True)
        target_logit = restored_logit[:,-1,target_token_idx] - restored_logit[:,-1,:].mean(dim=-1, keepdim=True)
        
        total_ablated_mle_logits[layer] = mle_logit.mean(dim=0, keepdim=True).to("cpu")
        total_ablated_target_logits = target_logit.mean(dim=0, keepdim=True).to("cpu")
        
        total_ablated_mle_residual_logits[layer] = ablated_mle_residual_logits.squeeze().mean(dim=-1).to("cpu")
        total_ablated_target_residual_logits[layer] = ablated_target_residual_logits.squeeze().mean(dim=-1).to("cpu")


    return {
        "layer_names" : layer_names,
        "clean_mle_residual_logits" : clean_mle_residual_logits.to("cpu"),
        "clean_target_residual_logits" : clean_target_residual_logits.to("cpu"),
        "corrupted_mle_residual_logits" : corrupted_mle_residual_logits.to("cpu"),
        "corrupted_target_residual_logits" : corrupted_target_residual_logits.to("cpu"),
        "ablated_mle_residual_logits" : total_ablated_mle_residual_logits.to("cpu"),
        "ablated_target_residual_logits" : total_ablated_target_residual_logits.to("cpu"),
        "total_ablated_mle_logits" : total_ablated_mle_logits.to("cpu"),
        "total_ablated_target_logits" : total_ablated_target_logits.to("cpu"),
        
         "clean_mle_logit":clean_mle_logit,
         "clean_target_logit":clean_target_logit,
         "corrupted_mle_logit":corrupted_mle_logit,
         "corrupted_target_logit":corrupted_target_logit,
    
    }
    

In [31]:
def run_experiment(indices, activation_to_ablate): 
    # indices = list(range(len(dataset)))
    # random.shuffle(indices)
    # indices = indices[:n]
    
    results = []
    
    for idx in indices: 
        prompt, subject, target = sample_dataset(dataset, idx=idx)
        true_fact, corrupted_facts, target = resample_ablation(prompt, subject, target, n_noise_samples=10)
        print(true_fact)
        result = run_all(
            clean_prompt=true_fact, 
            corrupted_prompts=corrupted_facts, 
            target=target, 
            activation_to_ablate = activation_to_ablate,
        )
        results.append(result)
    return results

# results_attn_CE = run_experiment_on_delta_ablate(dataset, n=2, activation_to_ablate="attn_out") 
# results_mlp = run_experiment_on_delta_ablate(dataset, n=10, activation_to_ablate="mlp_out") 
# results_resid = run_experiment_on_delta_ablate(dataset, n=10, activation_to_ablate="attn_out") 

In [32]:
index_batches = [np.arange(i, 100+i) for i in [0, 100, 200, 300, 400]]


for indices in index_batches: 
    filename = f"results/{indices.min()}-{indices.max()}_results.pickle"
    results = run_experiment(indices, activation_to_ablate="attn_out")
    save(results, filename)
    print("saved ", filename)

# results_attn_CE = run_experiment_on_delta_ablate(dataset, n=2, activation_to_ablate="attn_out") 


RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`

### load results

In [None]:
import os

In [None]:
pickle_files = [f for f in os.listdir("results") if f.split(".")[-1]=="pickle"]

In [None]:



results = []
for file in pickle_files: 
    result = load("results/" + file)

In [None]:
def flatten_result(result): 
    layer_names = result[0]["layer_names"]
    clean_mle_residual_logits = []
    clean_target_residual_logits = []
    corrupted_mle_residual_logits = []
    corrupted_target_residual_logits = []
    ablated_mle_residual_logits = []
    ablated_target_residual_logits = []

    for sample_result in result: 
        clean_mle_residual_logits.append(sample_result["clean_mle_residual_logits"])
        clean_target_residual_logits.append(sample_result["clean_target_residual_logits"])
        corrupted_mle_residual_logits.append(sample_result["corrupted_mle_residual_logits"])
        corrupted_target_residual_logits.append(sample_result["corrupted_target_residual_logits"])
        ablated_mle_residual_logits.append(sample_result["ablated_mle_residual_logits"])
        ablated_target_residual_logits.append(sample_result["ablated_target_residual_logits"])
        
    clean_mle_residual_logits = torch.stack(clean_mle_residual_logits)
    clean_target_residual_logits = torch.stack(clean_target_residual_logits)
    corrupted_mle_residual_logits = torch.stack(corrupted_mle_residual_logits)
    corrupted_target_residual_logits = torch.stack(corrupted_target_residual_logits)
    ablated_mle_residual_logits = torch.stack(ablated_mle_residual_logits)
    ablated_target_residual_logits = torch.stack(ablated_target_residual_logits)
    
    return (layer_names, 
            clean_mle_residual_logits,
            clean_target_residual_logits,
            corrupted_mle_residual_logits,
            corrupted_target_residual_logits,
            ablated_mle_residual_logits,
            ablated_target_residual_logits)

        
def flatten_all_results():
    pickle_files = [f for f in os.listdir("results") if f.split(".")[-1]=="pickle"]
    results = []
    for file in pickle_files: 
        result = load("results/" + file)
        result = flatten_result(result)
        results.append(result)
                
    layer_names = results[0][0]
    clean_mle_residual_logits = torch.cat([r[1] for r in results])
    clean_target_residual_logits = torch.cat([r[2] for r in results])
    corrupted_mle_residual_logits = torch.cat([r[3] for r in results])
    corrupted_target_residual_logits = torch.cat([r[4] for r in results])
    ablated_mle_residual_logits  = torch.cat([r[5] for r in results])
    ablated_target_residual_logits  = torch.cat([r[6] for r in results])
    
    return (layer_names, 
            clean_mle_residual_logits,
            clean_target_residual_logits,
            corrupted_mle_residual_logits,
            corrupted_target_residual_logits,
            ablated_mle_residual_logits,
            ablated_target_residual_logits)
    

        
(layer_names, 
clean_mle_residual_logits,
clean_target_residual_logits,
corrupted_mle_residual_logits,
corrupted_target_residual_logits,
ablated_mle_residual_logits,
ablated_target_residual_logits) = flatten_all_results()

In [None]:
def plot_delta_unembed_delta_ablate(layer_names, clean_residual_logits, corrupted_residual_logits, ablated_residual_logits, ablate_type="attn", analysis_type="attn"): 
    ablate_idx = torch.tensor([ablate_type in l for l in layer_names])
    analysis_idx = torch.tensor([analysis_type in l for l in layer_names])
    
    clean_residual_logits = clean_residual_logits[:,analysis_idx]
    corrupted_residual_logits = corrupted_residual_logits[:,analysis_idx]
    ablated_residual_logits = ablated_residual_logits[:,:,analysis_idx]
    
    batch, ablate_idx, layer_idx = ablated_residual_logits.shape

    delta_ablate_l = torch.zeros((batch, ablate_idx)) ; direct_effect = torch.zeros((batch, layer_idx))

    for idx in range(ablate_idx): 
        delta_ablate_l[:,idx] = ablated_residual_logits[:,idx,:].sum(dim=-1) - clean_residual_logits.sum(dim=-1)
        direct_effect[:,idx] = ablated_residual_logits[:,idx,idx] - clean_residual_logits[:,idx]
    delta_unembed = clean_residual_logits

    mean_delta_ablate = delta_ablate_l.mean(dim=0)
    mean_direct_effect = direct_effect.mean(dim=0)
    mean_delta_unembed = delta_unembed.mean(dim=0)
    
    
    
    # Create a 2x2 grid of subplots
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))
    # Remove the last empty subplot from the grid
    fig.delaxes(axes[1, 1])
    fig.delaxes(axes[0,1])

    
    for exp_idx in range(batch): 
        axes[0,0].plot(delta_ablate_l[exp_idx], color="lightgrey")
        axes[1,0].plot(delta_unembed[exp_idx], color="lightgrey")
    axes[0,0].plot(mean_delta_ablate, color="blue", label="Mean")
    axes[1,0].plot(mean_delta_unembed, color="blue", label="Mean")
    
    axes[0, 0].set_title("Delta Ablate")
    axes[1, 0].set_title("Delta Unembed")

    ax_third = fig.add_subplot(1, 2, 2)
    for exp_idx in  range(batch): 
        da = delta_ablate_l[exp_idx]
        du = delta_unembed[exp_idx]
        layer_colors = plt.cm.viridis(np.linspace(0, 1, du.shape[-1]))
        ax_third.scatter(x = du, y = da, c=layer_colors, s=10, marker='o', alpha=0.7)
        
    min_val = -2
    max_val = 2
    
    magnitude_axis=2
    ax_third.set_xlim(-magnitude_axis, magnitude_axis)
    ax_third.set_ylim(-magnitude_axis, max_val)

    # Set the limits for the dashed line to go through zero and extend on both sides
    line_min = min(-2, 0)
    line_max = max(6, 0)
    ax_third.plot([line_min, line_max], [line_min, line_max], color='black', linestyle='dashed', alpha=0.5)
    ax_third.set_title("Delta Ablate vs Delta Unembed")
    cax = fig.add_axes([0.99, 0.15, 0.02, 0.7])  # Adjust the position and size as needed
    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=plt.cm.viridis), cax=cax)
    tick_locations = np.linspace(0, 1, 8)  # Adjust the number of ticks as needed
    tick_labels = np.arange(0, 36, 5)
    cbar.set_ticks(tick_locations)
    cbar.set_ticklabels(tick_labels)
    cbar.set_label('Layer')
    


    plt.tight_layout()
    plt.show()

    
    

In [None]:
plot_delta_unembed_delta_ablate(layer_names=layer_names,
                    clean_residual_logits=clean_mle_residual_logits,
                    corrupted_residual_logits=corrupted_mle_residual_logits,
                    ablated_residual_logits=ablated_mle_residual_logits,
                    ablate_type="attn")


In [None]:
def plot_compensatory_effect_by_layer(layer_names, clean_residual_logits, corrupted_residual_logits, ablated_residual_logits, ablate_type="attn", analysis_type="attn"):
    # labels, target_residual_clean_logit, mle_residual_corrupted_logit, target_residual_clean_logit, mle_residual_clean_logit
    ablate_idx = torch.tensor([ablate_type in l for l in layer_names])
    analysis_idx = torch.tensor([analysis_type in l for l in layer_names])
    disp_layer_names = [l for l in layer_names if analysis_type in l]
    n_layers = analysis_idx.sum()
    
    clean_residual_logits = clean_residual_logits[:,analysis_idx]
    corrupted_residual_logits = corrupted_residual_logits[:,analysis_idx]
    ablated_residual_logits = ablated_residual_logits[:,:,analysis_idx]
    
    batch, ablate_idx, layer_idx = ablated_residual_logits.shape
    
    all_compensatory_effect = torch.zeros((batch, ablate_idx, n_layers))
    delta_ablate_l = torch.zeros((batch, ablate_idx))
    
    for i in range(ablate_idx): 
        for j in range(n_layers): 
            # all_compensatory_effect[:,i,j] = ablated_residual_logits[:,i,j] - (corrupted_residual_logits[:,j] - clean_residual_logits[:,j])
            all_compensatory_effect[:,i,j] = ablated_residual_logits[:,i,j] - corrupted_residual_logits[:,j] 
            
            
    all_compensatory_effect = all_compensatory_effect.mean(dim=0)
    
    
    imshow(all_compensatory_effect.T,y=disp_layer_names,labels={"x":"Ablation Index", "y":"Layer"}, title="PLOT A")

In [None]:
plot_compensatory_effect_by_layer(layer_names=layer_names,
                    clean_residual_logits=clean_mle_residual_logits,
                    corrupted_residual_logits=corrupted_mle_residual_logits,
                    ablated_residual_logits=ablated_mle_residual_logits,
                    ablate_type="attn", 
                    analysis_type="_")


In [None]:
def plot_compensatory_effect_by_layer(results, attribution_type="attn", ablation_type="attn", use_target=True, subtract_clean_run=True):
    # labels, target_residual_clean_logit, mle_residual_corrupted_logit, target_residual_clean_logit, mle_residual_clean_logit
    layer_names = results[0]["layer_names"]
    disp_layer_names = [l for l in layer_names if attribution_type in l]
    idx_of_interest = torch.tensor([attribution_type in l in l for l in layer_names])
    attn_idx = torch.tensor(["attn" in l in l for l in layer_names])
    
    n_ablations = results[0]["ablate_direct_effect_on_target"].shape[0]
    
    
    all_compensatory_effect = torch.zeros((len(results), n_ablations, len(disp_layer_names)))    
    
    
    for i,result in enumerate(results):
        
        if use_target: 
            ablate_direct_effect_all_layers = result["ablate_direct_effect_on_target"]
            unembed_direct_effect_all_layers = result["target_direct_effect"]
        else: 
            ablate_direct_effect_all_layers = result["ablate_direct_effect_on_mle"]
            unembed_direct_effect_all_layers = result["mle_direct_effect"] 
            
            
        ablate_direct_effect = ablate_direct_effect_all_layers[:,idx_of_interest]
        unembed_direct_effect = unembed_direct_effect_all_layers[idx_of_interest]

        #calculate compensatory response
        compensatory_response = torch.zeros((n_ablations, all_compensatory_effect.shape[-1]))
        
        for idx in range(n_ablations): 
            if subtract_clean_run: 
                ablate_direct_effect_all_layers = ablate_direct_effect[idx] - unembed_direct_effect
            else: 
                ablate_direct_effect_all_layers = ablate_direct_effect[idx]

            compensatory_response[idx] = (ablate_direct_effect_all_layers)
                
        
        all_compensatory_effect[i] = compensatory_response
    
    all_compensatory_effect = all_compensatory_effect.mean(dim=0)
    
    imshow(all_compensatory_effect,y=disp_layer_names,labels={"x":"Ablation Index", "y":"Layer"}, title="add later")
    
plot_compensatory_effect_by_layer(results_attn_CE, attribution_type="attn", ablation_type="attn", use_target=True, subtract_clean_run=True)