In [1]:
!ls

README.md		    miniconda.sh  quick_start_pytorch.ipynb   wandb
eliciting-latent-sentiment  miniconda3	  quick_start_pytorch_images


In [2]:
%cd eliciting-latent-sentiment

/notebooks/eliciting-latent-sentiment


In [3]:
#!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install fancy_einsum==0.0.3
!pip install transformer_lens
!pip install jaxtyping==0.2.13
!pip install einops
!pip install protobuf==3.20.*
!pip install plotly
!pip install torchtyping
!pip install git+https://github.com/neelnanda-io/neel-plotly.git
!pip install circuitsvis
!pip install imgkit
# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
# %pip install git+https://github.com/neelnanda-io/PySvelte.git
# %pip install typeguard==2.13.3

[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-9inesure
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-9inesure
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25ldone
[0m

In [4]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [5]:
import einops
from functools import partial
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import random
from datasets import concatenate_datasets
import torch
import datasets
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
from datasets import Dataset
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text
from utils.circuit_analysis import get_logit_diff
import plotly.express as px

from utils.tokenwise_ablation import (
    compute_mean_ablation_modified_logit_diff,
    compute_directional_ablation_modified_logit_diff,
    compute_last_position_logit_diff,
    compute_directional_ablation_modified_logit_diff_all_pos,
    compute_mean_ablation_modified_loss,
    load_directions,
    get_random_directions,
    get_zeroed_dir_vector
)

## Comma Ablation on Natural Text

### Data Preparation

In [7]:
# RUN BELOW IF NOT ALREADY DONE
from datasets import load_from_disk


In [8]:
from datasets import load_from_disk

In [9]:

def filter_function(example, model):
    prompt = model.to_tokens(example['text'] + " Review Sentiment:", prepend_bos=False)
    answer = torch.tensor([29071, 32725]).unsqueeze(0).unsqueeze(0).to(device) if example['label'] == 1 else torch.tensor([32725, 29071]).unsqueeze(0).unsqueeze(0).to(device)
    logits = model(prompt, return_type="logits")
    logit_diff = get_logit_diff(logits, answer)
    
    # Determine if the top answer (index 0) token is in top 10 logits
    _, top_indices = logits.topk(10, dim=-1)  # Get indices of top 10 logits
    top_answer_token = answer[0, 0, 0]  # Assuming answer is of shape (1, 1, 2) and the top answer token is at index 0
    is_top_answer_in_top_10_logits = (top_indices == top_answer_token).any()
    
    # Add a new field 'keep_example' to the example
    example['keep_example'] = (logit_diff > 0.0) and is_top_answer_in_top_10_logits.item()
    return example


def concatenate_classification_prompts(examples):
    return {"text": (examples['text'] + " Review Sentiment:")}


def tokenize_function(examples, tokenizer):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)


def find_dataset_positions(example, token_id=13):
    # Create a tensor of zeros with the same shape as example['tokens']
    positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)

    # Find positions where tokens match the given token_id
    positions[example['tokens'] == token_id] = 1
    has_token = True if positions.sum() > 0 else False

    return {'positions': positions, 'has_token': has_token}


def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):
    if example['label'] == 1:
        answers = torch.tensor([pos_answer_id, neg_answer_id])
    else:
        answers = torch.tensor([neg_answer_id, pos_answer_id])

    return {'answers': answers}


def get_random_subset(dataset, n):
    total_size = len(dataset)
    random_indices = random.sample(range(total_size), n)
    return dataset.select(random_indices)

