In [40]:
import sys
sys.path.append('..')
sys.path.append('../src')
sys.path.append('../data')

import torch
from transformer_lens import HookedTransformer
import json
from src.patching import get_act_patch_mlp_out
from src.model import WrapHookedTransformer
from src.dataset import Dataset
import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching
import plotly.express as px
import random
from src.utils import (
    embs_to_tokens_ids,
    patch_resid_pre,
    patch_attn_head_out_all_pos,
    patch_attn_head_by_pos,
    patch_per_block_all_poss,
    patch_attn_head_all_pos_every,
    logit_lens,
    list_of_dicts_to_dict_of_lists,
)
from dataclasses import dataclass
%load_ext autoreload
%autoreload 2

class Config:
    num_samples: int = 100
    batch_size: int = 100
    mem_win_noise_position = [1,2,3,9,10,11]
    mem_win_noise_mlt = 2
    cp_win_noise_position = [8]
    cp_win_noise_mlt = 2
config = Config()

def dict_of_lists_to_dict_of_tensors(dict_of_lists):
    dict_of_tensors = {}
    for key, tensor_list in dict_of_lists.items():
        # If the key is "example_str_token", keep it as a list of strings
        if key == "example_str_token" or key == "logit_lens":
            dict_of_tensors[key] = tensor_list
            continue
        
        # Check if the first element of the list is a tensor
        if isinstance(tensor_list[0], torch.Tensor):
            dict_of_tensors[key] = torch.stack(tensor_list)
        # If the first element is a list, convert each inner list to a tensor and then stack
        elif isinstance(tensor_list[0], list):
            tensor_list = [torch.tensor(item) for item in tensor_list]
            dict_of_tensors[key] = torch.stack(tensor_list)
        else:
            print(f"Unsupported data type for key {key}: {type(tensor_list[0])}")
            raise ValueError(f"Unsupported data type for key {key}: {type(tensor_list[0])}")
    return dict_of_tensors

torch.set_grad_enabled(False)

MODEL_NAME = "gpt2small"
MAX_LEN = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = WrapHookedTransformer.from_pretrained("gpt2", device=DEVICE)
target_data = json.load(open("../data/target_win_dataset_{}_filtered.json".format(MODEL_NAME)))
orthogonal_data = json.load(
    open("../data/orthogonal_win_dataset_{}_filtered.json".format(MODEL_NAME))
)
orthogonal_data = random.sample(orthogonal_data, len(target_data))
dataset = Dataset(target_data, orthogonal_data, model)
dataset.random_sample(config.num_samples, MAX_LEN)


dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=False)


pos_result = []
pos_result_var = []
neg_result = []
neg_result_var = []
pos_clean_caches = []
neg_clean_caches = []
batch = next(iter(dataloader))

pos_batch = batch["pos_dataset"]
neg_batch = batch["neg_dataset"]
pos_target_ids = {
    "target": model.to_tokens(pos_batch["target"], prepend_bos=False),
    "orthogonal": model.to_tokens(pos_batch["orthogonal_token"], prepend_bos=False),
}
neg_target_ids = {
    "target": model.to_tokens(neg_batch["target"], prepend_bos=False),
    "orthogonal": model.to_tokens(neg_batch["orthogonal_token"], prepend_bos=False),
}
pos_input_ids = model.to_tokens(pos_batch["premise"], prepend_bos=True)
neg_input_ids = model.to_tokens(neg_batch["premise"], prepend_bos=True)
pos_embs_corrupted = model.add_noise(
    pos_batch["premise"],
    noise_index = torch.tensor(config.mem_win_noise_position),
    target_win=8,
    noise_mlt=config.mem_win_noise_mlt
)
neg_embs_corrupted = model.add_noise(
    neg_batch["premise"],
    noise_index = torch.tensor(config.cp_win_noise_position),
    target_win=8,
    noise_mlt=config.cp_win_noise_mlt
)

pos_corrupted_logit, pos_corrupted_cache = model.run_with_cache_from_embed(pos_embs_corrupted)
pos_clean_logit, pos_clean_cache = model.run_with_cache(pos_batch["premise"])
neg_corrupted_logit, neg_corrupted_cache = model.run_with_cache_from_embed(neg_embs_corrupted)
neg_clean_logit, neg_clean_cache = model.run_with_cache(neg_batch["premise"])



pos_clean_caches.append(pos_clean_cache)
neg_clean_caches.append(neg_clean_cache)

def check_reversed_probs( corrupted_logits, target_pos, orthogonal_pos):
    corrupted_logits = torch.softmax(corrupted_logits, dim=-1)
    target_probs = corrupted_logits[:,-1,:].gather(-1, index=target_pos).squeeze(-1)
    orthogonal_probs = corrupted_logits[:,-1,:].gather(-1, index=orthogonal_pos).squeeze(-1)
    return (target_probs - orthogonal_probs).mean()

print("Traget - Orthogonal", check_reversed_probs( pos_corrupted_logit,  pos_target_ids["target"], pos_target_ids["orthogonal"],))
print("Target - orthogonal", check_reversed_probs( neg_corrupted_logit, neg_target_ids["target"], neg_target_ids["orthogonal"]))



def indirect_effect(logits, corrupted_logits, first_ids_pos, return_type="mean"):
    logits = torch.nn.functional.log_softmax(logits, dim=-1)
    corrupted_logits = torch.nn.functional.log_softmax(corrupted_logits, dim=-1)
    # Use torch.gather to get the desired values
    logits_values = torch.gather(logits[:, -1, :], 1, first_ids_pos).squeeze()
    corrupted_logits_values = torch.gather(corrupted_logits[:, -1, :], 1, first_ids_pos).squeeze()
    delta_value = logits_values - corrupted_logits_values

    return delta_value
        
