# Misdirection

Exploring IOI inputs with misdirected outputs - e.g.

> John was two years older than Mary. Who was born first? Mary was born before

This uses the same setup as the standard IOI task and can be measured in the same way, but allows us to compare any differences that exist between the activations for outputs that logically follow the input vs those that contradict it.

In [None]:
import torch
from transformer_lens import HookedTransformer 
import matplotlib.pyplot as plt
import pandas as pd

from utils import *

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
plt.ioff()

import warnings
warnings.filterwarnings('ignore')

# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

In [None]:
plt.ion()

Below creates a set of IOI inputs with templated subject, indirect object and object variables. 

In [None]:
from itertools import permutations, product

template = "{0} was two years older than{1}. Who was born first,{2} was born before{3}"
names = (" Mary", " John", " Alice", " Bob")
names = (" Phil", " Bob", " James", " Paul")

prompts = [
    template.format(S, IO, S, IO)
    for S, IO in permutations(names, 2)
]

len(prompts), prompts[:5]

In [None]:
corrupted_prompts = [
    template.format(S, IO, IO, S)
    for S, IO in permutations(names, 2)
]

len(corrupted_prompts), corrupted_prompts[:5]

In [None]:
cache = run_prompts(model, *prompts)
corrupted_cache = run_prompts(model, *corrupted_prompts)

In [None]:
l, h = 3, 11
attn_data = calculate_attns(cache, l, h)
attn_data.shape

In [None]:
corrupted_attn_data = calculate_attns(corrupted_cache, l, h)
corrupted_attn_data.shape

In [None]:
a = plot_attn(model, attn_data, feature_index=0, show_grid_labels=False)
b = plot_attn(model, corrupted_attn_data, feature_index=0, show_grid_labels=False)
c = plot_attn(model, attn_data, feature_index=3, show_grid_labels=False)
d = plot_attn(model, corrupted_attn_data, feature_index=3, show_grid_labels=False)

figure(a, b, c, d, title="Attention tokens for clean vs misdirected IOI")

In [None]:
plt.ioff()

In [None]:
data = []
corrupted_data = []
for i in range(90, 144):
    l, h = get_head_index(i)
    attn_data = calculate_attns(cache, l, h)
    corrupted_attn_data = calculate_attns(corrupted_cache, l, h)
    data.append(attn_data)
    corrupted_data.append(corrupted_attn_data)
    a = plot_attn(model, attn_data, feature_index=0, show_grid_labels=False)
    b = plot_attn(model, corrupted_attn_data, feature_index=0, show_grid_labels=False)
    c = plot_attn(model, attn_data, feature_index=3, show_grid_labels=False)
    d = plot_attn(model, corrupted_attn_data, feature_index=3, show_grid_labels=False)
    fig = figure(a, b, c, d, title=f'Attention tokens for clean vs misdirected IOI ({l}.{h})')
    fig.savefig(f'./images/attention_tokens_clean_vs_misdirected_IOI_{l}_{h}.png')

data = torch.stack(data)
data.shape

## t-SNE plots for clean vs corrupted

## token counts across input dataset

## shared tokens across heads

## logit contribution heads