In [43]:
import re
def prepare_sst_for_model(
        model: HookedTransformer,
        dataset_name: str = "sst2", 
        batch_size: int = 5,
        pad_token_id: int = 1, 
        pos_answer_id: int = 29071, 
        neg_answer_id: int = 32725
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Define the batch size
    BATCH_SIZE = batch_size

    sst_data = load_from_disk(dataset_name)

    # Use the map function to apply the filter_function
    filter_function_for_model = partial(filter_function, model=model)
    sst_data_with_flag_train = sst_data['train'].map(filter_function_for_model, keep_in_memory=True)
    sst_data_with_flag_dev = sst_data['dev'].map(filter_function_for_model, keep_in_memory=True)
    sst_data_with_flag_test = sst_data['test'].map(filter_function_for_model, keep_in_memory=True)
    sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_dev, sst_data_with_flag_test])
    #sst_data_with_flag = sst_data_with_flag_dev

    # Use the filter function to keep only the examples where 'keep_example' is True
    sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])
    # print number of items in dataset
    print(f"Number of items in dataset: {len(sst_zero_shot)}")
    # save dataset
    #new model name without slashes
    model_abbr = re.sub(r'/', '_', model.name)
    sst_zero_shot.save_to_disk(f"sst_zero_shot_{model_abbr}")

    # Load a tokenizer (you'll need to specify the appropriate model)
    tokenizer = AutoTokenizer.from_pretrained(model.name)
    # set padding token
    tokenizer.pad_token = model.to_string([pad_token_id])

    dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)
    tokenizer_function_for_model = partial(tokenize_function, tokenizer=tokenizer)
    dataset = dataset.map(tokenizer_function_for_model, batched=False)
    convert_answers_for_model = partial(convert_answers, pos_answer_id=pos_answer_id, neg_answer_id=neg_answer_id)
    dataset = dataset.map(convert_answers_for_model, batched=False)
    dataset = dataset.rename_column("input_ids", "tokens")
    dataset.set_format(type="torch", columns=["tokens", "attention_mask", "label", "answers"])
    dataset = dataset.map(find_dataset_positions, batched=False)
    dataset = dataset.filter(lambda example: example['has_token']==True)

    # create a subset with only positive labels
    pos_dataset = dataset.filter(lambda example: example['label']==1)
    neg_dataset = dataset.filter(lambda example: example['label']==0)
    len(pos_dataset), len(neg_dataset)

    subset_size = (min(len(pos_dataset), len(neg_dataset)) // BATCH_SIZE) * BATCH_SIZE

    pos_subset = get_random_subset(pos_dataset, subset_size)
    neg_subset = get_random_subset(neg_dataset, subset_size)
    balanced_subset = concatenate_datasets([pos_subset, neg_subset])
    # randomize the order of balanced_subset
    balanced_subset = balanced_subset.shuffle(len(balanced_subset))

    # Create a new dataloader from the subset, converting the data to tensors
    pos_data_loader = DataLoader(
        pos_subset, batch_size=5, shuffle=False, drop_last=True
    )
    neg_data_loader = DataLoader(
        neg_subset, batch_size=5, shuffle=False, drop_last=True
    )
    balanced_data_loader = DataLoader(
        balanced_subset, batch_size=5, shuffle=False, drop_last=True
    )

    print(f"Number of items in pos dataset: {len(pos_subset)}")
    print(f"Number of items in neg dataset: {len(neg_subset)}")
    print(f"Number of items in balanced dataset: {len(balanced_subset)}")
    return pos_data_loader, neg_data_loader, balanced_data_loader


In [51]:
def compute_average_comma_distances(dataset: Dataset) -> torch.Tensor:
    avg_distances = []

    for example in dataset:
        positions = example['positions']
        attention_mask = example['attention_mask']

        # Find the end of the prompt (the last token before padding starts)
        prompt_length = sum(attention_mask)

        # Calculate the distance of each comma from the end of the prompt
        comma_indices = [i for i, value in enumerate(positions) if value == 1]
        distances = [prompt_length - i for i in comma_indices]

        # Calculate the average distance for this example
        avg_distance = sum(distances) / len(comma_indices) if comma_indices else 0
        avg_distances.append(avg_distance)

    # Convert the list of average distances to a tensor
    return torch.tensor(avg_distances)

def compute_first_comma_distances(dataset: Dataset) -> torch.Tensor:
    first_comma_distances = []

    for example in dataset:
        positions = torch.tensor(example['positions'])
        attention_mask = torch.tensor(example['attention_mask'])

        # Find the end of the prompt (the last token before padding starts)
        prompt_length = sum(attention_mask)

        # Find the indices where positions are equal to 1 (comma positions)
        comma_indices = (positions == 1).nonzero(as_tuple=True)[0]

        if len(comma_indices) > 0:
            # Get the first comma index and calculate the distance
            first_comma_index = comma_indices[0]
            distance = prompt_length - first_comma_index
        else:
            # If there's no comma, set distance to -1 or another placeholder value
            distance = -1

        first_comma_distances.append(distance)

    # Convert the list of distances to a tensor
    return torch.tensor(first_comma_distances)

### Pythia-1.4b Directional Ablation

#### Model

In [7]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-1.4b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Loaded pretrained model EleutherAI/pythia-1.4b into HookedTransformer


In [10]:
pos_data_loader, neg_data_loader, balanced_data_loader = prepare_sst_for_model(model, "sst2", 5, 1, 29071, 32725)

Parameter 'function'=functools.partial(<function filter_function at 0x7fcd6878c700>, model=HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_i

Map:   0%|          | 0/7864 [00:00<?, ? examples/s]

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10929 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1066 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

Map:   0%|          | 0/1066 [00:00<?, ? examples/s]

  positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)


Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

Filter:   0%|          | 0/446 [00:00<?, ? examples/s]

Filter:   0%|          | 0/446 [00:00<?, ? examples/s]

Number of items in pos dataset: 115
Number of items in neg dataset: 115
Number of items in balanced dataset: 230


In [11]:
directions = load_directions(model, "data/directions/pythia-1.4b", direction_prefix="das_simple_train_ADJ_layer")
random_directions = get_random_directions(model)
zeroed_directions = get_zeroed_dir_vector(model)

#### Punctuation Only

In [12]:
model.reset_hooks()

##### Balanced Results

In [13]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)
orig_ld_list_rand, ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    random_directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/46 [00:00<?, ?it/s]

  0%|          | 0/46 [00:00<?, ?it/s]

