In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [3]:
import einops
from functools import partial
import torch
import datasets
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text
from utils.circuit_analysis import get_logit_diff

from utils.tokenwise_ablation import (
    compute_ablation_modified_logit_diff,
    load_directions,
    get_random_directions,
    get_zeroed_dir_vector
)

## Comma Ablation on Natural Text

### Model Setup

In [4]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-2.8b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


### Data

In [9]:
from datasets import load_from_disk
sst_data = load_from_disk("sst2")
sst_data

FileNotFoundError: Directory sst2 not found

In [7]:
sst_data['train'][0]['text'], sst_data['train'][0]['label']

("The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal.",
 1)

In [8]:
def filter_function(example):
    prompt = model.to_tokens(example['text'] + " Review Sentiment:", prepend_bos=False)
    answer = torch.tensor([29071, 32725]).unsqueeze(0).unsqueeze(0).to(device) if example['label'] == 1 else torch.tensor([32725, 29071]).unsqueeze(0).unsqueeze(0).to(device)
    logits, cache = model.run_with_cache(prompt)
    logit_diff = get_logit_diff(logits, answer)
    
    # Determine if the top answer (index 0) token is in top 10 logits
    _, top_indices = logits.topk(10, dim=-1)  # Get indices of top 10 logits
    top_answer_token = answer[0, 0, 0]  # Assuming answer is of shape (1, 1, 2) and the top answer token is at index 0
    is_top_answer_in_top_10_logits = (top_indices == top_answer_token).any()
    
    # Add a new field 'keep_example' to the example
    example['keep_example'] = (logit_diff > 0.0) and is_top_answer_in_top_10_logits.item()
    return example

# Use the map function to apply the filter_function
sst_data_with_flag_train = sst_data['train'].map(filter_function, keep_in_memory=True)
sst_data_with_flag_val = sst_data['dev'].map(filter_function, keep_in_memory=True)
sst_data_with_flag_test = sst_data['test'].map(filter_function, keep_in_memory=True)
sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_val, sst_data_with_flag_test])

# Use the filter function to keep only the examples where 'keep_example' is True
sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])

# save dataset
sst_zero_shot.save_to_disk("sst_zero_shot")




Map:   0%|          | 0/7864 [00:00<?, ? examples/s]

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10929 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6169 [00:00<?, ? examples/s]

In [9]:
sst_zero_shot

Dataset({
    features: ['text', 'label', 'keep_example'],
    num_rows: 6169
})

In [11]:
from datasets import load_from_disk
sst_zero_shot = load_from_disk("sst_zero_shot")

In [13]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


# Define the batch size
BATCH_SIZE = 5

# Load a tokenizer (you'll need to specify the appropriate model)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b")
# set padding token
tokenizer.pad_token = model.to_string([1])

#dataset = text_dataset.map(lambda x: tokenize_and_concatenate(x, tokenizer))

def concatenate_classification_prompts(examples):
    return {"text": (examples['text'] + " Review Sentiment:")}


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)


def find_dataset_positions(example, token_id=13):
    # Create a tensor of zeros with the same shape as example['tokens']
    positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)

    # Find positions where tokens match the given token_id
    positions[example['tokens'] == token_id] = 1
    has_token = True if positions.sum() > 0 else False

    return {'positions': positions, 'has_token': has_token}

def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):
    if example['label'] == 1:
        answers = torch.tensor([pos_answer_id, neg_answer_id])
    else:
        answers = torch.tensor([neg_answer_id, pos_answer_id])

    return {'answers': answers}


dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)
dataset = dataset.map(tokenize_function, batched=False)
dataset = dataset.map(convert_answers, batched=False)
dataset = dataset.rename_column("input_ids", "tokens")
dataset.set_format(type="torch", columns=["tokens", "attention_mask", "label", "answers"])
dataset = dataset.map(find_dataset_positions, batched=False)
dataset = dataset.filter(lambda example: example['has_token']==True)
dataset


Dataset({
    features: ['text', 'label', 'keep_example', 'tokens', 'attention_mask', 'answers', 'positions', 'has_token'],
    num_rows: 3318
})

In [14]:
model.to_string(dataset[0]['tokens']), model.to_str_tokens(dataset[0]['answers'][0])

("The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal. Review Sentiment:<|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|>",
 [' Positive'])

In [15]:
from utils.circuit_analysis import get_logit_diff
logits, cache = model.run_with_cache(dataset['tokens'][0])
get_logit_diff(logits, dataset['answers'][0].unsqueeze(0).unsqueeze(0).to(device))

tensor(-2.8829, device='cuda:0')

In [16]:
# create a subset with only positive labels
pos_dataset = dataset.filter(lambda example: example['label']==1)
neg_dataset = dataset.filter(lambda example: example['label']==0)
len(pos_dataset), len(neg_dataset)

(2623, 695)

In [17]:
import random
from datasets import concatenate_datasets

def get_random_subset(dataset, n):
    total_size = len(dataset)
    random_indices = random.sample(range(total_size), n)
    return dataset.select(random_indices)

