In [1]:
!ls

README.md		    miniconda.sh  quick_start_pytorch.ipynb   wandb
eliciting-latent-sentiment  miniconda3	  quick_start_pytorch_images


In [2]:
%cd eliciting-latent-sentiment

/notebooks/eliciting-latent-sentiment


In [3]:
#!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install fancy_einsum==0.0.3
!pip install transformer_lens
!pip install jaxtyping==0.2.13
!pip install einops
!pip install protobuf==3.20.*
!pip install plotly
!pip install torchtyping
!pip install git+https://github.com/neelnanda-io/neel-plotly.git
!pip install circuitsvis
# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
# %pip install git+https://github.com/neelnanda-io/PySvelte.git
# %pip install typeguard==2.13.3

Collecting fancy_einsum==0.0.3
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)


Installing collected packages: fancy_einsum
Successfully installed fancy_einsum-0.0.3
[0mCollecting transformer_lens
  Downloading transformer_lens-1.6.1-py3-none-any.whl (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.9/109.9 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting transformers>=4.25.1
  Downloading transformers-4.33.2-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m104.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting einops>=0.6.0
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
Collecting beartype<0.15.0,>=0.14.1
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m91.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1
  Downloading datasets-2

In [4]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

In [5]:
import einops
from functools import partial
import torch
import datasets
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text
from utils.circuit_analysis import get_logit_diff

In [6]:
def find_positions(tensor, token_ids=[11, 13]):
    positions = []
    for batch_item in tensor:
        token_positions = {token_id: [] for token_id in token_ids}
        for position, token in enumerate(batch_item):
            if token.item() in token_ids:
                token_positions[token.item()].append(position)
        positions.append([token_positions[token_id] for token_id in token_ids])
    return positions

In [7]:
def zero_attention_pos_hook(
    pattern: Float[Tensor, "batch head seq_Q seq_K"], hook: HookPoint,
    pos_by_batch: List[List[int]], layer: int = 0, head_idx: int = 0,
) -> Float[Tensor, "batch head seq_Q seq_K"]:
    """Zero-ablates an attention pattern tensor at a particular position"""
    assert 'pattern' in hook.name

    batch_size = pattern.shape[0]
    assert len(pos_by_batch) == batch_size

    for i in range(batch_size):
        for p in pos_by_batch[i]:
            pattern[i, head_idx, p, p] = 0
            
    return pattern

In [8]:
def names_filter(name: str):
    """Filter for the names of the activations we want to keep to study the resid stream."""
    return name.endswith('resid_post') or name == get_act_name('resid_pre', 0)

def get_layerwise_token_mean_activations(model: HookedTransformer, data_loader: DataLoader, token_id: int = 13) -> Float[Tensor, "layer d_model"]:
    """Get the mean value of a token across layers"""
    num_layers = model.cfg.n_layers
    d_model = model.cfg.d_model
    
    activation_sums = torch.stack([torch.zeros(d_model) for _ in range(num_layers)]).to(device)
    comma_counts = [0] * num_layers

    print(activation_sums.shape)

    token_mean_values = torch.zeros((num_layers, d_model))
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        
        batch_tokens = batch_value['tokens'].to(device)

        # get positions of all 11 and 13 token ids in batch
        punct_pos = find_positions(batch_tokens, token_ids=[13])

        _, cache = model.run_with_cache(
            batch_tokens, 
            names_filter=names_filter
        )

        
        for i in range(batch_tokens.shape[0]):
            for p in punct_pos[i][0]:
                for layer in range(num_layers):
                    activation_sums[layer] += cache[f"blocks.{layer}.hook_resid_post"][i, p, :]
                    comma_counts[layer] += 1

    for layer in range(num_layers):
        token_mean_values[layer] = activation_sums[layer] / comma_counts[layer]

    return token_mean_values

In [9]:
def compute_zeroed_attn_modified_loss(model: HookedTransformer, data_loader: DataLoader) -> float:
    total_loss = 0
    loss_list = []
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        batch_tokens = batch_value['tokens'].to(device)

        # get positions of all 11 and 13 token ids in batch
        punct_pos = find_positions(batch_tokens, token_ids=[13])

        # get the loss for each token in the batch
        initial_loss = model(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)
        
        # add hooks for the activations of the 11 and 13 tokens
        for layer, head in heads_to_ablate:
            ablate_punct = partial(zero_attention_pos_hook, pos_by_batch=punct_pos, layer=layer, head_idx=head)
            model.blocks[layer].attn.hook_pattern.add_hook(ablate_punct)

        # get the loss for each token when run with hooks
        hooked_loss = model(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)

        # compute the percent difference between the two losses
        loss_diff = (hooked_loss - initial_loss) / initial_loss

        loss_list.append(loss_diff)

    model.reset_hooks()
    return loss_list, batch_tokens

In [10]:
from utils.ablation import ablate_resid_with_precalc_mean

def compute_mean_ablation_modified_loss(model: HookedTransformer, data_loader: DataLoader, cached_means, target_token_ids) -> float:
    total_loss = 0
    loss_diff_list = []
    orig_loss_list = []
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        if isinstance(batch_value['tokens'], list):
            batch_tokens = torch.stack(batch_value['tokens']).to(device)
        else:
            batch_tokens = batch_value['tokens'].to(device)

        batch_tokens = einops.rearrange(batch_tokens, 'seq batch -> batch seq')
        punct_pos = batch_value['positions']
        print(f"punct_pos: {punct_pos}")

        # get the loss for each token in the batch
        initial_loss = model(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)
        print(f"initial loss shape: {initial_loss.shape}")
        orig_loss_list.append(initial_loss)
        
        # add hooks for the activations of the 11 and 13 tokens
        for layer, head in heads_to_ablate:
            mean_ablate_comma = partial(ablate_resid_with_precalc_mean_no_batch, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)

        # get the loss for each token when run with hooks
        print(f"batch tokens shape: {batch_tokens.shape}")
        
        hooked_loss = model(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)
        print(f"hooked loss shape: {hooked_loss.shape}")

        # compute the difference between the two losses
        loss_diff = hooked_loss - initial_loss
        
        # set all positions right after punct_pos to zero
        for p in punct_pos:
            print(f"zeroing {p}")
            loss_diff[0, p] = 0

        loss_diff_list.append(loss_diff)

    model.reset_hooks()
    return loss_diff_list, orig_loss_list

In [11]:
def ablate_resid_with_precalc_mean_no_batch(
    component: Float[Tensor, "batch ..."],
    hook: HookPoint,
    cached_means: Float[Tensor, "layer ..."],
    pos_by_batch: List[Tensor],
    layer: int = 0,
) -> Float[Tensor, "batch ..."]:
    """
    Mean-ablates a batch tensor

    :param component: the tensor to compute the mean over the batch dim of
    :return: the mean over the cache component of the tensor
    """
    assert 'resid' in hook.name

    #print(f"batch size: {batch_size} pos_by_batch: {len(pos_by_batch)}")

    for p in pos_by_batch:
        component[:, p] = cached_means[layer]
            
    return component

In [12]:
def convert_to_tensors(dataset, column_name='tokens'):
    token_buffer = []
    final_batches = []
    
    for batch in dataset:
        trimmed_batch = batch[column_name] #[batch[column_name][0]] + [token for token in batch[column_name] if token != 0]
        final_batches.append(trimmed_batch)
    
    # Convert list of batches to tensors
    final_batches = [torch.tensor(batch, dtype=torch.long) for batch in final_batches]
    # Create a new dataset with specified features
    features = Features({"tokens": Sequence(Value("int64"))})
    final_dataset = Dataset.from_dict({"tokens": final_batches}, features=features)

    final_dataset.set_format(type="torch", columns=["tokens"])
    
    return final_dataset

In [13]:
def plot_neuroscope(
    tokens: Int[Tensor, "batch pos"], centred: bool = False, activations: Float[Tensor, "pos layer 1"] = None,
    verbose=False,
):
    
    str_tokens = model.to_str_tokens(tokens, prepend_bos=False)

    if verbose:
        print(f"Tokens shape: {tokens.shape}")
  
    if centred:
        if verbose:
            print("Centering activations")
        layer_means = einops.reduce(activations, "pos layer 1 -> 1 layer 1", reduction="mean")
        layer_means = einops.repeat(layer_means, "1 layer 1 -> pos layer 1", pos=activations.shape[0])
        activations -= layer_means
    elif verbose:
        print("Activations already centered")
    assert (
        activations.ndim == 3
    ), f"activations must be of shape [tokens x layers x neurons], found {activations.shape}"
    assert len(str_tokens) == activations.shape[0], (
        f"tokens and activations must have the same length, found tokens={len(str_tokens)} and acts={activations.shape[0]}, "
        f"tokens={str_tokens}, "
        f"activations={activations.shape}"

    )
    return text_neuron_activations(
        tokens=str_tokens, 
        activations=activations,
        first_dimension_name="Layer (resid_pre)",
        second_dimension_name="Model",
        second_dimension_labels=["pythia-2.8b"],
    )

## Comma Ablation on Natural Text

In [14]:
torch.set_grad_enabled(False)
device = "cuda"
MODEL_NAME = "EleutherAI/pythia-2.8b"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
model.name = MODEL_NAME

Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


In [22]:
from datasets import load_from_disk
sst_data = load_from_disk("sst2")
sst_data

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 7864
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2058
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 1007
    })
})