Original mean logit diff: 1.2518
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.2197
Comma-ablated accuracy: 0.9435
Percent drop in logit diff with comma ablation: 2.56%
Percent drop in accuracy with comma ablation: 5.65%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.2475
Percent drop in logit diff with comma ablation: 0.34%


#### Complete Ablation

In [14]:
model.reset_hooks()

##### Balanced Results

In [15]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"All-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"All-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with all-ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"All-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/46 [00:00<?, ?it/s]

Original mean logit diff: 1.2518
Original accuracy: 1.0000


All-ablated mean logit diff: 0.6993
All-ablated accuracy: 0.5304
Percent drop in logit diff with all-ablation: 44.13%
Percent drop in accuracy with all-ablation: 46.96%
---------------------------------------------------------
Random direction ablation results:
All-ablated mean logit diff: 1.2475
Percent drop in logit diff with all-ablation: 0.34%


### Pythia-2.8b Directional Ablation

#### Model

In [6]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-2.8b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


In [44]:
pos_data_loader, neg_data_loader, balanced_data_loader = prepare_sst_for_model(model, "sst2", 5, 1, 29071, 32725)

Map:   0%|          | 0/7864 [00:00<?, ? examples/s]

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10929 [00:00<?, ? examples/s]

Number of items in dataset: 6169


Saving the dataset (0/1 shards):   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Filter:   0%|          | 0/6169 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3318 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3318 [00:00<?, ? examples/s]

Number of items in pos dataset: 695
Number of items in neg dataset: 695
Number of items in balanced dataset: 1390


In [47]:
directions = load_directions(model, "data/pythia-2.8b-das", direction_prefix="das_simple_train_ADJ_layer")
random_directions = get_random_directions(model)
zeroed_directions = get_zeroed_dir_vector(model)

#### Punctuation Only

In [48]:
model.reset_hooks()

##### Balanced Results

