In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [34]:
import einops
from functools import partial
import torch
import datasets
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text
from utils.circuit_analysis import get_logit_diff

from utils.tokenwise_ablation import (
    compute_ablation_modified_logit_diff,
    load_directions,
    get_random_directions,
    get_zeroed_dir_vector,
    get_layerwise_token_mean_activations
)

## Comma Ablation on Natural Text

### Model Setup

In [3]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-2.8b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


### Data

### OWT Data Prep

In [16]:
owt_data = load_dataset("stas/openwebtext-10k")

Repo card metadata block was not found. Setting CardData to empty.


In [17]:
owt_data

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 10000
    })
})

In [18]:
from utils.datasets import OWTData
owt_data = OWTData(owt_data, model)

In [19]:
owt_data.preprocess_datasets(token_to_ablate=13)

Map:   0%|          | 0/10959 [00:00<?, ? examples/s]

  positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)


Map:   0%|          | 0/10959 [00:00<?, ? examples/s]

  attention_mask = torch.ones_like(torch.tensor(example['tokens']), dtype=torch.int)


In [20]:
datasets = owt_data.get_datasets()

In [21]:
datasets['train']

Dataset({
    features: ['tokens', 'positions', 'has_token', 'attention_mask'],
    num_rows: 10959
})

In [31]:
print(model.to_str_tokens(datasets['train'][0]['tokens'][30:40]))

['in', ' Kamp', 'f', '”', ' in', ' Germany', ',', ' but', ' the', ' government']


In [30]:
print(datasets['train'][0]['positions'][30:40])

tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0])


In [32]:
dataloader = owt_data.get_dataloaders(batch_size=5)['train']

### Calculate Means

In [35]:
comma_mean_values = get_layerwise_token_mean_activations(model, dataloader, token_id=13)
save_array(comma_mean_values, 'comma_mean_values.npy', model)

  0%|          | 0/2192 [00:00<?, ?it/s]

'data/pythia-2.8b/comma_mean_values.npy'

In [18]:
# load the files
owt_mean_values = torch.from_numpy(load_array('comma_mean_values.npy', model)).to(device)

### Get Loss Results

In [37]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
ablated_loss_diff, orig_loss = compute_mean_ablation_modified_loss(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    comma_mean_bal_values,
    debug=False
)

ablated_loss = orig_loss + ablated_loss_diff

# orig_accuracy = (orig_ld_list > 0).float().mean()
# ablated_accuracy = (ablated_ld_list > 0).float().mean()
# freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

# print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
# print(f"Original accuracy: {orig_accuracy:.4f}")
# print("\n")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
# print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("\n")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
# print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("---------------------------------------------------------")
# print("Random direction ablation results:")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

In [39]:
ablated_loss_diff[0]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.1838, -0.0903,  0.0681,  0.1680, -0.0285,  0.2364,  0.4690,  0.3611,
         0.2591, -0.3020,  0.3572,  0.0501,  0.2772,  0.1820,  0.0000,  0.4224,
         1.6662,  0.6940,  0.1109,  0.0439, -0.2660,  0.2348, -0.6084, -0.7712,
        -0.1148, -0.2007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
       device='cuda:0')

In [None]:
ablated_loss_diff[0][33], orig_loss[0][33], ablated_loss[0][33]

tensor(-0.2007, device='cuda:0')

In [32]:
batch = next(iter(balanced_data_loader))

In [42]:
model.to_str_tokens(batch['tokens'][0][33])

[':']

In [None]:
batch

In [38]:
model.reset_hooks()