In [23]:
sst_data['train'][0]['text'], sst_data['train'][0]['label']

("The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal.",
 1)

In [146]:
def filter_function(example):
    prompt = model.to_tokens(example['text'] + " Review Sentiment:", prepend_bos=False)
    answer = torch.tensor([29071, 32725]).unsqueeze(0).unsqueeze(0).to(device) if example['label'] == 1 else torch.tensor([32725, 29071]).unsqueeze(0).unsqueeze(0).to(device)
    logits, cache = model.run_with_cache(prompt)
    logit_diff = get_logit_diff(logits, answer)
    
    # Determine if the top answer (index 0) token is in top 10 logits
    _, top_indices = logits.topk(10, dim=-1)  # Get indices of top 10 logits
    top_answer_token = answer[0, 0, 0]  # Assuming answer is of shape (1, 1, 2) and the top answer token is at index 0
    is_top_answer_in_top_10_logits = (top_indices == top_answer_token).any()
    
    # Add a new field 'keep_example' to the example
    example['keep_example'] = (logit_diff > 0.0) and is_top_answer_in_top_10_logits.item()
    return example

# Use the map function to apply the filter_function
sst_data_with_flag_train = sst_data['train'].map(filter_function, keep_in_memory=True)
sst_data_with_flag_val = sst_data['dev'].map(filter_function, keep_in_memory=True)
sst_data_with_flag_test = sst_data['test'].map(filter_function, keep_in_memory=True)
sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_val, sst_data_with_flag_test])

# Use the filter function to keep only the examples where 'keep_example' is True
sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])

# save dataset
sst_zero_shot.save_to_disk("sst_zero_shot")


Map:   0%|          | 0/7864 [00:00<?, ? examples/s]

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10929 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6169 [00:00<?, ? examples/s]

In [147]:
sst_zero_shot

Dataset({
    features: ['text', 'label', 'keep_example'],
    num_rows: 6169
})

In [21]:
from datasets import load_from_disk
sst_zero_shot = load_from_disk("sst_zero_shot")

In [22]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


# Define the batch size
BATCH_SIZE = 5

# Load a tokenizer (you'll need to specify the appropriate model)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b")
# set padding token
tokenizer.pad_token = 1

#dataset = text_dataset.map(lambda x: tokenize_and_concatenate(x, tokenizer))

def concatenate_classification_prompts(examples):
    return {"text": (examples['text'] + " Review Sentiment:")}


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)


def find_dataset_positions(example, token_id=13):
    # Create a tensor of zeros with the same shape as example['tokens']
    positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)

    # Find positions where tokens match the given token_id
    positions[example['tokens'] == token_id] = 1
    has_token = True if positions.sum() > 0 else False

    return {'positions': positions, 'has_token': has_token}

def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):
    if example['label'] == 1:
        answers = torch.tensor([pos_answer_id, neg_answer_id])
    else:
        answers = torch.tensor([neg_answer_id, pos_answer_id])

    return {'answers': answers}


dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)
dataset = dataset.map(tokenize_function, batched=False)
dataset = dataset.map(convert_answers, batched=False)
dataset = dataset.rename_column("input_ids", "tokens")
dataset.set_format(type="torch", columns=["tokens", "attention_mask", "label", "answers"])
dataset = dataset.map(find_dataset_positions, batched=False)
dataset = dataset.filter(lambda example: example['has_token']==True)
dataset


Dataset({
    features: ['text', 'label', 'keep_example', 'tokens', 'attention_mask', 'answers', 'positions', 'has_token'],
    num_rows: 3318
})

In [23]:
model.to_string(dataset[0]['tokens']), model.to_str_tokens(dataset[0]['answers'][0])

("The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal. Review Sentiment:1111111111111",
 [' Positive'])

In [24]:
from utils.circuit_analysis import get_logit_diff
logits, cache = model.run_with_cache(dataset['tokens'][0])
get_logit_diff(logits, dataset['answers'][0].unsqueeze(0).unsqueeze(0).to(device))

tensor(1.4662, device='cuda:0')

In [25]:
# create a subset with only positive labels
pos_dataset = dataset.filter(lambda example: example['label']==1)
neg_dataset = dataset.filter(lambda example: example['label']==0)
len(pos_dataset), len(neg_dataset)

(2623, 695)

In [26]:
import random
from datasets import concatenate_datasets

def get_random_subset(dataset, n):
    total_size = len(dataset)
    random_indices = random.sample(range(total_size), n)
    return dataset.select(random_indices)

