In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [4]:
import einops
from functools import partial
import numpy as np
import torch
import datasets
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from utils.circuit_analysis import get_logit_diff
from utils.store import load_array, save_array

from utils.tokenwise_ablation import (
    compute_ablation_modified_loss,
    compute_ablation_modified_logit_diff,
    load_directions,
    get_random_directions,
    get_zeroed_dir_vector,
    get_layerwise_token_mean_activations
)
from utils.datasets import OWTData

## Comma Ablation on Natural Text

### Model Setup

In [5]:
device = torch.device("cuda")
MODEL_NAME = "EleutherAI/pythia-2.8b"
TOKEN_ID = 13
BATCH_SIZE = 1

In [4]:
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


### Data

In [6]:
owt_data = OWTData.from_model(model)
owt_data.preprocess_datasets(token_to_ablate=TOKEN_ID)

Downloading builder script:   0%|          | 0.00/3.08k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/951 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/14.7M [00:00<?, ?B/s]



Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/10000 [00:00<?, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)


Map:   0%|          | 0/5475 [00:00<?, ? examples/s]

Map:   0%|          | 0/5475 [00:00<?, ? examples/s]

In [27]:
datasets = owt_data.dataset_dict

In [28]:
datasets['train']

Dataset({
    features: ['tokens', 'attention_mask', 'positions', 'has_token'],
    num_rows: 5475
})

In [29]:
datasets['train'][0]['tokens'][30:40]

tensor([  249, 46882,    71,   668,   275,  6176,    13,   533,   253,  2208])

In [30]:
print(model.to_str_tokens(datasets['train'][0]['tokens'][30:40]))

['in', ' Kamp', 'f', '”', ' in', ' Germany', ',', ' but', ' the', ' government']


In [31]:
print(datasets['train'][0]['positions'][30:40])

tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0])


In [39]:
#dataloader = owt_data.get_dataloaders(batch_size=BATCH_SIZE)['train']
# subsample dataset
smaller_dataset = datasets['train'].select(range(50))
owt_small = OWTData.from_model(model)
owt_small.dataset_dict['train'] = smaller_dataset
smaller_dataloader = owt_small.get_dataloaders(batch_size=BATCH_SIZE)['train']



Repo card metadata block was not found. Setting CardData to empty.


In [40]:
smaller_dataset

Dataset({
    features: ['tokens', 'attention_mask', 'positions', 'has_token'],
    num_rows: 50
})

In [13]:
len(datasets['train']), len(dataloader)

(5475, 5475)

### Calculate Means

In [13]:
comma_mean_values = get_layerwise_token_mean_activations(model, dataloader, token_id=13, device=device)
save_array(comma_mean_values, 'comma_mean_values.npy', model)

  0%|          | 0/685 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
# load the files
owt_mean_values = torch.from_numpy(np.load('data/pythia-2.8b/comma_mean_values.npy')).to(device)

### Get Loss Results

In [52]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
output = compute_ablation_modified_loss(
    model, 
    smaller_dataloader,
    layers_to_ablate,
    owt_mean_values,
    cached=False
)

#ablated_loss = orig_loss + ablated_loss_diff

# orig_accuracy = (orig_ld_list > 0).float().mean()
# ablated_accuracy = (ablated_ld_list > 0).float().mean()
# freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

# print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
# print(f"Original accuracy: {orig_accuracy:.4f}")
# print("\n")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
# print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("\n")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
# print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("---------------------------------------------------------")
# print("Random direction ablation results:")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/50 [00:00<?, ?it/s]

In [56]:
orig_loss, ablated_loss_diff = output[0], output[1]
ablated_loss = orig_loss + ablated_loss_diff

In [67]:
# display to 4 decimal places
orig_loss[0, 35:45], ablated_loss_diff[0, 35:45], ablated_loss[0, 35:45]

(tensor([ 0.2679,  0.9327,  0.7349,  1.6650,  2.8340,  4.7207,  2.5980,  0.0159,
          2.0307, 12.9145]),
 tensor([ 0.0000,  0.0000,  0.0000,  0.1548,  0.5523, -1.1104, -0.0629,  0.0478,
         -0.4134,  0.0000]),
 tensor([ 0.2679,  0.9327,  0.7349,  1.8198,  3.3863,  3.6103,  2.5350,  0.0637,
          1.6174, 12.9145]))

In [75]:
orig_loss[0][36], ablated_loss_diff[0][36], ablated_loss[0][36]

(tensor(0.9327), tensor(0.), tensor(0.9327))

In [69]:
batch = next(iter(smaller_dataloader))

In [74]:
model.to_str_tokens(batch['tokens'][0][36:43])

[',', ' but', ' the', ' government', ' of', ' Bav', 'aria']

In [None]:
model.reset_hooks()