# Exploration of OthelloGPT

In [2]:
import os

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from functools import partial
from typing import Tuple, Union

from circuitsvis.attention import attention_patterns
import einops
import torch as t
import transformer_lens.utils as utils
import wandb
from jaxtyping import Float, Int
from neel_plotly import line
from torch import Tensor
from transformer_lens import (
    HookedTransformer,
)
from transformer_lens.hook_points import HookPoint

from plotly_utils import imshow

from utils import *
from plotting import *
from probes import get_probe, get_neels_probe, ProbeTrainingArgs, LitLinearProbe

from othello_world.mechanistic_interpretability.mech_interp_othello_utils import (
    plot_single_board, )

try:
    import pytorch_lightning as pl
    from pytorch_lightning.loggers import CSVLogger, WandbLogger
except ValueError:
    print("pytorch_lightning not working")

%load_ext autoreload
%autoreload 2


pytorch_lightning working


NameError: name 'pl' is not defined

## Setup

Things that you probably always want to run.

In [None]:
t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

cfg, model = get_othello_gpt(device)

In [None]:
# %% Loading sample data
full_games_tokens, full_games_board_index = load_sample_games()
num_games = 50
focus_games_tokens = full_games_tokens[:num_games]
focus_games_board_index = full_games_board_index[:num_games]
assert (TOKENS_TO_BOARD[focus_games_tokens] == focus_games_board_index).all()

focus_states = move_sequence_to_state(focus_games_board_index)
focus_valid_moves = move_sequence_to_state(focus_games_board_index, mode="valid")

print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


## Ablation Study

In [None]:
individual_heads = False

n_games = 50
tokens = full_games_tokens[-n_games:].to(device)
board_index = full_games_board_index[-n_games:].to(device)
get_metrics = lambda model: get_loss(model, tokens, board_index, 5, -5).to_tensor()
zero_ablation_metrics = zero_ablation(model, get_metrics, individual_heads).cpu()
base_metrics = get_metrics(model)

In [None]:
# Plotting the results
x_labels = [f"Head {i}" for i in range(model.cfg.n_heads)] + ["All Heads", "MLP"]
y_labels = [f"Layer {i}" for i in range(model.cfg.n_layers)]
if not individual_heads:
    x_labels = x_labels[-2:]

imshow(
    zero_ablation_metrics[:3] - base_metrics[:3, None, None],
    title="Metric increase after zeroing each component",
    x=x_labels,
    y=y_labels,
    facet_col=0,
    facet_labels=["Loss", "Cell accuracy", "Board accuracy"],
)

In [None]:
# Abblate all attention after the layer `n`

def filter(name: str, start_layer: int = 0):
    if not name.startswith("blocks."):
        # 'hook_embed' or 'hook_pos_embed' or 'ln_final.hook_scale' or 'ln_final.hook_normalized'
        return False

    layer = int(name.split(".")[1])

    return layer >= start_layer and "attn_out" in name

metrics_per_layer = []
for start_layer in range(model.cfg.n_layers):
    with model.hooks(fwd_hooks=[(partial(filter, start_layer=start_layer),
                                    zero_ablation_hook)]):
        metrics_per_layer.append(get_metrics(model))

In [None]:
# %% Plot
lines = t.stack(metrics_per_layer, dim=1).cpu()
line(
    lines[:3],
    x=[f"≥ {i}" for i in range(model.cfg.n_layers)],
    facet_col=0,
    #  facet_col_wrap=3,
    facet_labels=[
        "Loss",
        "Cell accuracy",
        "Board accuracy",
    ],  # 'False Positive', 'False Negative', 'True Positive', 'True Negative'],
    title="Metrics after zeroing all attention heads above a layer",
)

## Exploration of the probe

In [None]:
# Exploration of the probe

neel_probe = get_neels_probe(device)
"""Shape (d_model, rows, cols, options)
options: 0: blank, 1: my piece, 2: their piece"""

blank_probe, my_probe, their_probe = neel_probe.unbind(dim=-1)
blank_direction = blank_probe - (their_probe + my_probe) / 2
my_direction = my_probe - their_probe


### Accuracy and cosine similarity

In [None]:
plot_probe_accuracy(
    model,
    neel_probe,
    focus_games_tokens,
    focus_games_board_index,
    # per_option=True,
    per_move="board_accuracy",
    name="Neel's probe",
)


In [None]:
for probe, name in zip([blank_probe, their_probe, my_probe], ["blank", "their", "my"]):
    probe = einops.rearrange(probe, "d_model rows cols -> (rows cols) d_model")
    plot_similarities(probe,
                      title=f"Similarity between {name} vectors",
                      x=full_board_labels,
                      y=full_board_labels)

In [None]:
# %% Similarity between mine and theirs (for each square)
plot_similarities_2(my_probe, their_probe)

### Using UMAP

In [None]:

import umap
import umap.plot
import pandas as pd

vectors = einops.rearrange(linear_probe,
                            "d_model rows cols options -> (options rows cols) d_model")

mapper = umap.UMAP(metric="cosine").fit(vectors.cpu().numpy())

labels = [probe_name for probe_name in ["blank", "their", "my"] for _ in full_board_labels]
hover_data = pd.DataFrame({
    "square": full_board_labels * 3,
    "probe": labels,
})

umap.plot.show_interactive(
    umap.plot.interactive(mapper, labels=labels, hover_data=hover_data, theme="inferno"))


### Using PCA

In [None]:
linear_probe = get_probe(device)

# %% Run PCA on the vectors of the probe
vectors = einops.rearrange(linear_probe, "d_model rows cols options -> (options rows cols) d_model")
plot_PCA(vectors, "the probe vectors")
# %% The same be per option
for i in range(3):
    plot_PCA(
        linear_probe[..., i],
        f"the probe vectors for option {i}",
        flip_dim_order=True,
        absolute=True,
    )

# %% Normalise the probe then run PCA
normalised_probe = linear_probe / linear_probe.norm(dim=-1, keepdim=True)
plot_PCA(normalised_probe, "the normalised probe vectors", flip_dim_order=True)

# %% Same PCA but with the unembeddings
plot_PCA(model.W_U, "the unembeddings")

# %%
plot_PCA(model.W_pos, "the embeddings")
# %%
all_vectors = [
    model.W_U.T,
    model.W_E,
    model.W_pos,
    vectors,
]
all_vectors = [(v - v.mean(dim=0)) / v.std(dim=0) for v in all_vectors]

plot_PCA(t.cat(all_vectors, dim=0), "the embeddings and unembeddings")
# %%
plot_PCA(t.cat([my_direction, blank_direction], dim=1).flatten(1).T, "the direction vectors")
# %%
plot_PCA(my_direction.flatten(1).T, "the direction vectors")
# %%
plot_PCA(blank_direction.flatten(1).T, "the direction vectors")