pos_subset = get_random_subset(pos_dataset, 695)
neg_subset = get_random_subset(neg_dataset, 695)
balanced_subset = concatenate_datasets([pos_subset, neg_subset])
# randomize the order of balanced_subset
balanced_subset = balanced_subset.shuffle(len(balanced_subset))

# Create a new dataloader from the subset, converting the data to tensors
pos_data_loader = DataLoader(
    pos_subset, batch_size=5, shuffle=False, drop_last=True
)
neg_data_loader = DataLoader(
    neg_subset, batch_size=5, shuffle=False, drop_last=True
)
balanced_data_loader = DataLoader(
    balanced_subset, batch_size=5, shuffle=False, drop_last=True
)

# save datasets
pos_subset.save_to_disk("sst_pos_subset")
neg_subset.save_to_disk("sst_neg_subset")
balanced_subset.save_to_disk("sst_balanced_subset")

len(pos_data_loader), len(neg_data_loader)


Saving the dataset (0/1 shards):   0%|          | 0/695 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/695 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1390 [00:00<?, ? examples/s]

(139, 139)

### Apply Mean Ablation

In [39]:
comma_mean_bal_values = get_layerwise_token_mean_activations(model, balanced_data_loader, token_id=13)
comma_mean_neg_values = get_layerwise_token_mean_activations(model, neg_data_loader, token_id=13)
comma_mean_pos_values = get_layerwise_token_mean_activations(model, pos_data_loader, token_id=13)
save_array(comma_mean_bal_values, 'comma_balanced_mean_values.npy', model)
save_array(comma_mean_neg_values, 'comma_neg_mean_values.npy', model)
save_array(comma_mean_pos_values, 'comma_pos_mean_values.npy', model)

torch.Size([32, 2560])


  0%|          | 0/278 [00:00<?, ?it/s]

torch.Size([32, 2560])


  0%|          | 0/139 [00:00<?, ?it/s]

torch.Size([32, 2560])


  0%|          | 0/139 [00:00<?, ?it/s]

'data/pythia-2.8b/comma_pos_mean_values.npy'

In [40]:
# period_mean_values = get_layerwise_token_mean_activations(model, train_data_loader, token_id=15)
# save_array(period_mean_values, 'period_mean_values.npy', model)

In [41]:
# load the files
owt_mean_values = torch.from_numpy(load_array('comma_mean_values.npy', model)).to(device)
comma_mean_bal_values = torch.from_numpy(load_array('comma_balanced_mean_values.npy', model)).to(device)
comma_mean_neg_values = torch.from_numpy(load_array('comma_neg_mean_values.npy', model)).to(device)
comma_mean_pos_values = torch.from_numpy(load_array('comma_pos_mean_values.npy', model)).to(device)
#period_mean_values = torch.from_numpy(load_array('period_mean_values.npy', model)).to(device)

In [32]:
def compute_last_position_logit_diff(logits, mask, answer):
    """
    Parameters:
    - logits: Tensor of shape (batch, sequence_position, logits)
    - mask: Tensor of shape (batch, sequence_position)
    - answer: Tensor of shape (batch, 2)

    Returns:
    - logit_diff: Tensor of shape (batch,)
    """
    # Find the last unmasked sequence position for each item in the batch
    last_unmasked_positions = mask.sum(dim=1) - 1  # Subtract 1 to get zero-based index
    #print(last_unmasked_positions)

    # Extract the logits for the last unmasked positions
    last_logits = logits[torch.arange(logits.size(0)), last_unmasked_positions]
    #print(f"last logits: {last_logits.shape}")
    #print(f"last logits shape: {last_logits.shape}")

    # Extract the logits for the correct and incorrect answers
    correct_logits = last_logits[torch.arange(last_logits.size(0)), answer[:, 0]]
    #print(f"correct logits shape: {correct_logits.shape}")
    incorrect_logits = last_logits[torch.arange(last_logits.size(0)), answer[:, 1]]

    # Compute the logit differences
    logit_diff = correct_logits - incorrect_logits
    #print(f"logit diff shape: {logit_diff.shape}")

    return logit_diff


In [172]:
def ablate_resid_with_precalc_mean(
    component: Float[Tensor, "batch ..."],
    hook: HookPoint,
    cached_means: Float[Tensor, "layer ..."],
    pos_by_batch: Float[Tensor, "batch ..."],
    layer: int = 0,
) -> Float[Tensor, "batch ..."]:
    """
    Mean-ablates a batch tensor

    :param component: the tensor to compute the mean over the batch dim of
    :return: the mean over the cache component of the tensor
    """
    assert 'resid' in hook.name

    # Identify the positions where pos_by_batch is 1
    batch_indices, sequence_positions = torch.where(pos_by_batch == 1)

    # Replace the corresponding positions in component with cached_means[layer]
    component[batch_indices, sequence_positions] = cached_means[layer]

    return component

In [173]:
def compute_mean_ablation_modified_logit_diff(model: HookedTransformer, data_loader: DataLoader, cached_means, target_token_ids) -> float:
    
    orig_ld_list = []
    ablated_ld_list = []
    freeze_ablated_ld_list = []
    
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        batch_tokens = batch_value['tokens'].to(device)
        punct_pos = batch_value['positions'].to(device)

        # get the logit diff for the last token in each sequence
        orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type="logits", prepend_bos=False)
        orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])
        orig_ld_list.append(orig_ld)
        
        # repeat with commas ablated
        for layer in layers_to_ablate:
            mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)
       
        ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        ablated_ld_list.append(ablated_ld)
        
        model.reset_hooks()

        # repeat with attention frozen and commas ablated
        for layer, head in heads_to_freeze:
            freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)
            model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)

        for layer in layers_to_ablate:
            mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)
       
        freeze_ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        freeze_ablated_ld_list.append(freeze_ablated_ld)
        
        model.reset_hooks()

    return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)

#### Comma Mean Ablation

##### Positive Prompt Results

In [194]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, pos_data_loader, comma_mean_bal_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.5158
Comma-ablated accuracy: 0.9971
Percent drop in logit diff with comma ablation: 1.41%
Percent drop in accuracy with comma ablation: 0.29%


Attn frozen, comma-ablated mean logit diff: 1.4686
Attn frozen, comma-ablated accuracy: 0.9971
Percent drop in logit diff with attn frozen, comma ablation: 4.48%
Percent drop in accuracy with attn frozen, comma ablation: 0.29%


##### Negative Prompt Results

In [195]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, neg_data_loader, comma_mean_bal_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.1785
Comma-ablated accuracy: 0.6259
Percent drop in logit diff with comma ablation: 65.39%
Percent drop in accuracy with comma ablation: 37.41%


Attn frozen, comma-ablated mean logit diff: 0.3691
Attn frozen, comma-ablated accuracy: 0.8619
Percent drop in logit diff with attn frozen, comma ablation: 28.45%
Percent drop in accuracy with attn frozen, comma ablation: 13.81%


In [196]:
model.reset_hooks()

##### Balanced Results

