# Causal Analysis of Facts
We analyze where facts might be locateed in an LLM. A fact is represented by a tuple of subject, relation, object (s,r,o). We also investigate where an inverse fact mis located. An inverse fact is represented by the tuple (o, r^-1, s). For example, the fact "Paris is the capital of France" has the inverse fact "France's capital is France". 

## Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
    # %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    renderer = "colab"
except:
    IN_COLAB = False
    from IPython import get_ipython
    %load_ext autoreload
    %autoreload 2
    renderer = "jupyterlab"

In [2]:
%%bash
cd ../
pip install poetry
poetry install
cd notebooks

Collecting platformdirs<4.0.0,>=3.0.0 (from poetry)
  Obtaining dependency information for platformdirs<4.0.0,>=3.0.0 from https://files.pythonhosted.org/packages/14/51/fe5a0d6ea589f0d4a1b97824fb518962ad48b27cd346dcdfa2405187997a/platformdirs-3.10.0-py3-none-any.whl.metadata
  Using cached platformdirs-3.10.0-py3-none-any.whl.metadata (11 kB)
Using cached platformdirs-3.10.0-py3-none-any.whl (17 kB)
Installing collected packages: platformdirs
  Attempting uninstall: platformdirs
    Found existing installation: platformdirs 3.8.0
    Uninstalling platformdirs-3.8.0:
      Successfully uninstalled platformdirs-3.8.0
Successfully installed platformdirs-3.10.0
Installing dependencies from lock file

Package operations: 0 installs, 1 update, 0 removals

  • Downgrading platformdirs (3.10.0 -> 3.8.0)

Installing the current project: transformer-lens (0.0.0)


In [3]:

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = renderer

In [4]:


# Import stuff
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
import json
from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy
import subprocess

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
from tqdm import tqdm

In [176]:
import ast
import contextlib
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import torch
from matplotlib.cm import ScalarMappable
import pickle

In [6]:
if IN_COLAB: 
    import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens.utilities import devices
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f6eb71d1ab0>

Plotting helper functions:

In [8]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def cuda():
    return torch.cuda.is_available()

def get_device(): 
    return "cuda" if cuda() else "cpu"

