<a href="https://colab.research.google.com/github/gg2001/transformer-circuits/blob/master/puzzles/1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [Mech Interp Puzzle 1: Suspiciously Similar Embeddings in GPT-Neo](https://www.alignmentforum.org/posts/eLNo7b56kQQerCzp2/mech-interp-puzzle-1-suspiciously-similar-embeddings-in-gpt)

In [1]:
%pip install transformer_lens plotly nbformat

Collecting transformer_lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting wandb>=0.13.5 (from transformer_lens)
  Downloading wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer_lens)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxha

In [2]:
from tqdm import tqdm
from typing import Callable
from transformer_lens import HookedTransformer
import plotly.express as px
import pandas as pd
import torch
import heapq

In [3]:
model = HookedTransformer.from_pretrained("gpt-neo-small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/526M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

Loaded pretrained model gpt-neo-small into HookedTransformer


In [4]:
W_E = model.W_E
W_E.requires_grad = False

In [5]:
# subsample = torch.randperm(model.cfg.d_vocab)[:5000].to(model.cfg.device)
# W_E = model.W_E[subsample]  # Take a random subset of 5,000 for memory reasons

In [6]:
W_E_normed = W_E / W_E.norm(dim=-1, keepdim=True) # [d_vocab, d_model]
# cosine_sims = W_E_normed @ W_E_normed.T # [d_vocab, d_vocab]

In [7]:
# px.histogram(
#    cosine_sims.flatten().detach().cpu().numpy(),
#    title="Pairwise cosine sims of embedding",
#)

In [8]:
TokenHeap = list[tuple[float, int, int, str, str]]

def push_token_heap(
    heap: TokenHeap,
    values: list[float],
    token0: list[int],
    token1: list[int],
    compare: Callable[[float, float], bool],
    top_k: int
):
    for value, t0, t1 in zip(values, token0, token1):
        if len(heap) < top_k:
            heapq.heappush(heap, (value, t0, t1))
        else:
            if compare(value, heap[0][0]):
                heapq.heappushpop(heap, (value, t0, t1))

In [9]:
num_tokens = W_E_normed.size(0)
batch = 1000
top_k = 1000
top_heap: TokenHeap = []
bottom_heap: TokenHeap = []
sum_cosine_sims = 0.0
count_cosine_sims = 0


for i in tqdm(range(0, num_tokens, batch)):
    # Get the cosine similarities of the batch
    token_cosine_sims = (W_E_normed[i:i+batch] @ W_E_normed[i + 1:].T)

    # Get the top diagonal of the batch
    mask = torch.tril(torch.ones(token_cosine_sims.size(0), token_cosine_sims.size(1), dtype=torch.bool, device=token_cosine_sims.device), diagonal=-1)
    cosine_sims_top = token_cosine_sims.masked_fill(mask, float('-inf'))
    flattened = cosine_sims_top.view(-1)

    # Incrementally calculate the mean
    valid_mask = flattened != float('-inf')
    sum_cosine_sims += flattened[valid_mask].sum().item()
    count_cosine_sims += valid_mask.sum().item()

    # Get the highest top_k cosine similarities
    top_cosine_sims = torch.topk(flattened, top_k)
    top_values = top_cosine_sims.values.tolist()
    top_token0 = ((top_cosine_sims.indices // (num_tokens - i - 1)) + i).tolist()
    top_token1 = ((top_cosine_sims.indices % (num_tokens - i - 1)) + i + 1).tolist()
    push_token_heap(top_heap, top_values, top_token0, top_token1, lambda x, y: x > y, top_k)

    # Get the lowest top_k cosine similarities
    flattened[valid_mask] *= -1
    bottom_cosine_sims = torch.topk(flattened, top_k)
    bottom_values = (bottom_cosine_sims.values.clone() * -1).tolist()
    bottom_token0 = ((bottom_cosine_sims.indices // (num_tokens - i - 1)) + i).tolist()
    bottom_token1 = ((bottom_cosine_sims.indices % (num_tokens - i - 1)) + i + 1).tolist()
    push_token_heap(bottom_heap, bottom_values, bottom_token0, bottom_token1, lambda x, y: x < y, top_k)


# Sort heaps
top_heap.sort(key=lambda x: x[0], reverse=True)
bottom_heap.sort(key=lambda x: x[0], reverse=False)

# Get the token strings
for i in range(top_k):
    top_heap[i] += (model.tokenizer.convert_ids_to_tokens(top_heap[i][1]), model.tokenizer.convert_ids_to_tokens(top_heap[i][2]))
    bottom_heap[i] += (model.tokenizer.convert_ids_to_tokens(bottom_heap[i][1]), model.tokenizer.convert_ids_to_tokens(bottom_heap[i][2]))

# Compute mean
mean = sum_cosine_sims / count_cosine_sims

print()
print(f"Mean cosine similarity: {mean}")
print(f"Highest cosine similarity: {top_heap[0][0]} (token0={top_heap[0][3]}, token1={top_heap[0][4]})")
print(f"Lowest cosine similarity: {bottom_heap[0][0]} (token0={bottom_heap[0][3]}, token1={bottom_heap[0][4]})")

100%|██████████| 51/51 [00:05<00:00,  9.47it/s]



Mean cosine similarity: 0.9293529306070797
Highest cosine similarity: 0.9999352097511292 (token0=TPPStreamerBot, token1=EStreamFrame)
Lowest cosine similarity: 0.6320329904556274 (token0=ĠC, token1=ĠShinji)


In [10]:
columns = ["cosine_sum", "token0_id", "token1_id", "token0_str", "token1_str"]
top_df = pd.DataFrame(top_heap, columns=columns)
bottom_df = pd.DataFrame(bottom_heap, columns=columns)

In [21]:
top_df.head(n=20)

Unnamed: 0,cosine_sum,token0_id,token1_id,token0_str,token1_str
0,0.999935,37579,43177,TPPStreamerBot,EStreamFrame
1,0.999935,37574,43177,StreamerBot,EStreamFrame
2,0.999934,28666,43177,PsyNetMessage,EStreamFrame
3,0.999934,187,36173,ÿ,ĠRandomRedditor
4,0.999933,178,187,ö,ÿ
5,0.999933,36173,43177,ĠRandomRedditor,EStreamFrame
6,0.999933,181,182,ù,ú
7,0.999933,30905,37579,rawdownload,TPPStreamerBot
8,0.999933,181,35207,ù,ĠattRot
9,0.999932,181,43177,ù,EStreamFrame


In [22]:
top_token_counts = pd.concat([top_df['token0_str'], top_df['token1_str']]).value_counts()
top_token_counts.head(n=20)

Unnamed: 0,count
ÿ,51
EStreamFrame,51
TPPStreamerBot,49
Á,49
PsyNetMessage,49
ù,49
ĠRandomRedditor,48
StreamerBot,47
ĠexternalToEVAOnly,47
cloneembedreportprint,47


In [23]:
bottom_df.head(n=20)

Unnamed: 0,cosine_sum,token0_id,token1_id,token0_str,token1_str
0,0.632033,327,33891,ĠC,ĠShinji
1,0.63607,327,34926,ĠC,ĠGors
2,0.64181,564,34926,ĠâĢ,ĠGors
3,0.641955,406,33891,ĠL,ĠShinji
4,0.642201,327,40627,ĠC,ĠMakoto
5,0.642831,337,33891,ĠM,ĠShinji
6,0.645438,564,33891,ĠâĢ,ĠShinji
7,0.645917,360,33891,ĠD,ĠShinji
8,0.646062,347,34926,ĠB,ĠGors
9,0.647432,350,33891,ĠP,ĠShinji


In [24]:
bottom_token_counts = pd.concat([bottom_df['token0_str'], bottom_df['token1_str']]).value_counts()
bottom_token_counts.head(n=20)

Unnamed: 0,count
ĠC,94
ĠGors,89
ĠShinji,79
ĠMakoto,70
ĠP,57
ĠâĢ,57
ĠL,45
ĠM,41
ĠD,40
ĠS,39