In [197]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, balanced_data_loader, comma_mean_bal_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0267
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8472
Comma-ablated accuracy: 0.8115
Percent drop in logit diff with comma ablation: 17.48%
Percent drop in accuracy with comma ablation: 18.85%


Attn frozen, comma-ablated mean logit diff: 0.9188
Attn frozen, comma-ablated accuracy: 0.9295
Percent drop in logit diff with attn frozen, comma ablation: 10.50%
Percent drop in accuracy with attn frozen, comma ablation: 7.05%


#### Comma OWT Mean Ablation

##### Positive Prompt Results

In [174]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, pos_data_loader, owt_mean_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.3962
Comma-ablated accuracy: 0.9942
Percent drop in logit diff with comma ablation: 9.18%
Percent drop in accuracy with comma ablation: 0.58%


Attn frozen, comma-ablated mean logit diff: 1.4692
Attn frozen, comma-ablated accuracy: 0.9957
Percent drop in logit diff with attn frozen, comma ablation: 4.44%
Percent drop in accuracy with attn frozen, comma ablation: 0.43%


##### Negative Prompt Results

In [175]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, neg_data_loader, owt_mean_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.3262
Comma-ablated accuracy: 0.7942
Percent drop in logit diff with comma ablation: 36.78%
Percent drop in accuracy with comma ablation: 20.58%


Attn frozen, comma-ablated mean logit diff: 0.3294
Attn frozen, comma-ablated accuracy: 0.8288
Percent drop in logit diff with attn frozen, comma ablation: 36.14%
Percent drop in accuracy with attn frozen, comma ablation: 17.12%


In [176]:
model.reset_hooks()

##### Balanced Results

In [177]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, balanced_data_loader, owt_mean_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0267
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8612
Comma-ablated accuracy: 0.8942
Percent drop in logit diff with comma ablation: 16.12%
Percent drop in accuracy with comma ablation: 10.58%


Attn frozen, comma-ablated mean logit diff: 0.8993
Attn frozen, comma-ablated accuracy: 0.9122
Percent drop in logit diff with attn frozen, comma ablation: 12.40%
Percent drop in accuracy with attn frozen, comma ablation: 8.78%


#### Comma Patch-Ablation

##### Positive Prompt Results

In [178]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, pos_data_loader, comma_mean_neg_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.5633
Comma-ablated accuracy: 0.9986
Percent drop in logit diff with comma ablation: -1.68%
Percent drop in accuracy with comma ablation: 0.14%


Attn frozen, comma-ablated mean logit diff: 1.3832
Attn frozen, comma-ablated accuracy: 0.9942
Percent drop in logit diff with attn frozen, comma ablation: 10.03%
Percent drop in accuracy with attn frozen, comma ablation: 0.58%


##### Negative Prompt Results

In [179]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, neg_data_loader, comma_mean_pos_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.2403
Comma-ablated accuracy: 0.6806
Percent drop in logit diff with comma ablation: 53.43%
Percent drop in accuracy with comma ablation: 31.94%


Attn frozen, comma-ablated mean logit diff: 0.2797
Attn frozen, comma-ablated accuracy: 0.7597
Percent drop in logit diff with attn frozen, comma ablation: 45.78%
Percent drop in accuracy with attn frozen, comma ablation: 24.03%


In [198]:
model.reset_hooks()

##### Balanced Results

In [181]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_mean_ablation_modified_logit_diff(model, balanced_data_loader, comma_mean_bal_values, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0267
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8472
Comma-ablated accuracy: 0.8115
Percent drop in logit diff with comma ablation: 17.48%
Percent drop in accuracy with comma ablation: 18.85%


Attn frozen, comma-ablated mean logit diff: 0.9188
Attn frozen, comma-ablated accuracy: 0.9295
Percent drop in logit diff with attn frozen, comma ablation: 10.50%
Percent drop in accuracy with attn frozen, comma ablation: 7.05%


### Apply Directional Ablation

In [61]:
import numpy as np

def load_directions(model, direction_folder):
    directions = []
    for i in range(model.cfg.n_layers):
        dir = np.load(f"{direction_folder}/das_simple_train_ADJ_layer{i}.npy")
        if len(dir.shape) == 2:
            dir = dir[:, 0]
        directions.append(torch.tensor(dir))

    # convert to tensor
    directions = torch.stack(directions).to(device)

    return directions

def get_random_directions(model, num_layers):
    directions = []
    for i in range(num_layers):
        dir = torch.randn(model.cfg.d_model).to(device)
        directions.append(dir)

    # convert to tensor
    directions = torch.stack(directions).to(device)

    return directions

def get_zeroed_dir_vector(model, directions):
    zeroed_directions = []
    for i in range(model.cfg.n_layers):
        dir = torch.zeros(model.cfg.d_model).to(device)
        zeroed_directions.append(dir)

    # convert to tensor
    zeroed_directions = torch.stack(zeroed_directions).to(device)

    return zeroed_directions

directions = load_directions(model, "data/pythia-2.8b-das")
random_directions = get_random_directions(model, model.cfg.n_layers)
zeroed_directions = get_zeroed_dir_vector(model, directions)

#### Punctuation Only

In [33]:
def ablate_resid_with_direction(
    component: torch.Tensor,
    hook: HookPoint,
    direction_vector: torch.Tensor,
    labels: torch.Tensor,
    multiplier: float = 1.0,
    pos_by_batch: torch.Tensor = None,
    layer: int = 0,
) -> torch.Tensor:
    """
    Ablates a batch tensor by removing the influence of a direction vector from it.

    Args:
        component: the tensor to compute the mean over the batch dim of
        direction_vector: the direction vector to remove from the component
        multiplier: the multiplier to apply to the direction vector
        pos_by_batch: the positions to ablate
        layer: the layer to ablate

    Returns:
        the ablated component
    """
    assert 'resid' in hook.name

    # Normalize the direction vector to make sure it's a unit vector
    D_normalized = direction_vector[layer] / torch.norm(direction_vector[layer])

    # Calculate the projection of component onto direction_vector
    proj = einops.einsum(component, D_normalized, "b s d, d -> b s").unsqueeze(-1) * D_normalized
    

    # Ablate the direction from component
    component_ablated = component.clone()  # Create a copy to ensure original is not modified
    if pos_by_batch is not None:
        batch_indices, sequence_positions = torch.where(pos_by_batch == 1)
        component_ablated[batch_indices, sequence_positions] = component[batch_indices, sequence_positions] - multiplier * proj[batch_indices, sequence_positions]
        
        # Print the (batch, pos) coordinates of all d_model vectors that were ablated
        # for b, s in zip(batch_indices, sequence_positions):
        #     print(f"(batch, pos) = ({b.item()}, {s.item()})")

        # Check that positions not in (batch_indices, sequence_positions) were not ablated
        check_mask = torch.ones_like(component, dtype=torch.bool)
        check_mask[batch_indices, sequence_positions] = 0
        if not torch.all(component[check_mask] == component_ablated[check_mask]):
            raise ValueError("Positions outside of specified (batch_indices, sequence_positions) were ablated!")

    return component_ablated

In [34]:
def compute_directional_ablation_modified_logit_diff(model: HookedTransformer, data_loader: DataLoader, direction_vectors, multiplier=1.0, target_token_ids=13) -> float:
    
    orig_ld_list = []
    ablated_ld_list = []
    freeze_ablated_ld_list = []
    
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        batch_tokens = batch_value['tokens'].to(device)
        labels = batch_value['label'].to(device)
        punct_pos = batch_value['positions'].to(device)

        # get the logit diff for the last token in each sequence
        orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type="logits", prepend_bos=False)
        orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])
        orig_ld_list.append(orig_ld)
        
        # repeat with commas ablated
        for layer in layers_to_ablate:
            dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)
       
        ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        # check to see if ablated_logits has any nan values
        if torch.isnan(ablated_logits).any():
            print("ablated logits has nan values")
        ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        ablated_ld_list.append(ablated_ld)
        
        model.reset_hooks()

        # repeat with attention frozen and commas ablated
        for layer, head in heads_to_freeze:
            freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)
            model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)

        for layer in layers_to_ablate:
            dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)
       
        freeze_ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        freeze_ablated_ld_list.append(freeze_ablated_ld)
        
        model.reset_hooks()

    return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)