In [49]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)
orig_ld_list_rand, ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    random_directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0520
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8669
Comma-ablated accuracy: 0.8266
Percent drop in logit diff with comma ablation: 17.60%
Percent drop in accuracy with comma ablation: 17.34%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0551
Percent drop in logit diff with comma ablation: -0.30%


In [52]:
ld_deltas = (orig_ld_list - ablated_ld_list).cpu()
avg_distances = compute_average_comma_distances(balanced_data_loader.dataset)
first_comma_distances_tensor = compute_first_comma_distances(balanced_data_loader.dataset)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [53]:
# graph the average distance vs. the logit diff in plotly
fig = px.scatter(x=avg_distances, y=ld_deltas, labels={"x": "Average distance", "y": "Logit diff"}, title="Average distance vs. logit diff", width=800, height=800)
fig.show()

In [54]:
fig = px.scatter(x=first_comma_distances_tensor, y=ld_deltas, labels={"x": "Distance to first comma", "y": "Logit diff"}, title="Distance to first comma vs. logit diff", width=800, height=800)
fig.show()

#### Complete Ablation

In [None]:
model.reset_hooks()

##### Balanced Results

In [None]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"All-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"All-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with all-ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"All-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0286
Original accuracy: 1.0000


All-ablated mean logit diff: 0.2870
All-ablated accuracy: 0.6237
Percent drop in logit diff with all-ablation: 72.10%
Percent drop in accuracy with all-ablation: 37.63%
---------------------------------------------------------
Random direction ablation results:
All-ablated mean logit diff: 1.0268
Percent drop in logit diff with all-ablation: 0.17%


### Pythia-6.9b Directional Ablation

#### Model

In [30]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-6.9b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/42.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Loaded pretrained model EleutherAI/pythia-6.9b into HookedTransformer


In [31]:
pos_data_loader, neg_data_loader, balanced_data_loader = prepare_sst_for_model(model, "sst2", 5, 1, 29071, 32725)

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1007 [00:00<?, ? examples/s]

Number of items in dataset: 243


Saving the dataset (0/1 shards):   0%|          | 0/243 [00:00<?, ? examples/s]

Map:   0%|          | 0/243 [00:00<?, ? examples/s]

Map:   0%|          | 0/243 [00:00<?, ? examples/s]

Map:   0%|          | 0/243 [00:00<?, ? examples/s]

Map:   0%|          | 0/243 [00:00<?, ? examples/s]

  positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)


Filter:   0%|          | 0/243 [00:00<?, ? examples/s]

Filter:   0%|          | 0/153 [00:00<?, ? examples/s]

Filter:   0%|          | 0/153 [00:00<?, ? examples/s]

Number of items in pos dataset: 5
Number of items in neg dataset: 5
Number of items in balanced dataset: 10


In [None]:
directions = load_directions(model, "data/pythia-2.8b-das", direction_prefix="das_simple_train_ADJ_layer")
random_directions = get_random_directions(model)
zeroed_directions = get_zeroed_dir_vector(model)

#### Punctuation Only

In [None]:
model.reset_hooks()

##### Balanced Results

In [None]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)
orig_ld_list_rand, ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    random_directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0286
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8599
Comma-ablated accuracy: 0.8273
Percent drop in logit diff with comma ablation: 16.40%
Percent drop in accuracy with comma ablation: 17.27%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0268
Percent drop in logit diff with comma ablation: 0.17%


In [None]:
ld_deltas = (orig_ld_list - ablated_ld_list).cpu()
avg_distances = compute_average_comma_distances(balanced_data_loader.dataset)
first_comma_distances_tensor = compute_first_comma_distances(balanced_data_loader.dataset)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [None]:
# graph the average distance vs. the logit diff in plotly
fig = px.scatter(x=avg_distances, y=ld_deltas, labels={"x": "Average distance", "y": "Logit diff"}, title="Average distance vs. logit diff", width=800, height=800)
fig.show()