def save_pickle(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

# Function to load a pickle object from a file
def load_pickle(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    return obj

def save_json(data, file_path):
    with open(file_path, "w") as json_file:
        json.dump(data, json_file, indent=4)

def load_json(file_path):
    with open(file_path, "r") as json_file:
        loaded_data = json.load(json_file)
    return loaded_data

def get_gpu_usage(): 
    subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

device = get_device()
device

'cuda'

## Model

- fold_ln: Whether to fold in the LayerNorm weights to the subsequent linear layer. This does not change the computation.
- center_writing_weights: Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNormthis doesn't change the computation.
- center_unembed : Whether to center W_U (ie set mean to be zero). Softmax is translation invariant so this doesn't affect log probs or loss, but does change logits. Defaults to True.
- refactor_factored_attn_matrices: Whether to convert the factoredmatrices (W_Q & W_K, and W_O & W_V) to be "even". 

In [9]:
MODEL_NAME = "gpt2-small"
model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        center_unembed=True,  
        center_writing_weights=True,              # Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the computation.      
        fold_ln=True,                             # Whether to  fold in the LayerNorm weights to the subsequent linear layer.
        refactor_factored_attn_matrices=True,
    )


 

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


## Dataset
We use the CounterFact Dataset

In [10]:
dataset = load_json("data/fact_dataset.json")

In [11]:
def sample_dataset(dataset, idx = None): 
    if idx is None: 
        prompt = "The {} is located in"
        subject = "Eiffel Tower"
        target = "Paris"
    else: 
        sample = dataset[idx]
        prompt = sample["requested_rewrite"]["prompt"]
        subject = sample["requested_rewrite"]["subject"]
        target = sample["requested_rewrite"]["target_true"]["str"]
    return prompt, subject, target

In [12]:
sample = sample_dataset(dataset)
prompt, subject, target = sample
prompt = prompt.format(subject)

In [13]:
model.to_tokens(prompt).shape

torch.Size([1, 9])

## Ablation Methods: 
- Noise ablation 
- Resample ablation

In [22]:
def noise_ablation(prompt, subject, target, n_noise_samples=5, vx=3):
    subject_tokens = model.to_tokens(subject)
    
    #shape: batch, n_tokens, embedding_dim
    subject_embedding = model.embed(subject_tokens)
    _, n_tokens, embedding_dim = subject_embedding.shape
    
    #noise: N(0,v), v = 3*std(embedding)
    embedding = model.W_E
    v = vx*torch.std(embedding, dim=0) #for each v in V
    noise = torch.randn(
        (n_noise_samples, n_tokens, embedding_dim)
    ).to(device) + v
    
    subject_embedding_w_noise = subject_embedding + noise
    
    #shape: batch, n_tokens, vocab_size (logits)
    unembedded_subject = model.unembed(subject_embedding_w_noise)

    noisy_subject_tokens = torch.argmax(unembedded_subject, dim=-1)
    noisy_subject_str = [
        model.to_string(nst) for nst in noisy_subject_tokens
    ]
    true_prompt = prompt.format(subject)
    corrupted_prompts = [
        prompt.format(nss.strip()) for nss in noisy_subject_str
    ]
    return true_prompt, corrupted_prompts, target

In [23]:
prompt, subject, target = sample_dataset(dataset)
noise_ablation(prompt, subject, target)

('The Eiffel Tower is located in',
 ['The tremendenghyde newspmonary is located in',
  'The semblyireziffsel Blades is located in',
  'The teinhartSPONSOREDirisEngineDebug is located in',
  'The [/ Pradeshiffelaordable is located in',
  'The sey destrodriversGB pancakes is located in'],
 'Paris')

In [33]:
def pad_from_left(tokens : torch.tensor, maxlen:int):
    pad_token = model.tokenizer.pad_token_id
    padded_tokenized_inputs = torch.zeros(tokens.shape[0], maxlen)
    
    n_pads = maxlen - tokens.shape[-1]
    padded_tokenized_inputs[:,n_pads] = pad_token
    padded_tokenized_inputs[:,n_pads:] = tokens
    return padded_tokenized_inputs.long()

def pad_to_same_length(clean_tokens, corrupted_tokens): 
    
    maxlen = max([clean_tokens.shape[-1], corrupted_tokens.shape[-1]])
    
    if clean_tokens.shape[-1] > corrupted_tokens.shape[-1]: 
        corrupted_tokens = pad_from_left(corrupted_tokens, maxlen)
    elif clean_tokens.shape[-1] < corrupted_tokens.shape[-1]: 
        clean_tokens = pad_from_left(clean_tokens, maxlen)
    return clean_tokens, corrupted_tokens

In [98]:
def resample_ablation(prompt, subject, target, n_noise_samples=20):
    subject = " " + subject
    subject_tokens = model.to_tokens(subject)
    n_subject_tokens = subject_tokens.shape[-1] - 1
    noisy_subject_tokens = []
    for i in range(n_subject_tokens):     
        embedding = model.W_E
        permutations = torch.randperm(embedding.size(0))[:n_noise_samples]
        random_samples = embedding[permutations]
        random_samples = random_samples.unsqueeze(dim=1)
        #we de-embed these rows 
        random_embeddings = model.unembed(random_samples)
        random_tokens = torch.argmax(random_embeddings, dim=-1)
        noisy_subject_tokens.append(random_tokens)
        
    noisy_subject_tokens = torch.stack(noisy_subject_tokens, dim=0)
    noisy_subject_tokens = noisy_subject_tokens.transpose(1,0)
    
    corrupted_facts = []
    for random_prompt in range(noisy_subject_tokens.shape[0]): 
        random_tokens = noisy_subject_tokens[random_prompt]
        random_subject = [
            model.to_string(t) for t in random_tokens
        ]
        print(len(random_subject))
        random_subject = " " + "".join(random_subject)
        corrupted_facts.append(prompt.format(random_subject.strip()))
        
        
    true_fact = prompt.format(subject.strip())
    
    fact_tokens = model.to_tokens(true_fact)
    subject_mask = torch.zeros_like(fact_tokens, dtype=torch.bool, device=device)

    for value in subject_tokens[0, 1:]:
        subject_mask |= (fact_tokens == value)    
        
    
    print(subject_mask)
    print(len(model.to_str_tokens(true_fact)))
    for p in corrupted_facts: 
        print(len(model.to_str_tokens(p)))
        
    
    
    return true_fact, corrupted_facts, target, subject_mask.to("cpu")
    
    
prompt, subject, target = sample_dataset(dataset, idx=1)
resample_ablation(prompt, subject, target, n_noise_samples=7)
    

9
9
9
9
9
9
9
tensor([[False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True, False]], device='cuda:0')
15
16
15
15
15
16
15
15


('The mother tongue of Thomas Joannes Stieltjes is',
 ['The mother tongue of ordableHEAD toddler intuitivekas jaw haz software Fired is',
  'The mother tongue of regained Arlingtonipesuminum fadingperorsavery rubbedjoy is',
  'The mother tongue of NPR Entrepreneowment cognography Beats Lines judge specificity is',
  'The mother tongue of Desktop tradersdocumented bombings puppy boasting Princeton subtractLife is',
  'The mother tongue of abruptazaarochond Wage packingendalengthResultsmanship is',
  'The mother tongue of segregationfolk publish Winds summoning vesselsForgeModLoaderoren politically is',
  'The mother tongue of package quotesStrong nitria learns prog14 anticipate is'],
 'Dutch',
 tensor([[False, False, False, False, False,  True,  True,  True,  True,  True,
           True,  True,  True,  True, False]]))

In [182]:
def get_mask(prompt_tokens, subject_tokens): 
    A_flattened = prompt_tokens.view(-1)
    B_flattened = subject_tokens.view(-1)

    mask = torch.zeros_like(A_flattened, dtype=torch.bool)

    for i in range(len(A_flattened) - len(B_flattened) + 1):
        if torch.all(A_flattened[i:i+len(B_flattened)] == B_flattened):
            mask[i:i+len(B_flattened)] = True
            
    return mask


def resample_ablation(prompt, subject, target, n_noise_samples=20, temperature=0.85):
    subject = " " + subject
    subject_tokens = model.to_tokens(subject)[:,1:]
    n_subject_tokens = subject_tokens.shape[-1]
    prompt_start = prompt.split("{")[0]
    prompt_end = prompt.split("}")[-1]
    clean_fact = prompt.format(subject)
    
    fact_tokens = model.to_tokens(clean_fact)
    clean_subject_mask = get_mask(fact_tokens, subject_tokens)


    corrupted_facts = []        
    while len(corrupted_facts) < n_noise_samples: 
        generated = model.generate(prompt_start, max_new_tokens=n_subject_tokens,temperature=temperature,verbose=False)
        corrupted_subject = generated.split(prompt_start)[-1].strip()
        corrupted_fact = prompt.format(corrupted_subject)
        corrupted_subject_tokens = model.to_tokens(corrupted_subject)[:,1:]
        corrupted_fact_tokens = model.to_tokens(corrupted_fact)
        
        corrupted_mask = get_mask(prompt_tokens=corrupted_fact_tokens, subject_tokens=corrupted_subject_tokens)
        if corrupted_mask.shape != clean_subject_mask.shape: 
            continue
        if all(corrupted_mask==clean_subject_mask):
            corrupted_facts.append(corrupted_fact)

    return clean_fact, corrupted_facts, target, clean_subject_mask
        
        
        
    
        
 

prompt, subject, target = sample_dataset(dataset, idx=3)
resample_ablation(prompt, subject, target, n_noise_samples=7)
    

('The headquarter of  Monell Chemical Senses Center is located in',
 ['The headquarter of ileanary association Hong Kong Min is located in',
  'The headquarter of ṣiṣn Day is located in',
  'The headquarter of _____ buildings; the also property of is located in',
  'The headquarter of urns and cubic feet at the is located in',
  'The headquarter of état bien-de- is located in',
  'The headquarter of été arabia is located is located in',
  'The headquarter of irc.wa.gov told ABC is located in'],
 'Philadelphia',
 tensor([False, False, False, False, False, False,  True,  True,  True,  True,
          True,  True,  True, False, False, False], device='cuda:0'))

In [179]:
def pad_from_left(tokens : torch.tensor, maxlen:int):
    pad_token = model.tokenizer.pad_token_id
    padded_tokenized_inputs = torch.zeros(tokens.shape[0], maxlen)
    
    n_pads = maxlen - tokens.shape[-1]
    padded_tokenized_inputs[:,n_pads] = pad_token
    padded_tokenized_inputs[:,n_pads:] = tokens
    return padded_tokenized_inputs.long()

def pad_to_same_length(clean_tokens, corrupted_tokens): 
    
    maxlen = max([clean_tokens.shape[-1], corrupted_tokens.shape[-1]])
    
    if clean_tokens.shape[-1] > corrupted_tokens.shape[-1]: 
        corrupted_tokens = pad_from_left(corrupted_tokens, maxlen)
    elif clean_tokens.shape[-1] < corrupted_tokens.shape[-1]: 
        clean_tokens = pad_from_left(clean_tokens, maxlen)
    return clean_tokens, corrupted_tokens

## Single Patch Restoration

In [209]:
prompt, subject, target = sample_dataset(dataset, idx=0)
true_fact, corrupted_facts, target, subject_mask = resample_ablation(prompt, subject, target, n_noise_samples=7)


In [252]:
def unembedding_function(residual_stack, cache) -> float:
    #we are only interested in applying the layer norm of the final layer on the final token
    #shape: [74, 5, 10, 1280] = n_layers, prompts, tokens, d_model
    z = cache.apply_ln_to_stack(residual_stack, layer = -1)
    z = z @ model.W_U
    return z

def get_accumulated_residual(cache, mle_token_idx, target_token_idx): 
    accumulated_residual = cache.accumulated_resid(layer=-1)
    accumulated_residual = unembedding_function(accumulated_residual, cache)
    #shape: torch.Size([layer, batch, pos, vocab])
    accumulated_residual = accumulated_residual.permute(0,2,1,3)
    
    print(accumulated_residual.shape)
    print(mle_token_idx.shape)
    
    # mle_token_idx = mle_token_idx.unsqueeze(dim=0).expand(accumulated_residual.shape[1],1,1)
    print(mle_token_idx)
    
    
    
    
    
          
                                        
    raise
    
    
    mle_idx_expanded = mle_token_idx.repeat(accumulated_residual.shape[0],1,accumulated_residual.shape[2],1)
    
    print(accumulated_residual.shape)
    mle_residual_logits = (accumulated_residual.gather(dim=-1, index=mle_idx_expanded) - accumulated_residual.mean(dim=-1, keepdim=True)).to("cpu")
    target_residual_logits = (accumulated_residual.gather(dim=-1, index=target_idx_expanded) - accumulated_residual.mean(dim=-1, keepdim=True)).to("cpu")

    
    
    raise
    
    
def get_logit_attributions(cache, mle_token_idx, target_token_idx): 
    residual_stack = cache.decompose_resid(layer=-1, mode="all", return_labels=False)
    mle_logit_attributions = cache.logit_attrs(residual_stack, mle_token_idx)
    target_logit_attributions = cache.logit_attrs(residual_stack, target_token_idx)
    raise

def get_decomposed_residual(cache, mle_token_idx, target_token_idx, return_labels=False):
    if return_labels: 
        residual_stack, labels = cache.decompose_resid(layer=-1, mode="all", return_labels=True)
    else: 
        residual_stack = cache.decompose_resid(layer=-1, mode="all", return_labels=False)
    raise
    
def get_stack_head_results(cache, mle_token_idx, target_token_idx): 
    head_results = cache.stack_head_results(layer=-1)
    raise
    
def get_neurons_stack(cache, mle_token_idx, target_token_idx): 
    activations = cache.stack_neuron_results(layer=-1)
    raise

def patch_layer(corrupted_residual_component,hook,cache):
    corrupted_residual_component[:, :, :] = cache[hook.name][:, :, :]
    return corrupted_residual_component

def patch_position(corrupted_residual_component, hook,pos,cache):
    corrupted_residual_component[:, pos, :] = cache[hook.name][:, pos, :]
    return corrupted_residual_component

def extract_logit(logits, mle_token_idx, target_token_idx):     
    mle_logit = (logits.gather(dim=-1, index=mle_token_idx) - logits.mean(dim=-1, keepdim=True)).to("cpu")
    target_logit = (logits.gather(dim=-1, index=target_token_idx) - logits.mean(dim=-1, keepdim=True)).to("cpu")
    mle_logit = mle_logit.mean(dim=0)
    target_logit = target_logit.mean(dim=0)
    return mle_logit, target_logit


def run_all_single_patches(
                        clean_prompt: str,
                        corrupted_prompts: List[str],
                        target: str):
    #-----------------------------prepare inputs--------------------------------------
    clean_tokens = model.to_tokens(clean_prompt, prepend_bos=True) 
    corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
    assert clean_tokens.shape[-1] == corrupted_tokens.shape[-1]
    n_tokens = clean_tokens.shape[-1]
    clean_tokens = clean_tokens.expand(corrupted_tokens.shape[0], -1)
    
    target_token = model.to_single_token(target)
    
    clean_logits, clean_cache = model.run_with_cache(clean_tokens, return_type="logits")
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
    
    mle_token = torch.argmax(clean_logits[:,-1,:], dim=-1)
    target_token = torch.ones_like(mle_token).long().to(device) * target_token
    target_token = target_token.unsqueeze(dim=-1)
    mle_token = mle_token.unsqueeze(dim=-1)
    
    clean_logits = clean_logits[:,-1,:]
    corrupted_logits = corrupted_logits[:,-1,:]
    
    clean_mle_logit, clean_target_logit = extract_logit(clean_logits, mle_token, target_token)
    corrupted_mle_logit, corrupted_target_logit = extract_logit(corrupted_logits, mle_token, target_token)
    
    #---------------------------calculating base results---------------------------------------
    get_accumulated_residual(clean_cache, mle_token, target_token)

    
    
    raise
    get_logit_attributions(cache, mle_token_idx, target_token_idx)
    get_decomposed_residual(cache, mle_token_idx, target_token_idx, return_labels=False)   
    get_stack_head_results(cache, mle_token_idx, target_token_idx)
    get_neurons_stack(cache, mle_token_idx, target_token_idx)

    
    
    
    
    layer_names, clean_mle_residual_logits, clean_target_residual_logits = delta_ablate(clean_cache,mle_token_idx,target_token_idx, return_labels=True)
    
    corrupted_mle_residual_logits, corrupted_target_residual_logits = delta_ablate(corrupted_cache,mle_token_idx,target_token_idx)
    
    total_ablated_mle_residual_logits = torch.zeros(model.cfg.n_layers * 2,len(layer_names), device="cpu", dtype=torch.float32)
    total_ablated_target_residual_logits = torch.zeros(model.cfg.n_layers * 2,len(layer_names),device="cpu", dtype=torch.float32)
    total_ablated_mle_logits = torch.zeros(model.cfg.n_layers * 2,device="cpu", dtype=torch.float32)
    total_ablated_target_logits = torch.zeros(model.cfg.n_layers * 2,device="cpu", dtype=torch.float32)
    
    total_ablated_mle_residual_position_logits = torch.zeros(model.cfg.n_layers * 2, n_tokens, len(layer_names), device="cpu", dtype=torch.float32)
    total_ablated_target_residual_position_logits = torch.zeros(model.cfg.n_layers * 2, n_tokens, len(layer_names),device="cpu", dtype=torch.float32)
    total_ablated_mle_position_logits = torch.zeros(model.cfg.n_layers * 2, n_tokens, device="cpu", dtype=torch.float32)
    total_ablated_target_position_logits = torch.zeros(model.cfg.n_layers * 2, n_tokens, device="cpu", dtype=torch.float32)

    for layer in range(model.cfg.n_layers * 2):
        if layer % 2 == 0: 
            patch_name =  f"blocks.{layer//2}.hook_attn_out"
        else: 
            patch_name = f"blocks.{layer//2}.hook_mlp_out"
        
        hook_fn = partial(patch_layer, cache=clean_cache)            
        with model.hooks(
            fwd_hooks = [(patch_name, hook_fn)]
        ) as hooked_model:
            restored_logits, ablated_cache = hooked_model.run_with_cache(corrupted_tokens, return_type="logits")

            ablated_mle_residual_logits, ablated_target_residual_logits  = delta_ablate(ablated_cache,
                                                                                          mle_token_idx,
                                                                                          target_token_idx)
            hooked_model.reset_hooks()
            model.reset_hooks()
        
        restored_logits = restored_logits[:,-1,:]
        mle_logit = (restored_logits.gather(dim=-1, index=mle_token_idx) - restored_logits.mean(dim=-1, keepdim=True)).to("cpu")
        target_logit = (restored_logits.gather(dim=-1, index=target_token_idx) - restored_logits.mean(dim=-1, keepdim=True)).to("cpu")
        mle_logit = mle_logit.mean(0) ; target_logit = target_logit.mean(0)
    
        total_ablated_mle_logits[layer] = mle_logit
        total_ablated_target_logits[layer] = target_logit
        total_ablated_mle_residual_logits[layer] = ablated_mle_residual_logits.to("cpu")
        total_ablated_target_residual_logits[layer] = ablated_target_residual_logits.to("cpu")
        
        for pos in range(n_tokens):
            hook_fn = partial(patch_position, pos=pos, clean_cache=clean_cache)
            restored_logits = model.run_with_hooks(corrupted_tokens,fwd_hooks = [(patch_name,hook_fn)],return_type="logits")
            
            restored_logits = restored_logits[:,-1,:]
            mle_logit = (restored_logits.gather(dim=-1, index=mle_token_idx) - restored_logits.mean(dim=-1, keepdim=True)).to("cpu")
            target_logit = (restored_logits.gather(dim=-1, index=target_token_idx) - restored_logits.mean(dim=-1, keepdim=True)).to("cpu")
            mle_logit = mle_logit.mean(0) ; target_logit = target_logit.mean(0)

            total_ablated_mle_position_logits[layer, pos] = mle_logit
            total_ablated_target_position_logits[layer, pos] = target_logit
            total_ablated_mle_residual_position_logits[layer, pos] = ablated_mle_residual_logits.to("cpu")
            total_ablated_target_residual_position_logits[layer, pos] = ablated_target_residual_logits.to("cpu")
                
    return {
        "layer_names" : layer_names,
        "clean_mle_residual_logits" : clean_mle_residual_logits.to("cpu"),
        "clean_target_residual_logits" : clean_target_residual_logits.to("cpu"),
        "corrupted_mle_residual_logits" : corrupted_mle_residual_logits.to("cpu"),
        "corrupted_target_residual_logits" : corrupted_target_residual_logits.to("cpu"),
        "ablated_mle_residual_logits" : total_ablated_mle_residual_logits.to("cpu"),
        "ablated_target_residual_logits" : total_ablated_target_residual_logits.to("cpu"),
        "total_ablated_mle_logits" : total_ablated_mle_logits.to("cpu"),
        "total_ablated_target_logits" : total_ablated_target_logits.to("cpu"),
        
         "clean_mle_logit":clean_mle_logit,
         "clean_target_logit":clean_target_logit,
         "corrupted_mle_logit":corrupted_mle_logit,
         "corrupted_target_logit":corrupted_target_logit,
        
        "total_ablated_mle_position_logits":total_ablated_mle_position_logits, 
        "total_ablated_target_position_logits":total_ablated_target_position_logits, 
        "total_ablated_mle_residual_position_logits":total_ablated_mle_residual_position_logits, 
        "total_ablated_target_residual_position_logits":total_ablated_target_residual_position_logits, 
    
    }
    

In [253]:

result = run_all_single_patches(
            clean_prompt=true_fact, 
            corrupted_prompts=corrupted_facts, 
            target=target, 
        ) 

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`

In [254]:
def run_experiment(indices): 
    # indices = list(range(len(dataset)))
    # random.shuffle(indices)
    # indices = indices[:n]
    
    results = []
    
    for idx in indices: 
        prompt, subject, target = sample_dataset(dataset, idx=idx)
        true_fact, corrupted_facts, target, subject_mask = resample_ablation(prompt, subject, target, n_noise_samples=10)
        print(true_fact)
        result = run_all_restoration(
            clean_prompt=true_fact, 
            corrupted_prompts=corrupted_facts, 
            target=target, 
        )
        result["prompt"] = true_fact
        result["subject_mask"] = subject_mask
        results.append(result)
    return results


index_batches = [np.arange(i, 100+i) for i in [200]]

for indices in index_batches: 
    filename = f"results/{indices.min()}-{indices.max()}_restoration_results.pickle"
    results = run_experiment(indices)
    save(results, filename)
    print("saved ", filename)



RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`

In [None]:
    

def delta_ablate(cache,
                  mle_token_idx,
                  target_token_idx, 
                return_labels=False):
    if return_labels: 
        residual_ablated_stack, labels = cache.decompose_resid(layer=-1, mode="all", return_labels=True)
    else: 
        residual_ablated_stack = cache.decompose_resid(layer=-1, mode="all", return_labels=False)
    
    residual_logits = unembedding_function(residual_ablated_stack, cache)
    residual_logits = residual_logits[:,:,-1,:]

    target_idx_expanded = target_token_idx.repeat(residual_logits.shape[0],1,1)
    mle_idx_expanded = mle_token_idx.repeat(residual_logits.shape[0],1,1)

    target_residual_logits = residual_logits.gather(index=target_idx_expanded, dim=-1) - residual_logits.mean(dim=-1, keepdim=True)
    mle_residual_logits = residual_logits.gather(index=mle_idx_expanded, dim=-1) - residual_logits.mean(dim=-1, keepdim=True)
    
    target_residual_logits = target_residual_logits.squeeze().mean(dim=-1)
    mle_residual_logits = mle_residual_logits.squeeze().mean(dim=-1)
    
    if return_labels: 
        return labels, mle_residual_logits, target_residual_logits
    return mle_residual_logits, target_residual_logits