In [35]:
model.reset_hooks()

##### Positive Prompt Results

In [186]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff(model, pos_data_loader, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand, freeze_ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(model, pos_data_loader, random_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.4088
Comma-ablated accuracy: 0.9597
Percent drop in logit diff with comma ablation: 8.37%
Percent drop in accuracy with comma ablation: 4.03%


Attn frozen, comma-ablated mean logit diff: 1.5050
Attn frozen, comma-ablated accuracy: 0.9971
Percent drop in logit diff with attn frozen, comma ablation: 2.11%
Percent drop in accuracy with attn frozen, comma ablation: 0.29%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.5306
Percent drop in logit diff with comma ablation: 0.44%
Attn frozen, comma-ablated mean logit diff: 1.5360
Percent drop in logit diff with attn frozen, comma ablation: 0.09%


##### Negative Prompt Results

In [187]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff(model, neg_data_loader, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand, freeze_ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(model, neg_data_loader, random_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")


  0%|          | 0/139 [00:00<?, ?it/s]

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.2726
Comma-ablated accuracy: 0.6820
Percent drop in logit diff with comma ablation: 47.16%
Percent drop in accuracy with comma ablation: 31.80%


Attn frozen, comma-ablated mean logit diff: 0.4758
Attn frozen, comma-ablated accuracy: 0.9669
Percent drop in logit diff with attn frozen, comma ablation: 7.77%
Percent drop in accuracy with attn frozen, comma ablation: 3.31%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 0.5198
Percent drop in logit diff with comma ablation: -0.76%
Attn frozen, comma-ablated mean logit diff: 0.5162
Percent drop in logit diff with attn frozen, comma ablation: -0.06%


##### Balanced Results

In [36]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff(model, balanced_data_loader, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand, freeze_ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(model, balanced_data_loader, random_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0282
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8556
Comma-ablated accuracy: 0.8216
Percent drop in logit diff with comma ablation: 16.79%
Percent drop in accuracy with comma ablation: 17.84%


Attn frozen, comma-ablated mean logit diff: 0.9921
Attn frozen, comma-ablated accuracy: 0.9835
Percent drop in logit diff with attn frozen, comma ablation: 3.52%
Percent drop in accuracy with attn frozen, comma ablation: 1.65%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0258
Percent drop in logit diff with comma ablation: 0.23%
Attn frozen, comma-ablated mean logit diff: 1.0265
Percent drop in logit diff with attn frozen, comma ablation: 0.17%


#### Punctuation - Inverse Directional Ablation

In [62]:
def ablate_resid_with_mean_and_direction(
    component: torch.Tensor,
    hook: HookPoint,
    direction_vector: torch.Tensor,
    cached_means: torch.Tensor,
    pos_by_batch: torch.Tensor,
    layer: int = 0,
    multiplier: float = 1.0,
) -> torch.Tensor:
    """
    Ablates a batch tensor by first removing the influence of a direction vector,
    then replacing with the mean, and finally adding the projection back.

    Args:
        component: the tensor to compute the mean over the batch dim of
        direction_vector: the direction vector to remove from the component
        cached_means: the cached means tensor
        pos_by_batch: the positions to ablate
        layer: the layer to ablate
        multiplier: the multiplier to apply to the direction vector

    Returns:
        the ablated component
    """
    assert 'resid' in hook.name

    # Normalize the direction vector to make sure it's a unit vector
    D_normalized = direction_vector[layer] / torch.norm(direction_vector[layer])

    # Calculate the projection of component onto direction_vector
    proj = einops.einsum(component, D_normalized, "b s d, d -> b s").unsqueeze(-1) * D_normalized

    # Create a copy to ensure the original is not modified
    component_ablated = component.clone()

    # Identify the positions where pos_by_batch is 1
    batch_indices, sequence_positions = torch.where(pos_by_batch == 1)

    # Step 1: Replace the corresponding positions in component_ablated with cached_means[layer]
    component_ablated[batch_indices, sequence_positions] = cached_means[layer]

    # Step 2: Add the projection back to the component at specified positions
    component_ablated[batch_indices, sequence_positions] += proj[batch_indices, sequence_positions]

    # Check that positions not in (batch_indices, sequence_positions) were not ablated
    check_mask = torch.ones_like(component, dtype=torch.bool)
    check_mask[batch_indices, sequence_positions] = 0
    if not torch.all(component[check_mask] == component_ablated[check_mask]):
        raise ValueError("Positions outside of specified (batch_indices, sequence_positions) were ablated!")

    return component_ablated


In [58]:
def compute_inverse_directional_ablation_modified_logit_diff(model: HookedTransformer, data_loader: DataLoader, direction_vectors, cached_means, multiplier=1.0, target_token_ids=13) -> float:
    
    orig_ld_list = []
    ablated_ld_list = []
    freeze_ablated_ld_list = []
    
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        batch_tokens = batch_value['tokens'].to(device)
        labels = batch_value['label'].to(device)
        punct_pos = batch_value['positions'].to(device)

        # get the logit diff for the last token in each sequence
        orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type="logits", prepend_bos=False)
        orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])
        orig_ld_list.append(orig_ld)
        
        # repeat with modified ablation
        for layer in layers_to_ablate:
            dir_ablate_comma = partial(ablate_resid_with_mean_and_direction, direction_vector=direction_vectors, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer, multiplier=multiplier)
            model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)
       
        ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        if torch.isnan(ablated_logits).any():
            print("ablated logits has nan values")
        ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        ablated_ld_list.append(ablated_ld)
        
        model.reset_hooks()

    return torch.cat(orig_ld_list), torch.cat(ablated_ld_list)

In [59]:
model.reset_hooks()

##### Positive Prompt Results

In [None]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff(model, pos_data_loader, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand, freeze_ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(model, pos_data_loader, random_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 1.4088
Comma-ablated accuracy: 0.9597
Percent drop in logit diff with comma ablation: 8.37%
Percent drop in accuracy with comma ablation: 4.03%


Attn frozen, comma-ablated mean logit diff: 1.5050
Attn frozen, comma-ablated accuracy: 0.9971
Percent drop in logit diff with attn frozen, comma ablation: 2.11%
Percent drop in accuracy with attn frozen, comma ablation: 0.29%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.5306
Percent drop in logit diff with comma ablation: 0.44%
Attn frozen, comma-ablated mean logit diff: 1.5360
Percent drop in logit diff with attn frozen, comma ablation: 0.09%


##### Negative Prompt Results

In [None]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff(model, neg_data_loader, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand, freeze_ablated_ld_list_rand = compute_directional_ablation_modified_logit_diff(model, neg_data_loader, random_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")


  0%|          | 0/139 [00:00<?, ?it/s]

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.2726
Comma-ablated accuracy: 0.6820
Percent drop in logit diff with comma ablation: 47.16%
Percent drop in accuracy with comma ablation: 31.80%


Attn frozen, comma-ablated mean logit diff: 0.4758
Attn frozen, comma-ablated accuracy: 0.9669
Percent drop in logit diff with attn frozen, comma ablation: 7.77%
Percent drop in accuracy with attn frozen, comma ablation: 3.31%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 0.5198
Percent drop in logit diff with comma ablation: -0.76%
Attn frozen, comma-ablated mean logit diff: 0.5162
Percent drop in logit diff with attn frozen, comma ablation: -0.06%


##### Balanced Results

In [63]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list = compute_inverse_directional_ablation_modified_logit_diff(model, balanced_data_loader, comma_mean_bal_values, directions, 1.0, 13)
orig_ld_list_rand, ablated_ld_list_rand = compute_inverse_directional_ablation_modified_logit_diff(model, balanced_data_loader, comma_mean_bal_values, zeroed_directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()

ablated_accuracy_rand = (ablated_ld_list_rand > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
# print("\n")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
# print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
# print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy_rand:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy_rand) / orig_accuracy * 100:.2f}%")
# print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
# print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0282
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.8603
Comma-ablated accuracy: 0.8079
Percent drop in logit diff with comma ablation: 16.34%
Percent drop in accuracy with comma ablation: 19.21%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 0.8529
Comma-ablated accuracy: 0.8065
Percent drop in logit diff with comma ablation: 17.05%
Percent drop in accuracy with comma ablation: 19.35%


In [55]:
ablated_accuracy_rand = (ablated_ld_list_rand > 0).float().mean()
print(f"Comma-ablated accuracy: {ablated_accuracy_rand:.4f}")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy_rand) / orig_accuracy * 100:.2f}%")

Comma-ablated accuracy: 0.8288
Percent drop in accuracy with comma ablation: 17.12%


#### Complete Ablation

In [189]:
def compute_directional_ablation_modified_logit_diff_all_pos(model: HookedTransformer, data_loader: DataLoader, direction_vectors, multiplier=1.0, target_token_ids=13) -> float:
    
    orig_ld_list = []
    ablated_ld_list = []
    freeze_ablated_ld_list = []
    
    for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):
        batch_tokens = batch_value['tokens'].to(device)
        labels = batch_value['label'].to(device)
        punct_pos = batch_value['attention_mask'].to(device)

        # get the logit diff for the last token in each sequence
        orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type="logits", prepend_bos=False)
        orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])
        orig_ld_list.append(orig_ld)
        
        # repeat with commas ablated
        for layer in layers_to_ablate:
            dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)
       
        ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        # check to see if ablated_logits has any nan values
        if torch.isnan(ablated_logits).any():
            print("ablated logits has nan values")
        ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        ablated_ld_list.append(ablated_ld)
        
        model.reset_hooks()

        # repeat with attention frozen and commas ablated
        for layer, head in heads_to_freeze:
            freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)
            model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)

        for layer in layers_to_ablate:
            dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)
            model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)
       
        freeze_ablated_logits = model(batch_tokens, return_type="logits", prepend_bos=False)
        freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])
        freeze_ablated_ld_list.append(freeze_ablated_ld)
        
        model.reset_hooks()

    return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)