In [None]:
fig = px.scatter(x=first_comma_distances_tensor, y=ld_deltas, labels={"x": "Distance to first comma", "y": "Logit diff"}, title="Distance to first comma vs. logit diff", width=800, height=800)
fig.show()

#### Complete Ablation

In [None]:
model.reset_hooks()

##### Balanced Results

In [None]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"All-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"All-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with all-ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"All-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0286
Original accuracy: 1.0000


All-ablated mean logit diff: 0.2870
All-ablated accuracy: 0.6237
Percent drop in logit diff with all-ablation: 72.10%
Percent drop in accuracy with all-ablation: 37.63%
---------------------------------------------------------
Random direction ablation results:
All-ablated mean logit diff: 1.0268
Percent drop in logit diff with all-ablation: 0.17%


### GPT-2-XL Directional Ablation

#### Model

In [16]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "gpt2-xl"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-xl into HookedTransformer


In [19]:
model.to_tokens(" Negative", prepend_bos=False), model.to_tokens(" Positive", prepend_bos=False)

(tensor([[36183]], device='cuda:0'), tensor([[33733]], device='cuda:0'))

In [24]:
tokenizer("this is test text")

{'input_ids': [5661, 318, 1332, 2420], 'attention_mask': [1, 1, 1, 1]}

In [29]:
from transformer_lens.utils import test_prompt
test_prompt("This movie sucks! Review Sentiment:", " Negative", model, top_k=10)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' sucks', '!', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Negative']


Top 0th token. Logit: 14.84 Prob:  8.39% Token: | I|
Top 1th token. Logit: 14.66 Prob:  6.99% Token: | 9|
Top 2th token. Logit: 14.47 Prob:  5.81% Token: | 5|
Top 3th token. Logit: 14.17 Prob:  4.29% Token: | L|
Top 4th token. Logit: 14.00 Prob:  3.63% Token: | 4|
Top 5th token. Logit: 13.94 Prob:  3.39% Token: | 6|
Top 6th token. Logit: 13.84 Prob:  3.07% Token: | The|
Top 7th token. Logit: 13.81 Prob:  3.00% Token: | 8|
Top 8th token. Logit: 13.77 Prob:  2.88% Token: | 0|
Top 9th token. Logit: 13.54 Prob:  2.29% Token: | 7|


In [26]:
pos_data_loader, neg_data_loader, balanced_data_loader = prepare_sst_for_model(model, "sst2", 5, 1, 33733, 36183)

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1007 [00:00<?, ? examples/s]

Number of items in dataset: 0


Saving the dataset (0/1 shards): 0 examples [00:00, ? examples/s]

ValueError: Original column name input_ids not in the dataset. Current columns in the dataset: ['text', 'label', 'keep_example']

In [None]:
directions = load_directions(model, "data/pythia-2.8b-das", direction_prefix="das_simple_train_ADJ_layer")
random_directions = get_random_directions(model)
zeroed_directions = get_zeroed_dir_vector(model)

#### Punctuation Only

In [None]:
model.reset_hooks()

##### Balanced Results

In [None]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)
orig_ld_list_rand, ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    random_directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0286
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8599
Comma-ablated accuracy: 0.8273
Percent drop in logit diff with comma ablation: 16.40%
Percent drop in accuracy with comma ablation: 17.27%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0268
Percent drop in logit diff with comma ablation: 0.17%


#### Complete Ablation

In [None]:
model.reset_hooks()

##### Balanced Results

In [None]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(
    model, 
    balanced_data_loader,
    layers_to_ablate,
    heads_to_freeze,
    directions, 
    1.0, 
)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"All-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"All-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with all-ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"All-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with all-ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0286
Original accuracy: 1.0000


All-ablated mean logit diff: 0.2870
All-ablated accuracy: 0.6237
Percent drop in logit diff with all-ablation: 72.10%
Percent drop in accuracy with all-ablation: 37.63%
---------------------------------------------------------
Random direction ablation results:
All-ablated mean logit diff: 1.0268
Percent drop in logit diff with all-ablation: 0.17%
