# Exploration of OthelloGPT

In [1]:
%load_ext autoreload
%autoreload 2

import os

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from functools import partial
from typing import Tuple, Union, Dict

from circuitsvis.attention import attention_patterns
import einops
import pandas as pd
import plotly.express as px
import torch as t
import transformer_lens.utils as utils
import wandb
from jaxtyping import Float, Int
from neel_plotly import line
from torch import Tensor
from transformer_lens import (
    HookedTransformer,
)
from transformer_lens.hook_points import HookPoint

from plotly_utils import imshow

import circuitsvis

import utils as u
from utils import *
from plotting import *
from probes import get_probe, get_neels_probe, ProbeTrainingArgs, LitLinearProbe

from othello_world.mechanistic_interpretability.mech_interp_othello_utils import (
    plot_single_board, )




pytorch_lightning not working. Cannot train probes.


## Setup

Things that you probably always want to run.

In [2]:
t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

cfg, model = get_othello_gpt(device)

Moving model to device:  cpu
Moving model to device:  cpu


In [89]:
val_tokens, val_valid = generate_training_data(100, 69)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# %% Loading sample data
full_games_tokens, full_games_board_index = load_sample_games()
num_games = 50
focus_games_tokens = full_games_tokens[:num_games]
focus_games_board_index = full_games_board_index[:num_games]
assert (TOKENS_TO_BOARD[focus_games_tokens] == focus_games_board_index).all()

focus_states = move_sequence_to_state(focus_games_board_index)
focus_valid_moves = move_sequence_to_state(focus_games_board_index, mode="valid")

print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


## Exploration of the model

In [5]:
pos = model.W_pos

pos_to_pos = pos @ pos.T
imshow((pos_to_pos)[1:, 1:], title="Dot product between positional embeddings")
imshow((pos_to_pos)[2::2, 2::2], title="Dot product between even positional embeddings")

In [None]:
we_dir = model.W_E / model.W_E.norm(dim=-1, keepdim=True)
sim = we_dir @ we_dir.T
# remove the diagonal elements
sim = sim - sim.diag().diag()
sim = logits_to_board(sim, 'logits')  # token, row, col
sim[:, 3:5, 3:5] = 0
print("Max correlation:", sim.max())
print("Min correlation:", sim.min())
plot_square_as_board(
    sim,
    animation_frame=0,
    title="Dot product between token embeddings",
    height=700,
)

In [124]:
def plot_attention_from_pos_emb(model: HookedTransformer, layer: int = 0,
                                use_visibility_factor: bool = False,
                                softmax: bool = True,
                                hide_position_0: bool = False,
                                use_animation: bool = False,
                                ) -> None:
    pos = model.W_pos

    positional_queries = einops.einsum(
        pos, model.W_Q[layer],
        "pos dmodel, head dmodel dhead -> pos head dhead",
    ) + model.b_Q[layer]
    positional_keys = einops.einsum(
        pos, model.W_K[layer],
        "pos dmodel, head dmodel dhead -> pos head dhead",
    ) + model.b_K[layer]

    print("positional_queries", positional_queries.shape)
    positional_attention = einops.einsum(
        positional_queries, positional_keys,
        "pos_q head dhead, pos_k head dhead -> head pos_q pos_k",
    )

    mask = t.triu(t.ones_like(positional_attention), diagonal=1).bool()
    positional_attention.masked_fill_(mask, float("-inf"))

    if softmax:
        positional_attention = positional_attention.softmax(-1)

    if use_visibility_factor:
        visibility_factor = t.arange(positional_attention.shape[-1], device=device).float()[:, None]
        positional_attention *= visibility_factor

    if hide_position_0:
        positional_attention[:, 0, :] = float('nan')
        positional_attention[:, :, 0] = float('nan')
    
    if use_animation:
        kwargs = dict(
            animation_frame=0,
            height=700,
        )
    else:
        kwargs = dict(
            facet_col=0,
            facet_col_wrap=3,
            height=1200,
            facet_labels=["Head {}".format(i) for i in range(model.cfg.n_heads)],
        )
    
    imshow(
        positional_attention,
        title=f"Attention from matching of positional embeddings of layer {layer}",
        **kwargs,
    )

plot_attention_from_pos_emb(model, 0, softmax=0)

positional_queries torch.Size([59, 8, 64])


In [185]:
from dataclasses import field

def wrap_print(f):
    return f
    def wrapped(*args, **kwargs):
        print()
        print(f.__name__)
        print("args:", *[a.shape if isinstance(a, Tensor) else a for a in args])
        print("kwargs:", *[f"{k}={v.shape if isinstance(v, Tensor) else v}" for k, v in kwargs.items()])
        return f(*args, **kwargs)
    return wrapped