In [190]:
model.reset_hooks()

##### Positive Prompt Results

In [191]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(model, pos_data_loader, directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 1.5375
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.9126
Comma-ablated accuracy: 0.9094
Percent drop in logit diff with comma ablation: 40.64%
Percent drop in accuracy with comma ablation: 9.06%


Attn frozen, comma-ablated mean logit diff: 1.3112
Attn frozen, comma-ablated accuracy: 0.9871
Percent drop in logit diff with attn frozen, comma ablation: 14.72%
Percent drop in accuracy with attn frozen, comma ablation: 1.29%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0252
Percent drop in logit diff with comma ablation: 0.14%
Attn frozen, comma-ablated mean logit diff: 1.0261
Percent drop in logit diff with attn frozen, comma ablation: 0.05%


##### Negative Prompt Results

In [192]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(model, neg_data_loader, directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/139 [00:00<?, ?it/s]

Original mean logit diff: 0.5159
Original accuracy: 1.0000


Comma-ablated mean logit diff: -0.3245
Comma-ablated accuracy: 0.3295
Percent drop in logit diff with comma ablation: 162.89%
Percent drop in accuracy with comma ablation: 67.05%


Attn frozen, comma-ablated mean logit diff: -0.5284
Attn frozen, comma-ablated accuracy: 0.1669
Percent drop in logit diff with attn frozen, comma ablation: 202.44%
Percent drop in accuracy with attn frozen, comma ablation: 83.31%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0252
Percent drop in logit diff with comma ablation: 0.14%
Attn frozen, comma-ablated mean logit diff: 1.0261
Percent drop in logit diff with attn frozen, comma ablation: 0.05%


##### Balanced Results

In [193]:
from utils.ablation import freeze_attn_pattern_hook
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
heads_to_freeze = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
layers_to_ablate = [layer for layer in range(model.cfg.n_layers)]
orig_ld_list, ablated_ld_list, freeze_ablated_ld_list = compute_directional_ablation_modified_logit_diff_all_pos(model, balanced_data_loader, directions, 1.0, 13)

orig_accuracy = (orig_ld_list > 0).float().mean()
ablated_accuracy = (ablated_ld_list > 0).float().mean()
freeze_ablated_accuracy = (freeze_ablated_ld_list > 0).float().mean()

print(f"Original mean logit diff: {orig_ld_list.mean():.4f}")
print(f"Original accuracy: {orig_accuracy:.4f}")
print("\n")
print(f"Comma-ablated mean logit diff: {ablated_ld_list.mean():.4f}")
print(f"Comma-ablated accuracy: {ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list.mean() - ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with comma ablation: {(orig_accuracy - ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("\n")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list.mean():.4f}")
print(f"Attn frozen, comma-ablated accuracy: {freeze_ablated_accuracy:.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list.mean() - freeze_ablated_ld_list.mean()) / orig_ld_list.mean() * 100:.2f}%")
print(f"Percent drop in accuracy with attn frozen, comma ablation: {(orig_accuracy - freeze_ablated_accuracy) / orig_accuracy * 100:.2f}%")
print("---------------------------------------------------------")
print("Random direction ablation results:")
print(f"Comma-ablated mean logit diff: {ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with comma ablation: {(orig_ld_list_rand.mean() - ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")
print(f"Attn frozen, comma-ablated mean logit diff: {freeze_ablated_ld_list_rand.mean():.4f}")
print(f"Percent drop in logit diff with attn frozen, comma ablation: {(orig_ld_list_rand.mean() - freeze_ablated_ld_list_rand.mean()) / orig_ld_list_rand.mean() * 100:.2f}%")

  0%|          | 0/278 [00:00<?, ?it/s]

Original mean logit diff: 1.0267
Original accuracy: 1.0000


Comma-ablated mean logit diff: 0.2941
Comma-ablated accuracy: 0.6194
Percent drop in logit diff with comma ablation: 71.36%
Percent drop in accuracy with comma ablation: 38.06%


Attn frozen, comma-ablated mean logit diff: 0.3914
Attn frozen, comma-ablated accuracy: 0.5770
Percent drop in logit diff with attn frozen, comma ablation: 61.88%
Percent drop in accuracy with attn frozen, comma ablation: 42.30%
---------------------------------------------------------
Random direction ablation results:
Comma-ablated mean logit diff: 1.0252
Percent drop in logit diff with comma ablation: 0.14%
Attn frozen, comma-ablated mean logit diff: 1.0261
Percent drop in logit diff with attn frozen, comma ablation: 0.05%


### Get Loss Results

In [140]:
heads_to_ablate = [(layer, head) for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]

loss_change_by_token, orig_loss = compute_mean_ablation_modified_loss(model, subset_data_loader, comma_mean_values, target_token_ids=[13])

  0%|          | 0/100 [00:00<?, ?it/s]

punct_pos: [tensor([14]), tensor([19])]
initial loss shape: torch.Size([1, 31])
batch tokens shape: torch.Size([1, 32])
hooked loss shape: torch.Size([1, 31])
zeroing tensor([14])
zeroing tensor([19])
punct_pos: [tensor([3]), tensor([9]), tensor([19]), tensor([23])]
initial loss shape: torch.Size([1, 31])
batch tokens shape: torch.Size([1, 32])
hooked loss shape: torch.Size([1, 31])
zeroing tensor([3])
zeroing tensor([9])
zeroing tensor([19])
zeroing tensor([23])
punct_pos: [tensor([12]), tensor([17])]
initial loss shape: torch.Size([1, 31])
batch tokens shape: torch.Size([1, 32])
hooked loss shape: torch.Size([1, 31])
zeroing tensor([12])
zeroing tensor([17])
punct_pos: [tensor([7])]
initial loss shape: torch.Size([1, 31])
batch tokens shape: torch.Size([1, 32])
hooked loss shape: torch.Size([1, 31])
zeroing tensor([7])
punct_pos: [tensor([5])]
initial loss shape: torch.Size([1, 31])
batch tokens shape: torch.Size([1, 32])
hooked loss shape: torch.Size([1, 31])
zeroing tensor([5])
pun

In [141]:
for i in range(len(loss_change_by_token)):
    # add one column of zeros to the loss change tensor
    loss_change_by_token[i] = torch.cat([torch.zeros(loss_change_by_token[i].shape[0], 1).to(device), loss_change_by_token[i]], dim=1)

loss_change_by_token = torch.stack(loss_change_by_token).cpu()


In [142]:
subset_dataset_tokens = convert_to_tensors(subset_dataset)

In [143]:
subset_data_loader_tkns = DataLoader(
    subset_dataset_tokens, batch_size=1, shuffle=False, drop_last=True
)
len(subset_data_loader_tkns)

100

In [144]:
loss_change_by_token_by_row = einops.rearrange(loss_change_by_token, "batch item token -> (batch item) token")
loss_change_by_token_by_row = loss_change_by_token_by_row.unsqueeze(2)
loss_change_by_token_by_row.shape

torch.Size([100, 32, 1])

In [145]:
save_array(loss_change_by_token_by_row, 'loss_change_by_token_by_row_sst.npy', model)

'data/pythia-2.8b/loss_change_by_token_by_row_sst.npy'

In [146]:
loss_change_by_token_by_row.sum()

tensor(8.0657)

In [153]:
loss_change_by_token_by_row[1]

tensor([[ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.4284],
        [ 0.4954],
        [ 0.3777],
        [ 0.7570],
        [ 0.6130],
        [ 0.0000],
        [ 0.4558],
        [ 0.1512],
        [-0.5029],
        [-0.4370],
        [ 0.0000],
        [ 0.5531],
        [-0.4186],
        [ 0.0285],
        [-0.0923],
        [ 0.0000],
        [ 0.0777],
        [ 0.0282],
        [ 0.0304],
        [ 0.0000],
        [ 0.3828],
        [-0.1881],
        [-0.0471],
        [-0.1300],
        [-0.0969],
        [ 0.0932],
        [ 0.0637]])

In [150]:
from utils.neuroscope import plot_topk
loss_change_by_token = torch.from_numpy(load_array('loss_change_by_token_by_row_sst.npy', model))
plot_topk(
    activations=loss_change_by_token_by_row, 
    dataloader=subset_data_loader_tkns, 
    window_size=64, 
    model=model, 
    k=40, 
    centred=False, 
    #exclusions=[" '", " ,", ",", "."," ."," Fig", "'t", " Pinterest", " Kampf", "m", "uk", " Kamp", "com", "edu", "S", "youtube", "twitter", "0", "js", "py", " Protein", " Fiber", " Carbohydrates", " Sugar", " Grant", " Pub", ","]
)

Top 40 most positive examples:
Example:  Yas, Activation: 2.0176, Batch: 76, Pos: 31
Example: ., Activation: 1.3744, Batch: 33, Pos: 29
Example:  II, Activation: 1.2075, Batch: 7, Pos: 28
Example: ,, Activation: 1.1415, Batch: 14, Pos: 17
Example:  weight, Activation: 0.8663, Batch: 18, Pos: 17
Example:  roles, Activation: 0.7570, Batch: 1, Pos: 8
Example:  echoes, Activation: 0.7151, Batch: 2, Pos: 14
Example:  buy, Activation: 0.6960, Batch: 35, Pos: 9
Example: 'm, Activation: 0.6842, Batch: 48, Pos: 31
Example: .'', Activation: 0.6383, Batch: 10, Pos: 12
Example: ., Activation: 0.6022, Batch: 25, Pos: 26
Example: Min, Activation: 0.5999, Batch: 4, Pos: 7
Example: <|endoftext|>, Activation: 0.5333, Batch: 0, Pos: 25
Example:  as, Activation: 0.4545, Batch: 39, Pos: 29
Example:  it, Activation: 0.3246, Batch: 53, Pos: 23
Example:  string, Activation: 0.2914, Batch: 3, Pos: 12
Example:  the, Activation: 0.2676, Batch: 27, Pos: 26


Top 40 most negative examples:
Example:  part, Activation: -1.3371, Batch: 0, Pos: 16
Example:  ultimately, Activation: -0.9171, Batch: 14, Pos: 19
Example:  Report, Activation: -0.7714, Batch: 4, Pos: 9
Example:  Cris, Activation: -0.5618, Batch: 2, Pos: 22
Example:  and, Activation: -0.5029, Batch: 1, Pos: 13
Example:  still, Activation: -0.4266, Batch: 10, Pos: 7
Example:  seem, Activation: -0.3844, Batch: 24, Pos: 7
Example:  dialogue, Activation: -0.3544, Batch: 18, Pos: 23
Example: erving, Activation: -0.3078, Batch: 25, Pos: 23
Example: perhaps, Activation: -0.3037, Batch: 53, Pos: 26
Example:  protagon, Activation: -0.2978, Batch: 39, Pos: 31
Example: itus, Activation: -0.2772, Batch: 33, Pos: 26
Example:  camp, Activation: -0.2713, Batch: 7, Pos: 23
Example: earth, Activation: -0.2532, Batch: 27, Pos: 30
Example: ache, Activation: -0.1927, Batch: 76, Pos: 29
Example:  the, Activation: -0.1866, Batch: 16, Pos: 23


## Comma Ablation for Classification

In [58]:
import torch

import random

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from transformer_lens import HookedTransformer
from datasets import load_dataset, Dataset, DatasetDict
from tqdm.notebook import tqdm
from utils.store import load_pickle, load_array
from utils.ablation import ablate_resid_with_precalc_mean

In [8]:
dataset = load_dataset("imdb")
dataset["train"][100]

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

{'text': "Terrible movie. Nuff Said.<br /><br />These Lines are Just Filler. The movie was bad. Why I have to expand on that I don't know. This is already a waste of my time. I just wanted to warn others. Avoid this movie. The acting sucks and the writing is just moronic. Bad in every way. The only nice thing about the movie are Deniz Akkaya's breasts. Even that was ruined though by a terrible and unneeded rape scene. The movie is a poorly contrived and totally unbelievable piece of garbage.<br /><br />OK now I am just going to rag on IMDb for this stupid rule of 10 lines of text minimum. First I waste my time watching this offal. Then feeling compelled to warn others I create an account with IMDb only to discover that I have to write a friggen essay on the film just to express how bad I think it is. Totally unnecessary.",
 'label': 0}

### GPT-2 Classifier

In [9]:
MODEL_NAME = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(8000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(2000))

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [15]:
model = AutoModelForCausalLM.from_pretrained("./gpt2_imdb_classifier")
class_layer_weights = load_pickle("gpt2_imdb_classifier_classification_head_weights", 'gpt2')

model = HookedTransformer.from_pretrained(
    "gpt2",
    hf_model=model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [45]:
def get_classification_prediction(eval_dataset, dataset_idx, verbose=False):

    logits, cache = model.run_with_cache(small_eval_dataset[dataset_idx]['text'])
    last_token_act = cache['ln_final.hook_normalized'][0, -1, :]
    res = torch.softmax(torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(), dim=-1)
    if verbose:
        print(f"Sentence: {small_eval_dataset[dataset_idx]['text']}")
        print(f"Prediction: {res.argmax()} Label: {small_eval_dataset[dataset_idx]['label']}")

    return res.argmax(), small_eval_dataset[dataset_idx]['label'], res

In [47]:
def get_accuracy(eval_dataset, n=100):
    correct = 0
    for idx in range(min(len(eval_dataset), n)):
        pred, label, _ = get_classification_prediction(eval_dataset, idx)
        if pred == label:
            correct += 1
    return correct / n

get_accuracy(small_eval_dataset)

0.89

In [56]:
def compute_mean_ablation_modified_accuracy(model: HookedTransformer, dataset: Dataset, target_token_ids) -> float:

    for _, item in tqdm(enumerate(dataset), total=len(dataset)):
        batch_tokens = model.to_tokens(item['text'], prepend_bos=False)
        print(batch_tokens.shape)

    #     # get positions of all 13 and 15 token ids in batch
    #     punct_pos = find_positions(batch_tokens, token_ids=target_token_ids)

    #     # get the loss for each token in the batch
    #     initial_loss = model(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)
    #     orig_loss_list.append(initial_loss)
        
    #     # add hooks for the activations of the 13 and 15 tokens
    #     for layer, head in heads_to_ablate:
    #         mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)
    #         model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)

    #     # get the loss for each token when run with hooks
    #     hooked_loss = model.run_with_cache(batch_tokens, return_type="loss", prepend_bos=False, loss_per_token=True)

    #     # compute the percent difference between the two losses
    #     loss_diff = hooked_loss - initial_loss
    #     loss_diff_list.append(loss_diff)

    # model.reset_hooks()
    # return loss_diff_list, orig_loss_list

In [54]:
# load the files
comma_mean_values = torch.from_numpy(load_array('comma_mean_values.npy', model)).to(device)
period_mean_values = torch.from_numpy(load_array('period_mean_values.npy', model)).to(device)

FileNotFoundError: [Errno 2] No such file or directory: 'data/gpt2-small/comma_mean_values.npy'

In [59]:
compute_mean_ablation_modified_accuracy(model, small_eval_dataset, target_token_ids=[15])

  0%|          | 0/2000 [00:00<?, ?it/s]

torch.Size([1, 413])
torch.Size([1, 230])
torch.Size([1, 157])
torch.Size([1, 139])
torch.Size([1, 108])
torch.Size([1, 171])
torch.Size([1, 147])
torch.Size([1, 164])
torch.Size([1, 177])
torch.Size([1, 158])
torch.Size([1, 173])
torch.Size([1, 154])
torch.Size([1, 274])
torch.Size([1, 340])
torch.Size([1, 606])
torch.Size([1, 232])
torch.Size([1, 217])
torch.Size([1, 344])
torch.Size([1, 193])
torch.Size([1, 229])
torch.Size([1, 175])
torch.Size([1, 581])
torch.Size([1, 634])
torch.Size([1, 206])
torch.Size([1, 161])
torch.Size([1, 296])
torch.Size([1, 225])
torch.Size([1, 163])
torch.Size([1, 245])
torch.Size([1, 166])
torch.Size([1, 438])
torch.Size([1, 153])
torch.Size([1, 114])
torch.Size([1, 94])
torch.Size([1, 209])
torch.Size([1, 75])
torch.Size([1, 156])
torch.Size([1, 385])
torch.Size([1, 209])
torch.Size([1, 666])
torch.Size([1, 67])
torch.Size([1, 137])
torch.Size([1, 423])
torch.Size([1, 151])
torch.Size([1, 142])
torch.Size([1, 208])
torch.Size([1, 270])
torch.Size([1, 2