pos_subset = get_random_subset(pos_dataset, 695)
neg_subset = get_random_subset(neg_dataset, 695)
balanced_subset = concatenate_datasets([pos_subset, neg_subset])
# randomize the order of balanced_subset
balanced_subset = balanced_subset.shuffle(len(balanced_subset))

# Create a new dataloader from the subset, converting the data to tensors
pos_data_loader = DataLoader(
    pos_subset, batch_size=5, shuffle=False, drop_last=True
)
neg_data_loader = DataLoader(
    neg_subset, batch_size=5, shuffle=False, drop_last=True
)
balanced_data_loader = DataLoader(
    balanced_subset, batch_size=5, shuffle=False, drop_last=True
)

# save datasets
pos_subset.save_to_disk("sst_pos_subset")
neg_subset.save_to_disk("sst_neg_subset")
balanced_subset.save_to_disk("sst_balanced_subset")

len(pos_data_loader), len(neg_data_loader)


Saving the dataset (0/1 shards):   0%|          | 0/695 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/695 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1390 [00:00<?, ? examples/s]

(139, 139)

### OWT Data Prep

In [28]:
owt_data = load_dataset("stas/openwebtext-10k")

Repo card metadata block was not found. Setting CardData to empty.


In [29]:
owt_data

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 10000
    })
})

In [30]:
from utils.datasets import OWTData
owt_data = OWTData(owt_data, model)

In [31]:
owt_data.preprocess_datasets()

Map:   0%|          | 0/10959 [00:00<?, ? examples/s]

In [32]:
owt_data.find_dataset_positions()

Map:   0%|          | 0/10959 [00:00<?, ? examples/s]

TypeError: ExperimentData._find_dataset_positions() got multiple values for argument 'token_id'

In [8]:
type(dataset['tokens'][0])

torch.Tensor

### Apply Mean Ablation

In [39]:
comma_mean_bal_values = get_layerwise_token_mean_activations(model, balanced_data_loader, token_id=13)
comma_mean_neg_values = get_layerwise_token_mean_activations(model, neg_data_loader, token_id=13)
comma_mean_pos_values = get_layerwise_token_mean_activations(model, pos_data_loader, token_id=13)
save_array(comma_mean_bal_values, 'comma_balanced_mean_values.npy', model)
save_array(comma_mean_neg_values, 'comma_neg_mean_values.npy', model)
save_array(comma_mean_pos_values, 'comma_pos_mean_values.npy', model)

torch.Size([32, 2560])


  0%|          | 0/278 [00:00<?, ?it/s]

torch.Size([32, 2560])


  0%|          | 0/139 [00:00<?, ?it/s]

torch.Size([32, 2560])


  0%|          | 0/139 [00:00<?, ?it/s]

'data/pythia-2.8b/comma_pos_mean_values.npy'

In [40]:
# period_mean_values = get_layerwise_token_mean_activations(model, train_data_loader, token_id=15)
# save_array(period_mean_values, 'period_mean_values.npy', model)

In [18]:
# load the files
owt_mean_values = torch.from_numpy(load_array('comma_mean_values.npy', model)).to(device)
comma_mean_bal_values = torch.from_numpy(load_array('comma_balanced_mean_values.npy', model)).to(device)
comma_mean_neg_values = torch.from_numpy(load_array('comma_neg_mean_values.npy', model)).to(device)
comma_mean_pos_values = torch.from_numpy(load_array('comma_pos_mean_values.npy', model)).to(device)
#period_mean_values = torch.from_numpy(load_array('period_mean_values.npy', model)).to(device)

### Get Loss Results

In [37]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
ablated_loss_diff, orig_loss = compute_mean_ablation_modified_loss(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    comma_mean_bal_values,
    debug=False
)

ablated_loss = orig_loss + ablated_loss_diff

# orig_accuracy = (orig_ld_list > 0).float().mean()
# ablated_accuracy = (ablated_ld_list > 0).float().mean()
# freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

# print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
# print(f"Original accuracy: {orig_accuracy:.4f}")
# print("\n")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
# print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("\n")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
# print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("---------------------------------------------------------")
# print("Random direction ablation results:")
# print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

In [39]:
ablated_loss_diff[0]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.1838, -0.0903,  0.0681,  0.1680, -0.0285,  0.2364,  0.4690,  0.3611,
         0.2591, -0.3020,  0.3572,  0.0501,  0.2772,  0.1820,  0.0000,  0.4224,
         1.6662,  0.6940,  0.1109,  0.0439, -0.2660,  0.2348, -0.6084, -0.7712,
        -0.1148, -0.2007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
       device='cuda:0')

In [None]:
ablated_loss_diff[0][33], orig_loss[0][33], ablated_loss[0][33]

tensor(-0.2007, device='cuda:0')

In [32]:
batch = next(iter(balanced_data_loader))

In [42]:
model.to_str_tokens(batch['tokens'][0][33])

[':']

In [None]:
batch

In [38]:
model.reset_hooks()