@dataclass
class Circuit:
    model: HookedTransformer
    value: Tensor = field(default_factory=lambda: t.tensor(1))
    _shape: List[str] = field(default_factory=lambda: [])

    PREFERED_DIMS = ['dhead', 'dmodel']

    def __repr__(self):
        return f"Circuit({', '.join(f'{a}={b}' for a, b in self.shape.items())})"

    def by(self, *dims: str):
        assert all(dim in self._shape for dim in dims), f"Dims {dims} not in {self._shape}"

        new_shape = [*dims] + [d for d in self._shape if d not in dims]
        value = einops.rearrange(self.value, f"{' '.join(self._shape)} -> {' '.join(new_shape)}")

        return Circuit(self.model, value, new_shape)

    def _prefered_dim(self, shape: List[str]):
        if len(shape) == 1:
            return next(iter(shape))
        for dim in self.PREFERED_DIMS:
            if dim in shape:
                return dim
        raise ValueError(f"Could not find a prefered dim in {shape}")

    def _dim_name_to_index(self, dim: Optional[str] = None):
        if dim is None:
            assert len(self._shape) == 1, f"Implicit dimention is possible only for 1D tensor. Got: {self}."
            return 0
        else:
            assert dim in self._shape, f"Dimension {dim} not in {self._shape}"
            return self._shape.index(dim)

    @property
    def shape(self) -> Dict[str, int]:
        return {d: s for d, s in zip(self._shape, self.value.shape)}

    def __getitem__(self, index: Union[int, slice, Tuple[str, Union[int, slice]]]):
        if isinstance(index, tuple):
            dim, index = index
        elif isinstance(index, (int, slice)):
            dim = None
        else:
            raise ValueError(f"Invalid index type: {type(index)}")

        dim_index = self._dim_name_to_index(dim)

        indexing = [slice(None)] * len(self._shape)
        indexing[dim_index] = index

        if isinstance(index, int):
            new_shape = self._shape[:dim_index] + self._shape[dim_index + 1:]
        else:
            new_shape = self._shape

        return Circuit(self.model, self.value[indexing], new_shape)


    def _mul(self, other: Tensor, other_shape: List[str], bias: Optional[Tensor] = None):
        if not self._shape:
            value = self.value * other
            shape = other_shape
            return Circuit(self.model, value, shape)

        # Check that no dim is repeated in each tensor
        assert len(set(self._shape)) == len(self._shape), f"Repeated dimensions in {self._shape}"
        assert len(set(other_shape)) == len(other_shape), f"Repeated dimensions in {other_shape}"

        # Find the dimension on which to multiply
        possible_multiplied_dims = set(self._shape) & set(other_shape)
        dim = self._prefered_dim(possible_multiplied_dims)

        pattern_1 = " ".join(self._shape)
        pattern_2 = " ".join(other_shape)

        shape = [d for d in self._shape if d != dim] + [d for d in other_shape if d not in possible_multiplied_dims]
        pattern_target = ' '.join(shape)

        value = einops.einsum(
            self.value,
            other,
            f"{pattern_1}, {pattern_2} -> {pattern_target}",
        )
        
        if bias is not None:
            bias_shape = [d for d in other_shape if d != dim]
            target_shape = [d if d in bias_shape else "1" for d in shape]
            bias = einops.rearrange(bias, f"{' '.join(bias_shape)} -> {' '.join(target_shape)}")
            value += bias

        return Circuit(self.model, value, shape)

    def _rename(self, old: str, new: str):
        self._shape = [new if d == old else d for d in self._shape]

    def embedding(self):
        if 'vocab' in self._shape:
            self._rename("vocab", "vocab_1")
            return self._mul(self.model.W_E, ["vocab_2", "dmodel"])
        else:
            return self._mul(self.model.W_E, ["vocab", "dmodel"])

    def pos_embedding(self):
        if 'pos' in self._shape:
            self._rename("pos", "pos_1")
            return self._mul(self.model.W_pos, ["pos_2", "dmodel"])
        else:
            return self._mul(self.model.W_pos, ["pos", "dmodel"])

    def unembed(self):
        self._rename("vocab", "vocab_in")
        return self._mul(self.model.W_U, ["vocab_out"], self.model.b_U)

    def ov(self, layer: int, head: Optional[int] = None):
        if head is None:
            index = (layer,)
            dims_v = ["head", "dmodel", "dhead"]
            dims_o = ["head", "dhead", "dmodel"]
        else:
            index = (layer, head)
            dims_v = ["dmodel", "dhead"]
            dims_o = ["dhead", "dmodel"]

        out = self._mul(self.model.W_V[index], dims_v)
        return out._mul(self.model.W_O[index], dims_o, self.model.b_O[index])

    @wrap_print
    def qk(self, layer: Optional[int] = None, head: Optional[int] = None, *, key: "Circuit"):
        assert self.model is key.model, "Both circuits must have the same model"

        dims = []
        index = ()

        if layer is None:
            dims += ["layer"]
        else:
            index = (layer,)

        if head is None:
            dims += ["head"]
        else:
            index += (head,)

        dims += ["dmodel", "dhead"]

        # Apply the query
        query = self._mul(self.model.W_Q[index], dims, self.model.b_Q[index])

        # Apply the key
        key = key._mul(self.model.W_K[index], dims, self.model.b_K[index])


        # Combine the two
        # Find the intersection of the two shapes
        possible_multiplied_dims = set(query._shape) & set(key._shape) - {'dhead', 'head', 'layer'}
        # We always want to match the head dimension, and there might be an other dimension
        # that is shared between the two, which we dont want to match
        query._shape = [d + "_q" if d in possible_multiplied_dims else d for d in query._shape]
        key._shape = [d + "_k" if d in possible_multiplied_dims else d for d in key._shape]


        out = query._mul(key.value, key._shape)
        
        # If we have two "pos" dimensions, use the mask to simulate the triangular matrix
        tokens_dims = ["pos_q", "pos_k"]
        if all(d in out._shape for d in tokens_dims):
            size = out.shape['pos_q']
            mask = t.triu(t.ones(size, size), diagonal=1).to(out.value.device, t.bool)
            # Reshape to put pos_1 and pos_2 at the end
            new_shape = [d for d in out._shape if d not in tokens_dims] + tokens_dims
            out = out.rearange(new_shape)
            out.value.masked_fill_(mask, float("-inf"))

        return out

    def add(self, other: "Circuit"):
        assert self.model is other.model, "Both circuits must have the same model"
        
        # We keep the one with the largest number of dimensions
        if len(self._shape) < len(other._shape):
            return other.add(self)
        
        assert set(other.shape) <= set(self.shape), f"Cannot add {other.shape} to {self.shape}"

        other_new_shape = [d if d in other._shape else "1" for d in self._shape]
        other_value = einops.rearrange(other.value, f"{' '.join(other._shape)} -> {' '.join(other_new_shape)}")

        return Circuit(self.model, self.value + other_value, self._shape)

    def new_dim(self, dim: str):
        assert dim not in self._shape, f"Cannot add {dim} to {self.shape}, already present"
        return Circuit(self.model, self.value.unsqueeze(-1), self._shape + [dim])

    def rearange(self, new_shape: List[str]):
        value = einops.rearrange(self.value, f"{' '.join(self._shape)} -> {' '.join(new_shape)}")
        return Circuit(self.model, value, new_shape)

    def softmax(self, dim=-1):
        return Circuit(self.model, self.value.softmax(dim), self._shape)

    def norm(self, dim: Optional[str] = None, keepdim: bool = False):
        dim_index = self._dim_name_to_index(dim)
        new_shape = self._shape[:dim_index] + self._shape[dim_index + 1:]
        return Circuit(self.model, self.value.norm(dim=dim_index, keepdim=False), new_shape)

    def normalise(self, dim: Optional[str] = None):
        dim_index = self._dim_name_to_index(dim)
        return Circuit(self.model, self.value / self.value.norm(dim=dim_index, keepdim=True), self._shape)

    def remove_diag(self):
        return Circuit(self.model, self.value - self.value.diag().diag(), self._shape)

    def flatten(self, dim1: str, dim2: str):
        """Flatten the two dimensions into one"""
        assert dim1 in self._shape, f"Dimension {dim1} not in {self._shape}"
        assert dim2 in self._shape, f"Dimension {dim2} not in {self._shape}"
        assert dim1 != dim2, f"Cannot flatten {dim1} and {dim2} as they are the same"

        # Put the two dimensions at the end
        new_shape = [d for d in self._shape if d not in [dim1, dim2]] + [dim1, dim2]
        self = self.rearange(new_shape)

        # Flatten the two dimensions
        new_shape = new_shape[:-2] + [f"flat_{dim1}_{dim2}"]
        value = self.value.flatten(-2, -1)
        
        return Circuit(self.model, value, new_shape)

    def plot(self, facet_name: Optional[str]=None, type: Literal['imshow', 'line', 'histogram'] = 'imshow', **kwargs):
        if facet_name is None:
            if len(self._shape) == 3 and 'head' in self._shape:
                facet_name = 'head'

        if facet_name:
            self = self.by(facet_name)

        dim_plot = 2 if type == 'imshow' else 1

        if len(self._shape) > dim_plot:
            facet_dim = (-dim_plot - 1) % len(self._shape)
            facet_name = self._shape[facet_dim]
            nb_plots = self.value.shape[facet_dim]
            kwargs.setdefault('facet_col', facet_dim)
            if nb_plots == 4:
                wrap = 2
            else:
                wrap = 3
            kwargs.setdefault('facet_col_wrap', wrap)
            kwargs.setdefault('height', 500 * (nb_plots // wrap))

            if facet_name.startswith('head'):
                kwargs.setdefault('facet_labels', [f"Head {i}" for i in range(nb_plots)])
            elif facet_name.startswith('layer'):
                kwargs.setdefault('facet_labels', [f"Layer {i}" for i in range(nb_plots)])
            else:
                print(f"Warning: Don't know how to label {facet_name}")


        if len(self._shape) == dim_plot + 2:
            kwargs.setdefault('animation_frame', 0)
        elif len(self._shape) > dim_plot + 2:
            raise ValueError(f"Cannot plot {self}: Too many dimensions.")

        kwargs.setdefault('title', str(self))

        x_name = self._shape[-1]
        if x_name.startswith('vocab'):
            if self.shape[x_name] == len(TOKEN_NAMES):
                kwargs.setdefault('x', TOKEN_NAMES)
            elif self.shape[x_name] == len(CELL_TOKEN_NAMES):
                kwargs.setdefault('x', CELL_TOKEN_NAMES)

        if type == 'imshow':
            y_name = self._shape[-2]
            kwargs.setdefault('xaxis_title', x_name)
            kwargs.setdefault('yaxis_title', y_name)
            if y_name.startswith('vocab'):
                if self.shape[y_name] == len(TOKEN_NAMES):
                    kwargs.setdefault('y', TOKEN_NAMES)
                elif self.shape[y_name] == len(CELL_TOKEN_NAMES):
                    kwargs.setdefault('y', CELL_TOKEN_NAMES)

            try:
                imshow(self.value, **kwargs)
            except Exception as e:
                print("Error while plotting", self)
                raise

        elif type == 'line':
            kwargs.setdefault('xaxis', x_name)
            line(self.value, **kwargs)
        elif type == 'histogram':
            # kwargs.setdefault('xaxis', x_name)
            display(px.histogram(self.value, **kwargs))
        else:
            raise ValueError(f"Unknown type: {type}")
        return self
            
C = Circuit(model)

# (Circuit(model)
#     .embedding()
#     .new_dim('pos')
#     .add(C.pos_embedding())
#     .qk(0, key=C.embedding())
#     .by("pos", "head")
#     .softmax()
# ).plot()

(C
    .embedding()
    .new_dim('pos')
    .add(C.pos_embedding())
    .norm(dim='dmodel')
    ['vocab', 1:]
    ['pos', 2:]
    .plot(title="Distribution of the norm of the positional embeddings + token embeddings",
          height=800)
    .flatten('vocab', 'pos')
    .plot(type='histogram', 
          title="Distribution of the norm of the positional embeddings + token embeddings",
          )
)


            
(C
    .pos_embedding()
    .normalise('dmodel')
    .qk(0, key=C
        .pos_embedding()
        .normalise('dmodel')
    )
).plot('head', title="Attention score from positional embeddings matching")
    

# (C
#     .embedding()
#     .new_dim('pos')
#     .add(C.pos_embedding())
#     .qk(0, key=C
#         .embedding())
#     .softmax()
# ).plot('head')

Circuit(head=8, pos_q=59, pos_k=59)

In [159]:
num_components = 15

pos_emb = model.W_pos[1:]
pos_emb_normalised = pos_emb / pos_emb.norm(dim=-1, keepdim=True)

pca = plot_PCA(pos_emb, name="Positional embeddings")
# Show which positions are aligned with the PCA axes (first 4)
pca_directions = pca.components_[:num_components]
pca_directions = t.tensor(pca_directions, device=device)


# Show the dot product between the PCA directions and the positional embeddings
imshow(pca_directions @ pos_emb_normalised.T,
        xaxis_title="Position",
        yaxis_title="PCA direction index",
       title="Dot product between PCA directions and positional embeddings")

In [None]:
# Plot the bias
line(
    model.b_K[0],
)

In [64]:
logit_attribution_emb_ov_circuit(model, 0)

In [105]:
line(model.b_O[0] @ model.W_U, x=token_names)
model.b_O.shape

torch.Size([8, 512])

In [92]:
model.b_K[0].norm(dim=-1)

tensor([4.1463, 0.8333, 3.9880, 3.9358, 0.8857, 0.7185, 2.8934, 2.5379])

In [None]:
model_metrics = ModelMetricPlotter(val_tokens, model)

In [None]:
model_metrics.plot_loss_per_move(True)
model_metrics.plot_loss_per_move()
model_metrics.plot_whole_board_accuracy()

In [None]:
val_valid.shape

In [None]:
def explain_game(
    tokens: Int[Tensor, 'move'],
    model: HookedTransformer,
    move: int = None,
    # Extra args
    extra_boards: Optional[Dict[str, Tensor]] = None,
) -> None:
    if isinstance(tokens, list):
        tokens = t.tensor(tokens, dtype=t.long)
    if move is None:
        move = tokens.shape[0] - 1
    move = move % tokens.shape[0]
    if extra_boards is None:
        extra_boards = {}
        
    tokens = tokens[None, :move+1].to(model.cfg.device)
    board_indices = tokens_to_board(tokens)

    logits = model(tokens)

    scale = logits.abs().max().item() * 0.5
    predictions = logits_to_board(logits, 'log_prob')
    state = move_sequence_to_state(board_indices, 'mine-their') * scale
    valid = move_sequence_to_state(board_indices, mode="valid") * scale

    # Scale extra boards if they are boolean
    for k, v in extra_boards.items():
        if v.dtype == t.bool:
            extra_boards[k] = v * scale

    # Plotting
    boards = t.stack([
        state[0, -1],
        valid[0, -1],
        predictions[0, -1],
        *extra_boards.values(),
    ], dim=0)
    plot_square_as_board(
        boards,
        facet_col=0,
        facet_col_wrap=3 if len(boards) != 4 else 2,
        facet_labels=["State before", "Valid moves", "Model logits", *extra_boards.keys()],
        title=f"Game after {to_board_label(board_indices[0, -1].item())} - blue to play - move {move}",
    )


explain_game([20, 19], model)
explain_game(val_tokens[3], model, 0)

In [None]:
@t.inference_mode()
def find_fail_datapoints(
    model: HookedTransformer,
    tokens: Int[Tensor, 'game move'],
    threshold: float = 3,
    nb_examples: int = 10,
    biggest_mistakes: bool = True,
) -> None:
    """
    Find datapoints where the model fails to predict the valid moves.

    Plots the board for each datapoint where the model fails to predict the valid moves.
    """

    # Get the probabilities for each cell
    tokens = tokens[:, :59].to(model.cfg.device)
    logits = model(tokens[:, :59])
    probabilities = logits_to_board(logits, 'prob')  # [game, move, row, col]

    # Compute the number of valid moves
    valid_moves = move_sequence_to_state(tokens_to_board(tokens), mode="valid")[:, :59]
    nb_valid_moves = valid_moves.sum(dim=(-1, -2), keepdim=True)  # [game, move, 1, 1]

    # Find which moves are considered correct
    if biggest_mistakes:
        error = t.zeros_like(probabilities)
        error[valid_moves] = 1 - (probabilities * nb_valid_moves)[valid_moves]
        error[~valid_moves] = (probabilities * nb_valid_moves)[~valid_moves]
        error = error.max(dim=-1).values.max(dim=-1).values.flatten()
        # plot histogram of errors
        display(px.histogram(error.flatten()))
        biggest_mistakes_flat = error.topk(nb_examples).indices
        incorrect_indices = t.stack([biggest_mistakes_flat // 59, biggest_mistakes_flat % 59],
                                    dim=-1)
    else:

        predictions = probabilities > (1 / (threshold * nb_valid_moves))
        correct = predictions == valid_moves
        correct_boards = (predictions == valid_moves).all(-1).all(-1)

        # Find the indices of the incorrect moves
        incorrect_indices = t.nonzero(~correct_boards, as_tuple=False)
        # Sample from the incorrect indices
        incorrect_indices = incorrect_indices[t.randperm(len(incorrect_indices))]
        incorrect_indices = incorrect_indices[:nb_examples]

    # Plot the incorrect moves
    for game, move in incorrect_indices:
        explain_game(tokens[game],
                     model,
                     move,
                     extra_boards={
                         "Model's mistakes": ~correct[game, move],
                     })
        # plot_board(TOKENS_TO_BOARD[tokens[game]])


find_fail_datapoints(
    model,
    val_tokens[:20],
    threshold=10,
    nb_examples=1,
    biggest_mistakes=False,
)


## Attention

In [None]:
def plot_average_attention(model: HookedTransformer, tokens: Int[Tensor, "game move"],
                           *layers: int):

    move_nb = min(tokens.shape[1], 59)
    total_attention = t.zeros(model.cfg.n_layers, model.cfg.n_heads, move_nb, move_nb, device=model.cfg.device)
    def hook(activation: Float[Tensor, "game head query key"], hook: HookPoint):
        total_attention[hook.layer()] += activation.mean(0)

    model.run_with_hooks(tokens[:, :59], fwd_hooks=[
        (lambda n: 'pattern' in n, hook)
    ])

    # Plot all the attention patterns
    if not layers:
        layers = list(range(model.cfg.n_layers))
    else:
        layers = list(layers)
    
    plots = einops.rearrange(total_attention[layers], "layer head row col -> (layer head) row col")
    labels = [f"Head {head} Layer {layer}" for layer in layers for head in range(model.cfg.n_heads)]
    tokens_labels = [f" {i}" for i in range(move_nb)]
    display(circuitsvis.attention.attention_patterns(tokens_labels, plots, labels))

In [None]:
plot_average_attention(model, val_tokens[:1], 0)

In [92]:
def plot_head_attention(model: HookedTransformer, tokens: Int[Tensor, "game move"], 
                        layer: int,
                        head: int):
    move_nb = min(tokens.shape[1], 59)
    attention = t.zeros(tokens.shape[0], move_nb, move_nb, device=model.cfg.device)
    def hook(activation: Float[Tensor, "game head query key"], hook: HookPoint):
        attention[...] += activation[:, head]

    model.run_with_hooks(tokens[:, :59], fwd_hooks=[
        (utils.get_act_name('pattern', layer), hook)
    ])

    # Plot the attention patterns
    tokens_labels = [f" {i}" for i in range(move_nb)]
    return circuitsvis.attention.attention_patterns(
        tokens_labels,
        attention,
        # [f"Head {head} Layer {layer} - game {i}" for i in range(tokens.shape[0])],
    )

plot_head_attention(model, val_tokens[:12], 0, 0)

In [None]:
model.reset_hooks()

## Ablation Study

In [None]:
individual_heads = False

n_games = 50
tokens = full_games_tokens[-n_games:].to(device)
board_index = full_games_board_index[-n_games:].to(device)
get_metrics = lambda model: get_loss(model, tokens, board_index, 5, -5).to_tensor()
zero_ablation_metrics = zero_ablation(model, get_metrics, individual_heads).cpu()
base_metrics = get_metrics(model)

In [None]:
# Plotting the results
x_labels = [f"Head {i}" for i in range(model.cfg.n_heads)] + ["All Heads", "MLP"]
y_labels = [f"Layer {i}" for i in range(model.cfg.n_layers)]
if not individual_heads:
    x_labels = x_labels[-2:]

imshow(
    zero_ablation_metrics[:3] - base_metrics[:3, None, None],
    title="Metric increase after zeroing each component",
    x=x_labels,
    y=y_labels,
    facet_col=0,
    facet_labels=["Loss", "Cell accuracy", "Board accuracy"],
)

In [None]:
# Abblate all attention after the layer `n`

def filter(name: str, start_layer: int = 0):
    if not name.startswith("blocks."):
        # 'hook_embed' or 'hook_pos_embed' or 'ln_final.hook_scale' or 'ln_final.hook_normalized'
        return False

    layer = int(name.split(".")[1])

    return layer >= start_layer and "attn_out" in name

metrics_per_layer = []
for start_layer in range(model.cfg.n_layers):
    with model.hooks(fwd_hooks=[(partial(filter, start_layer=start_layer),
                                    zero_ablation_hook)]):
        metrics_per_layer.append(get_metrics(model))

In [None]:
# %% Plot
lines = t.stack(metrics_per_layer, dim=1).cpu()
line(
    lines[:3],
    x=[f"≥ {i}" for i in range(model.cfg.n_layers)],
    facet_col=0,
    #  facet_col_wrap=3,
    facet_labels=[
        "Loss",
        "Cell accuracy",
        "Board accuracy",
    ],  # 'False Positive', 'False Negative', 'True Positive', 'True Negative'],
    title="Metrics after zeroing all attention heads above a layer",
)

## Exploration of the probe

In [None]:
neel_probe = get_neels_probe(False, device)

blank_probe, my_probe, their_probe = neel_probe.unbind(dim=-1)
blank_direction = blank_probe - (their_probe + my_probe) / 2
my_direction = my_probe - their_probe

### Accuracy and cosine similarity

In [None]:
plot_aggregate_metric(
    val_tokens,
    model,
    neel_probe,
    # per_option=True,
    # per_move="cell_accuracy",
    per_move="board_accuracy",
    name="Neel's probe",
    # prediction="softmax",
    # prediction='logprob',
);

In [None]:
for probe, name in zip([blank_probe, their_probe, my_probe], ["blank", "their", "my"]):
    probe = einops.rearrange(probe, "d_model rows cols -> (rows cols) d_model")
    plot_similarities(probe,
                      title=f"Similarity between {name} vectors",
                      x=full_board_labels,
                      y=full_board_labels)

In [None]:
# %% Similarity between mine and theirs (for each square)
plot_similarities_2(my_probe, their_probe)

In [None]:
probes = [get_probe(i, device=device) for i in range(3)]

for probe, name in zip(probes, ["new probe", "orthogonal probe", "orthogonal probe 2"]):
    plot_probe_accuracy(
        model,
        probe.to(device),
        full_games_tokens[-100:],
        full_games_board_index[-100:],
        per_option=True,
        name=name,
    )


Similarities between the blank probe and the token embeddings

In [None]:
token_embs = model.W_E[1:]
token_embs_64 = t.zeros((64, token_embs.shape[1]), device=token_embs.device)
token_embs_64[TOKENS_TO_BOARD] = token_embs
print(token_embs.shape)
token_embs_64 = einops.rearrange(token_embs_64, "(rows cols) d_model -> d_model rows cols", rows=8)

plot_similarities_2(
    neel_probe[..., 0],
    token_embs_64,
    name="Blank probe and token embeddings",
)

Compute and show probe vector norms

In [None]:
probe_norm = neel_probe.norm(dim=0)
px.histogram(probe_norm.cpu().flatten(), title="Probe vector norms", labels={"value": "norm"})

### Using UMAP

In [None]:

import umap
import umap.plot
import pandas as pd

vectors = einops.rearrange(linear_probe,
                            "d_model rows cols options -> (options rows cols) d_model")

mapper = umap.UMAP(metric="cosine").fit(vectors.cpu().numpy())

labels = [probe_name for probe_name in ["blank", "their", "my"] for _ in full_board_labels]
hover_data = pd.DataFrame({
    "square": full_board_labels * 3,
    "probe": labels,
})

umap.plot.show_interactive(
    umap.plot.interactive(mapper, labels=labels, hover_data=hover_data, theme="inferno"))


### Using PCA

In [None]:
linear_probe = get_probe(device)

# %% Run PCA on the vectors of the probe
vectors = einops.rearrange(linear_probe, "d_model rows cols options -> (options rows cols) d_model")
plot_PCA(vectors, "the probe vectors")
# %% The same be per option
for i in range(3):
    plot_PCA(
        linear_probe[..., i],
        f"the probe vectors for option {i}",
        flip_dim_order=True,
        absolute=True,
    )

# %% Normalise the probe then run PCA
normalised_probe = linear_probe / linear_probe.norm(dim=-1, keepdim=True)
plot_PCA(normalised_probe, "the normalised probe vectors", flip_dim_order=True)

# %% Same PCA but with the unembeddings
plot_PCA(model.W_U, "the unembeddings")

# %%
plot_PCA(model.W_pos, "the embeddings")
# %%
all_vectors = [
    model.W_U.T,
    model.W_E,
    model.W_pos,
    vectors,
]
all_vectors = [(v - v.mean(dim=0)) / v.std(dim=0) for v in all_vectors]

plot_PCA(t.cat(all_vectors, dim=0), "the embeddings and unembeddings")
# %%
plot_PCA(t.cat([my_direction, blank_direction], dim=1).flatten(1).T, "the direction vectors")
# %%
plot_PCA(my_direction.flatten(1).T, "the direction vectors")
# %%
plot_PCA(blank_direction.flatten(1).T, "the direction vectors")


## Training probes

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from probes import ProbeTrainingArgs, LitLinearProbe, PROBE_DIR

In [None]:
MAKE_NEW_TRAINING_DATA = False
if MAKE_NEW_TRAINING_DATA:
    games_tokens, games_valid_moves = make_training_data()
else:
    games_tokens, games_valid_moves = get_training_data()

games_board_index = tokens_to_board(games_tokens)
games_states = move_sequence_to_state(games_board_index, mode="mine-their")

In [None]:
valid_tokens, _ = generate_training_data(1_000, seed=69)
valid_states = move_sequence_to_state(tokens_to_board(valid_tokens), mode="mine-their")

### Compute stats on the dataset

In [None]:
if not STATS_PATH.exists():
    stats = compute_stats(games_states, games_valid_moves)
else:
    stats = t.load(STATS_PATH)
    print("Stats shape:", stats.shape)

stat_names = ["Empty", "My piece", "Their piece", "Valid move"]


In [None]:
plot_square_as_board(
    stats.mean(1),
    title="Average frequency of each cell being ...",
    facet_col=0,
    facet_labels=stat_names,
)



In [None]:
# Plot per move
lines = stats[1:].mean((2, 3)).T

df = pd.DataFrame(lines.tolist(), columns=stat_names[1:])
df["Move"] = df.index
# Add propotion of my pieces
df["Proportion of my pieces"] = df["My piece"] / (df["My piece"] + df["Their piece"])
df = df.melt(id_vars=["Move"], var_name="Cell type", value_name="Frequency")
px.line(
    df,
    x="Move",
    y="Frequency",
    color="Cell type",
    title="Frequency of each cell being ...",
    labels={
        "value": "Frequency",
        "index": "Move"
    },
    # legend=stat_names,
)

In [None]:
# %% Plot stats per cell and move
if 0:
    moves_to_show = [0, 5, 10, 20, 30, 40, 50, 55]
    x = einops.rearrange(stats[:, moves_to_show], "option m r c -> (m option) r c")
    labels = [f"{name} (move {move})" for move in moves_to_show for name in stat_names]

    plot_square_as_board(
        x,
        facet_col=0,
        facet_col_wrap=4,
        facet_labels=labels,
        title="Average frequency of each cell being ...",
        height=3000,
    )


### Actual Training

In [None]:
import dataclasses


def train_probe(args: ProbeTrainingArgs, *orthogonal):
    if args.probe_name == ProbeTrainingArgs.probe_name:
        args.probe_name = f'probe-{len(orthogonal)}'
        if args.black_and_white:
            args.probe_name += '-bw'
        if args.correct_for_dataset_bias:
            args.probe_name += '-unbiased'

    lit_ortho_probe = LitLinearProbe(model, args, *orthogonal)

    wandb.finish()
    logger = WandbLogger(save_dir=os.getcwd() + "/logs", project='orthogonal-probes')
    config = dataclasses.asdict(args)
    del config['train_tokens']
    del config['valid_tokens']
    logger.log_hyperparams(config)
    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        logger=logger,
        log_every_n_steps=5,
        val_check_interval=100,
        check_val_every_n_epoch=None,
    )

    trainer.fit(model=lit_ortho_probe)
    probe = lit_ortho_probe.linear_probe
    
    plot_aggregate_metric(
        valid_tokens,
        model,
        probe,
        # per_option=True,
        per_move="board_accuracy",
    )

    wandb.finish()

    path = PROBE_DIR / f"{args.probe_name}.pt"
    if not path.exists():
        t.save(lit_ortho_probe.linear_probe, path)
        print(f"Saved probe to {path.resolve()}")
    else:
        print(f"Warning: {path.resolve()} already exists. Not saving the probe.")

    return probe

In [None]:
probe = get_probe(0, device=device, base_name='probe-0.9')

In [None]:
probe2 = train_probe(ProbeTrainingArgs(games_tokens, val_tokens, lr=0.001, batch_size=600, max_epochs=12),
                    probe,)

## Constructing residual stream

In [None]:
def make_residual_stream(
    world: Union[Int[Tensor, "row=8 cols=8"], str],
    probe: Float[Tensor, "d_model rows cols options=3"],
) -> Float[Tensor, "d_model"]:
    """
    Create the embedding of a board state according to the probe

    Args:
        world: the board state, with 0 for blank, +1 for mine, -1 for theirs
        probe: directions in the residual stream that correspond to each square.
            The last dimension is the options, with 0 for blank, 1 for mine, 2 for theirs
    """

    if isinstance(world, str):
        world = board_to_tensor(world)

    d_model = probe.shape[0]
    blank_direction = probe[..., 0] - (probe[..., 1] + probe[..., 2]) / 2
    my_direction = probe[..., 1] - probe[..., 2]

    world = world.to(probe.device)
    embedding = t.zeros(d_model, device=probe.device)
    for row in range(world.shape[0]):
        for col in range(world.shape[1]):
            if world[row, col] == 0:
                embedding += blank_direction[:, row, col]
            else:
                embedding += my_direction[:, row, col] * world[row, col]

    return embedding


# %% Try to run the model on a virtual residual stream


def hook(activation: Float[Tensor, "game move d_model"], hook: HookPoint):
    activation[:, -1] = resid


layer = 4
act_name = utils.get_act_name("resid_pre", layer)
osef_input = focus_games_tokens[:1, :20]  # 1 game, 20 moves
logits = model.run_with_hooks(osef_input, fwd_hooks=[(act_name, hook)])

# Plot what the model predicts
logits = logits_to_board(logits[0, -1], "log_prob")
plot_square_as_board(logits, title="Model predictions")


In [None]:
board = """
........
........
........
...xo...
...ox...
........
........
........
"""

board_2 = """
........
........
........
..xxxx..
.xooooo.
.o..ox..
.x.oxo..
........
"""
board_3 = """
........
........
........
...xo...
...oo...
....o...
........
........
"""

board_tensor = board_to_tensor(board_3)
resid = make_residual_stream(board_tensor, linear_probe)
# plot_square_as_board(board_tensor)


In [None]:
@t.inference_mode()
def modify_resid_given_probe(
        model: HookedTransformer,
        moves_orig: Int[Tensor, "move"],
        moves_new: Int[Tensor, "move"],
        *probes: Float[Tensor, "d_model rows cols options=3"],
        layer: int = 6,
        cells: Tuple[str, ...] = (),
):
    act_name = utils.get_act_name("resid_pre", layer)
    new_logits, new_cache = model.run_with_cache(
        moves_new,
        names_filter=lambda name: name == act_name,
    )

    def hook(orig_activation: Float[Tensor, "game move d_model"], hook: HookPoint):
        # Step 0. Find a basis of the subspace of the probe
        # collect the probe vectors
        all_probes = t.stack(probes, dim=-1).to(orig_activation.device)
        if cells:
            rows_cols = t.tensor([board_label_to_row_col(cell) for cell in cells])
            all_probes = all_probes[:, rows_cols[:, 0], rows_cols[:, 1]]

        probe_vectors = einops.rearrange(all_probes, "d_model ... -> (...) d_model")

        orig_activation[:, -1] = swap_subspace(
            orig_activation[:, -1],
            new_cache[act_name][:, -1],
            probe_vectors,
        )

    patched_logits = model.run_with_hooks(
        moves_orig,
        fwd_hooks=[(act_name, hook)],
    )

    # display the logits
    orig_valid_moves = move_sequence_to_state(tokens_to_board(moves_orig), mode="valid")
    if cells:
        rows_cols = [board_label_to_row_col(cell) for cell in cells]
        index = tuple(zip(*rows_cols))

        # Compute the state of orig and new board
        new_board_state = move_sequence_to_state(tokens_to_board(moves_new), mode="normal")[0, -1]
        orig_board_state = move_sequence_to_state(tokens_to_board(moves_orig), mode="normal")[0, -1]
        # Put the cells of the new board in the orig board
        orig_board_state[index] = new_board_state[index]

        valid_cells = valid_moves_from_board(orig_board_state, moves_orig.shape[1])
        new_valid_moves = one_hot(valid_cells).reshape(1, 1, 8, 8)

    else:
        new_valid_moves = move_sequence_to_state(tokens_to_board(moves_new), mode="valid")

    orig_logits = logits_to_board(model(moves_orig)[0, -1], "log_prob")
    patched_logits = logits_to_board(patched_logits[0, -1], "log_prob")
    new_logits = logits_to_board(new_logits[0, -1], "log_prob")

    scale = new_logits.abs().max().cpu()

    to_stack = [
        orig_valid_moves[0, -1] * scale,
        orig_logits,
        patched_logits,
        new_logits,
        new_valid_moves[0, -1] * scale,
        patched_logits - orig_logits,
    ]

    all_logits = t.stack([t.cpu() for t in to_stack], dim=-1)
    plot_square_as_board(
        all_logits,
        title="Model predictions",
        facet_col=-1,
        facet_col_wrap=3,
        facet_labels=[
            "New expected",
            "new logits",
            "logit diff (patch - orig)",
            "original expected",
            "original logits",
            "patched logits",
        ],
    )

    # plot_square_as_board(logits_to_board(new_logits[0, -1], 'log_prob'),
    #                      title="Model predictions (new)")
    # plot_square_as_board(logits_to_board(patched_logits[0, -1], 'log_prob'),
    #                      title="Model predictions (patched)")

    # with model.hooks(fwd_hooks=[(act_name, hook)]):
    #     plot_board_log_probs(
    #         tokens_to_board(moves_new[0]),
    #         patched_logits[0],
    #     )


orig_index = 2
new_index = 3
move_index = 20
layer = 4
orig_games = focus_games_tokens[orig_index:orig_index + 1, :move_index]
new_games = focus_games_tokens[new_index:new_index + 1, :move_index]

modify_resid_given_probe(model, orig_games, new_games, *probes, layer=layer, cells=["D2"])


In [None]:
plot_single_board(focus_games_board_index[orig_index, :move_index], title="Original game")
plot_single_board(focus_games_board_index[new_index, :move_index], title="New game")