POS_BASELINE = indirect_effect(
    logits=pos_clean_logit,
    corrupted_logits=pos_corrupted_logit,
    first_ids_pos=pos_target_ids["target"]
)
NEG_BASELINE = indirect_effect(
    logits=neg_clean_logit,
    corrupted_logits=neg_corrupted_logit,
    first_ids_pos=neg_target_ids["orthogonal"]
)


def pos_metric(logits, return_type="mean"):
    improved = indirect_effect(
        logits=logits,
        corrupted_logits=pos_corrupted_logit,
        first_ids_pos=pos_target_ids["target"]
    )
    # improved = improved/POS_BASELINE
    if return_type == "mean":
        return improved.mean()
    elif return_type == "var":
        return improved.std()
    elif return_type == "mad":
        return (improved - improved.median()).abs().median()
    
def neg_metric(logits, return_type="mean"):
    improved = indirect_effect(
        logits=logits,
        corrupted_logits=neg_corrupted_logit,
        first_ids_pos=neg_target_ids["orthogonal"]
    )
    # improved = improved/NEG_BASELINE
    if return_type == "mean":
        return improved.mean()
    elif return_type == "var":
        return improved.std(dim=0)
    elif return_type == "mad":
        return (improved - improved.median()).abs().median()

pos_metric_var = partial(pos_metric, return_type="mad")
neg_metric_var = partial(neg_metric, return_type="mad")


print("pos metric", pos_metric(logits=pos_clean_logit), "var", pos_metric_var(logits=pos_clean_logit))
print("neg metric", neg_metric(logits=neg_clean_logit), "var", neg_metric_var(logits=neg_clean_logit))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
possible_lengths for sampling 100: [12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0]
Traget - Orthogonal tensor(0.0652)
Target - orthogonal tensor(-0.0983)
pos metric tensor(0.0086) var tensor(0.0670)
neg metric tensor(0.0260) var tensor(0.1569)


In [37]:
pos_effect = indirect_effect(
        logits=pos,
        corrupted_logits=neg_corrupted_logit,
        first_ids_pos=neg_target_ids["orthogonal"]
    )
neg_effect = indirect_effect(
        logits=neg_clean_logit,
        corrupted_logits=pos_corrupted_logit,
        first_ids_pos=pos_target_ids["target"]
    )

In [39]:
#plot the distribution of the effect
fig = px.histogram(pos_effect, nbins=50)
#plot the mean of the effect
fig.add_vline(x=pos_effect.mean(), line_width=3, line_dash="dash", line_color="green")
#plot the std of the effect
fig.add_vline(x=pos_effect.mean() + pos_effect.std(), line_width=3, line_dash="dash", line_color="red")
fig.add_vline(x=pos_effect.mean() - pos_effect.std(), line_width=3, line_dash="dash", line_color="red")
# plot the mad of the effect
mad = (pos_effect - pos_effect.median()).abs().median()
fig.add_vline(x=pos_effect.median() + mad, line_width=3, line_dash="dash", line_color="purple")
fig.add_vline(x=pos_effect.median() - mad, line_width=3, line_dash="dash", line_color="purple")
# add legend
fig.add_scatter(x=[0], y=[0], mode="markers", marker=dict(size=10, color="green"), name="mean")
fig.add_scatter(x=[0], y=[0], mode="markers", marker=dict(size=10, color="red"), name="std")
fig.add_scatter(x=[0], y=[0], mode="markers", marker=dict(size=10, color="purple"), name="mad")
fig.show()

In [17]:
#plot the distribution of the effect
fig = px.histogram(neg_effect, nbins=50)
#plot the mean of the effect
fig.add_vline(x=neg_effect.mean(), line_width=3, line_dash="dash", line_color="green")
#plot the std of the effect
fig.add_vline(x=neg_effect.mean() + neg_effect.std(), line_width=3, line_dash="dash", line_color="red")
fig.add_vline(x=neg_effect.mean() - neg_effect.std(), line_width=3, line_dash="dash", line_color="red")
# plot the mad of the effect
mad = (neg_effect - neg_effect.median()).abs().median()
fig.add_vline(x=neg_effect.median() + mad, line_width=3, line_dash="dash", line_color="purple")
fig.add_vline(x=neg_effect.median() - mad, line_width=3, line_dash="dash", line_color="purple")
fig.show()

In [23]:
pos_corrupted_cache["blocks.4.hook_resid_post"].shape

torch.Size([100, 16, 768])

In [30]:
import transformer_lens.utils as utils
utils.get_act_name("resid_pre", 4)

'blocks.4.hook_resid_pre'

In [36]:
def residual_hooks(activation, hook, corrupted_cache, position):

    activation[:,position,:] = corrupted_cache[:, position, :]
    return activation
def embed_hook(cache, hook, corrupted_embeddings):
                cache[:,:,:] = corrupted_embeddings
                return cache
embeds_hook = partial(embed_hook, corrupted_embeddings=pos_embs_corrupted)


hooks_fn = partial(residual_hooks, corrupted_cache = pos_corrupted_cache["blocks.4.hook_resid_post"], position = 15)

logit = model.run_with_hooks(pos_input_ids,
    fwd_hooks=[("hook_embed", embeds_hook), (utils.get_act_name("resid_pre", 4),hooks_fn)]
)