# Exploration of OthelloGPT

In [None]:
print(1 + 1)

In [None]:
%load_ext autoreload
%autoreload 2

import os

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from functools import partial
from typing import Tuple, Union, Dict, Literal

from rich import print as rprint
from circuitsvis.attention import attention_patterns
import einops
import pandas as pd
import plotly.express as px
import torch as t
import transformer_lens.utils as utils
from dataclasses import dataclass, asdict
import wandb
from jaxtyping import Float, Int
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

from plotly_utils import imshow

import circuitsvis

import utils as u
import circuit
from utils import *
from plotting import *
from probes import get_probe, get_neels_probe, ProbeTrainingArgs, LitLinearProbe

from othello_world.mechanistic_interpretability.mech_interp_othello_utils import (
    plot_single_board, )




## Setup

Things that you probably always want to run.

In [None]:
t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

cfg, model = get_othello_gpt(device)

In [None]:
val_tokens, val_valid = generate_training_data(100, 69)

In [None]:
# %% Loading sample data
full_games_tokens, full_games_board_index = load_sample_games()
num_games = 50
focus_games_tokens = full_games_tokens[:num_games]
focus_games_board_index = full_games_board_index[:num_games]
assert (TOKENS_TO_BOARD[focus_games_tokens] == focus_games_board_index).all()

focus_states = move_sequence_to_state(focus_games_board_index)
focus_valid_moves = move_sequence_to_state(focus_games_board_index, mode="valid")

print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


## Exploring Othello

In [None]:
states = move_sequence_to_state(tokens_to_board(val_tokens[:1]), mode="black-white")
print(states.shape)
(circuit.Kuit(model, states[0], ["move", "row", "col"])
    .new_dim('facet')
    .by('move', 'facet')
    .plot()
)


## Exploration of the model

In [None]:
kuit = circuit.Kuit(model)

### Token embeddings

In [None]:
(kuit.pos_embedding(True)
     .plot(title="Normalised positional embedding matrix")
)
(kuit.embedding(True)
     .plot(title="Normalised token embedding matrix")
     .flatten()
     .histogram(title="Histogram of normalised token embeddings values")
)
(kuit.embedding()
     .norm('dmodel')
     [1:]  # skip the first one, which is big
     .histogram(title="Histogram of token embeddings norms", nbins=50)
)

The token embeddings seem kinda 1-hot encoded

In [None]:
# Count the number of > 0.5 values along dmodel
threshold = 0.2
(kuit.embedding(True)
     (lambda x: x.abs() > threshold)
     .sum('vocab')
     .line(title=f"Number of tokens with embedding value > {threshold}")
)


In [None]:
(kuit.pos_embedding(True)
     .pos_embedding(True)
     .plot(title="Cosine similarity between positional embeddings")
     ['pos_1', ::2]
     ['pos_2', ::2]
     .plot(title="Cosine similarity between even positional embeddings")
)

In [None]:
(kuit.embedding(True)
     .embedding(True)
     .remove_diag()
     .tokens_to_board('vocab_2', 0)
     .new_dim('facet')
     .by('vocab_1', 'facet')
     .plot(title="Cosine similarity between token embeddings",
           height=600)
)

### QK circuit of first layer

In [None]:
# (kuit
#     .embedding()
#     .new_dim('pos')
#     .add(kuit.pos_embedding())
#     .qk(0, key=kuit.embedding(), softmax=True)
#     .by("pos", "head")
# ).plot()

(kuit
    .pos_embedding()
    .normalise('dmodel')
    .qk(0, key=kuit
        .pos_embedding()
        .normalise('dmodel')
    )
    .softmax()
        (lambda x: x * t.arange(1, 60, device=device)[:, None].float())
    .by('head')
    .plot(title="Attention score from positional embeddings matching")
)
    

(kuit
    .embedding()
    .new_dim('pos')
    .add(kuit.pos_embedding())
    .norm(dim='dmodel')
    ['vocab', 1:]
    ['pos', 2:]
    .plot(title="Distribution of the norm of the positional embeddings + token embeddings",
          height=800)
    .flatten('vocab', 'pos')
    .histogram(title="Distribution of the norm of the positional embeddings + token embeddings")
)



# (C
#     .embedding()
#     .new_dim('pos')
#     .add(C.pos_embedding())
#     .qk(0, key=C
#         .embedding())
#     .softmax()
# ).plot('head')

In [None]:
from itertools import product

pos = kuit.pos_embedding()
pos.name = "pos"
emb = kuit.embed(val_tokens[0, :59], positional=False)
emb.name = "emb"

move_labels = [to_board_label(tokens_to_board(t)) for t in val_tokens[0, :59]]

for key, query in product([pos, emb], [pos, emb]):
    (key.normalise('dmodel')
        .qk(0,
            key=key.normalise('dmodel'),
        )
        .softmax()
        (lambda x: x * t.arange(1, 60, device=device)[:, None].float())
        .by('head')
        .plot(facet_col_wrap=4, 
              title=f"Attention score from {key.name} matching {query.name}",
              x=move_labels,
              y=move_labels,
        )
    )
        

In [None]:
num_components = 15

pos_emb = model.W_pos[1:]
pos_emb_normalised = pos_emb / pos_emb.norm(dim=-1, keepdim=True)

pca = plot_PCA(pos_emb, name="Positional embeddings")
# Show which positions are aligned with the PCA axes (first 4)
pca_directions = pca.components_[:num_components]
pca_directions = t.tensor(pca_directions, device=device)


# Show the dot product between the PCA directions and the positional embeddings
imshow(pca_directions @ pos_emb_normalised.T,
        xaxis_title="Position",
        yaxis_title="PCA direction index",
       title="Dot product between PCA directions and positional embeddings")

### Other QK circuits

In [None]:
layer = 7
(kuit
    .embedding(True)
    .qk(layer, key=kuit.embedding(True))
    .tokens_to_board('vocab_k', 0)
    .print()
    .by('vocab_q', 'head')
    .plot(title="Attention score from each embedding as query, to each embedding as key",
          height=900, facet_col_wrap=4)
)


### Emb - OV - Unemb circuit

In [None]:
(circuit.Kuit(model)
    .embedding()
    .normalise('dmodel')
    .print()
    .ov(0)
    .print()
    .unembed()
    # .unembed_bias()
    .by('head')
    .plot(facet_col_wrap=4)
)

In [None]:
(Circuit(model)
    .unembed_bias()
    .tokens_to_board(fill=0)
    .plot(title="Unembedding bias"))

### Accuracy and failures of the model

In [None]:
model_metrics = ModelMetricPlotter(val_tokens, model)

In [None]:
model_metrics.plot_loss_per_move(True)
model_metrics.plot_loss_per_move()
model_metrics.plot_whole_board_accuracy()

In [None]:
val_valid.shape

In [None]:
def explain_game(
    tokens: Int[Tensor, 'move'],
    model: HookedTransformer,
    move: int = None,
    # Extra args
    extra_boards: Optional[Dict[str, Tensor]] = None,
) -> None:
    if isinstance(tokens, list):
        tokens = t.tensor(tokens, dtype=t.long)
    if move is None:
        move = tokens.shape[0] - 1
    move = move % tokens.shape[0]
    if extra_boards is None:
        extra_boards = {}
        
    tokens = tokens[None, :move+1].to(model.cfg.device)
    board_indices = tokens_to_board(tokens)

    logits = model(tokens)

    scale = logits.abs().max().item() * 0.5
    predictions = logits_to_board(logits, 'log_prob')
    state = move_sequence_to_state(board_indices, 'mine-their') * scale
    valid = move_sequence_to_state(board_indices, mode="valid") * scale

    # Scale extra boards if they are boolean
    for k, v in extra_boards.items():
        if v.dtype == t.bool:
            extra_boards[k] = v * scale

    # Plotting
    boards = t.stack([
        state[0, -1],
        valid[0, -1],
        predictions[0, -1],
        *extra_boards.values(),
    ], dim=0)
    plot_square_as_board(
        boards,
        facet_col=0,
        facet_col_wrap=3 if len(boards) != 4 else 2,
        facet_labels=["State before", "Valid moves", "Model logits", *extra_boards.keys()],
        title=f"Game after {to_board_label(board_indices[0, -1].item())} - blue to play - move {move}",
    )


explain_game([20, 19], model)
explain_game(val_tokens[3], model, 0)

In [None]:
@t.inference_mode()
def find_fail_datapoints(
    model: HookedTransformer,
    tokens: Int[Tensor, 'game move'],
    threshold: float = 3,
    nb_examples: int = 10,
    biggest_mistakes: bool = True,
) -> None:
    """
    Find datapoints where the model fails to predict the valid moves.

    Plots the board for each datapoint where the model fails to predict the valid moves.
    """

    # Get the probabilities for each cell
    tokens = tokens[:, :59].to(model.cfg.device)
    logits = model(tokens[:, :59])
    probabilities = logits_to_board(logits, 'prob')  # [game, move, row, col]

    # Compute the number of valid moves
    valid_moves = move_sequence_to_state(tokens_to_board(tokens), mode="valid")[:, :59]
    nb_valid_moves = valid_moves.sum(dim=(-1, -2), keepdim=True)  # [game, move, 1, 1]

    # Find which moves are considered correct
    if biggest_mistakes:
        error = t.zeros_like(probabilities)
        error[valid_moves] = 1 - (probabilities * nb_valid_moves)[valid_moves]
        error[~valid_moves] = (probabilities * nb_valid_moves)[~valid_moves]
        error = error.max(dim=-1).values.max(dim=-1).values.flatten()
        # plot histogram of errors
        display(px.histogram(error.flatten()))
        biggest_mistakes_flat = error.topk(nb_examples).indices
        incorrect_indices = t.stack([biggest_mistakes_flat // 59, biggest_mistakes_flat % 59],
                                    dim=-1)
    else:

        predictions = probabilities > (1 / (threshold * nb_valid_moves))
        correct = predictions == valid_moves
        correct_boards = (predictions == valid_moves).all(-1).all(-1)

        # Find the indices of the incorrect moves
        incorrect_indices = t.nonzero(~correct_boards, as_tuple=False)
        # Sample from the incorrect indices
        incorrect_indices = incorrect_indices[t.randperm(len(incorrect_indices))]
        incorrect_indices = incorrect_indices[:nb_examples]

    # Plot the incorrect moves
    for game, move in incorrect_indices:
        explain_game(tokens[game],
                     model,
                     move,
                     extra_boards={
                         "Model's mistakes": ~correct[game, move],
                     })
        # plot_board(TOKENS_TO_BOARD[tokens[game]])


find_fail_datapoints(
    model,
    val_tokens[:40, :30],
    threshold=2,
    nb_examples=10,
    biggest_mistakes=False,
)


## Attention

In [None]:
def plot_average_attention(model: HookedTransformer, tokens: Int[Tensor, "game move"],
                           *layers: int):

    move_nb = min(tokens.shape[1], 59)
    total_attention = t.zeros(model.cfg.n_layers, model.cfg.n_heads, move_nb, move_nb, device=model.cfg.device)
    def hook(activation: Float[Tensor, "game head query key"], hook: HookPoint):
        total_attention[hook.layer()] += activation.mean(0)

    model.run_with_hooks(tokens[:, :59], fwd_hooks=[
        (lambda n: 'pattern' in n, hook)
    ])

    # Plot all the attention patterns
    if not layers:
        layers = list(range(model.cfg.n_layers))
    else:
        layers = list(layers)
    
    plots = einops.rearrange(total_attention[layers], "layer head row col -> (layer head) row col")
    labels = [f"Head {head} Layer {layer}" for layer in layers for head in range(model.cfg.n_heads)]
    tokens_labels = [f" {i}" for i in range(move_nb)]
    display(circuitsvis.attention.attention_patterns(tokens_labels, plots, labels))

In [None]:
plot_average_attention(model, val_tokens[:1], 0)

In [None]:
def plot_head_attention(model: HookedTransformer, tokens: Int[Tensor, "game move"], 
                        layer: int,
                        head: int):
    move_nb = min(tokens.shape[1], 59)
    attention = t.zeros(tokens.shape[0], move_nb, move_nb, device=model.cfg.device)
    def hook(activation: Float[Tensor, "game head query key"], hook: HookPoint):
        attention[...] += activation[:, head]

    model.run_with_hooks(tokens[:, :59], fwd_hooks=[
        (utils.get_act_name('pattern', layer), hook)
    ])

    # Plot the attention patterns
    tokens_labels = [f" {i}" for i in range(move_nb)]
    return circuitsvis.attention.attention_patterns(
        tokens_labels,
        attention,
        # [f"Head {head} Layer {layer} - game {i}" for i in range(tokens.shape[0])],
    )

plot_head_attention(model, val_tokens[:12], 0, 4)

In [None]:
layer = 1
name = utils.get_act_name('pattern', layer)
_, cache = model.run_with_cache(val_tokens[:1, :59], names_filter=lambda n: n == name)
pattern = cache[name][0]

(circuit.Kuit(model, pattern, ['head', 'pos_q', 'pos_k'])
pattern.shape


## Ablation Study

In [None]:
individual_heads = False

n_games = 50
tokens = full_games_tokens[-n_games:].to(device)
board_index = full_games_board_index[-n_games:].to(device)
get_metrics = lambda model: get_loss(model, tokens, board_index, 5, -5).to_tensor()
zero_ablation_metrics = zero_ablation(model, get_metrics, individual_heads).cpu()
base_metrics = get_metrics(model)

In [None]:
# Plotting the results
x_labels = [f"Head {i}" for i in range(model.cfg.n_heads)] + ["All Heads", "MLP"]
y_labels = [f"Layer {i}" for i in range(model.cfg.n_layers)]
if not individual_heads:
    x_labels = x_labels[-2:]

imshow(
    zero_ablation_metrics[:3] - base_metrics[:3, None, None],
    title="Metric increase after zeroing each component",
    x=x_labels,
    y=y_labels,
    facet_col=0,
    facet_labels=["Loss", "Cell accuracy", "Board accuracy"],
)

In [None]:
# Abblate all attention after the layer `n`

def filter(name: str, start_layer: int = 0):
    if not name.startswith("blocks."):
        # 'hook_embed' or 'hook_pos_embed' or 'ln_final.hook_scale' or 'ln_final.hook_normalized'
        return False

    layer = int(name.split(".")[1])

    return layer >= start_layer and "attn_out" in name

metrics_per_layer = []
for start_layer in range(model.cfg.n_layers):
    with model.hooks(fwd_hooks=[(partial(filter, start_layer=start_layer),
                                    zero_ablation_hook)]):
        metrics_per_layer.append(get_metrics(model))

In [None]:
# %% Plot
lines = t.stack(metrics_per_layer, dim=1).cpu()
line(
    lines[:3],
    x=[f"≥ {i}" for i in range(model.cfg.n_layers)],
    facet_col=0,
    #  facet_col_wrap=3,
    facet_labels=[
        "Loss",
        "Cell accuracy",
        "Board accuracy",
    ],  # 'False Positive', 'False Negative', 'True Positive', 'True Negative'],
    title="Metrics after zeroing all attention heads above a layer",
)

## Exploration of the probe

In [None]:
neel_probe = get_neels_probe(False, device)

blank_probe, my_probe, their_probe = neel_probe.unbind(dim=-1)
blank_direction = blank_probe - (their_probe + my_probe) / 2
my_direction = my_probe - their_probe

### Accuracy and cosine similarity

In [None]:
plot_aggregate_metric(
    val_tokens,
    model,
    neel_probe,
    # per_option=True,
    # per_move="cell_accuracy",
    # per_move="board_accuracy",
    name="Neel's probe",
    # prediction="softmax",
    # prediction='logprob',
);

In [None]:
for probe, name in zip([blank_probe, their_probe, my_probe], ["blank", "their", "my"]):
    probe = einops.rearrange(probe, "d_model rows cols -> (rows cols) d_model")
    plot_similarities(probe,
                      title=f"Similarity between {name} vectors",
                      x=full_board_labels,
                      y=full_board_labels)

In [None]:
# %% Similarity between mine and theirs (for each square)
plot_similarities_2(my_probe, their_probe)

In [None]:
probes = [get_probe(i, device=device) for i in range(1)]

for probe, name in zip(probes, ["new probe", "orthogonal probe", "orthogonal probe 2"]):
    plot_agreggate_metric(
        model,
        probe.to(device),
        full_games_tokens[-100:],
        full_games_board_index[-100:],
        per_option=True,
        name=name,
    )


Similarities between the blank probe and the token embeddings

In [None]:
token_embs = model.W_E[1:]
token_embs_64 = t.zeros((64, token_embs.shape[1]), device=token_embs.device)
token_embs_64[TOKENS_TO_BOARD] = token_embs
print(token_embs.shape)
token_embs_64 = einops.rearrange(token_embs_64, "(rows cols) d_model -> d_model rows cols", rows=8)

plot_similarities_2(
    neel_probe[..., 0],
    token_embs_64,
    name="Blank probe and token embeddings",
)

Compute and show probe vector norms

In [None]:
probe_norm = neel_probe.norm(dim=0)
px.histogram(probe_norm.cpu().flatten(), title="Probe vector norms", labels={"value": "norm"})

### Using UMAP

In [None]:

import umap
import umap.plot
import pandas as pd

vectors = einops.rearrange(linear_probe,
                            "d_model rows cols options -> (options rows cols) d_model")

mapper = umap.UMAP(metric="cosine").fit(vectors.cpu().numpy())

labels = [probe_name for probe_name in ["blank", "their", "my"] for _ in full_board_labels]
hover_data = pd.DataFrame({
    "square": full_board_labels * 3,
    "probe": labels,
})

umap.plot.show_interactive(
    umap.plot.interactive(mapper, labels=labels, hover_data=hover_data, theme="inferno"))


### Using PCA

In [None]:
linear_probe = get_probe(device)

# %% Run PCA on the vectors of the probe
vectors = einops.rearrange(linear_probe, "d_model rows cols options -> (options rows cols) d_model")
plot_PCA(vectors, "the probe vectors")
# %% The same be per option
for i in range(3):
    plot_PCA(
        linear_probe[..., i],
        f"the probe vectors for option {i}",
        flip_dim_order=True,
        absolute=True,
    )

# %% Normalise the probe then run PCA
normalised_probe = linear_probe / linear_probe.norm(dim=-1, keepdim=True)
plot_PCA(normalised_probe, "the normalised probe vectors", flip_dim_order=True)

# %% Same PCA but with the unembeddings
plot_PCA(model.W_U, "the unembeddings")

# %%
plot_PCA(model.W_pos, "the embeddings")
# %%
all_vectors = [
    model.W_U.T,
    model.W_E,
    model.W_pos,
    vectors,
]
all_vectors = [(v - v.mean(dim=0)) / v.std(dim=0) for v in all_vectors]

plot_PCA(t.cat(all_vectors, dim=0), "the embeddings and unembeddings")
# %%
plot_PCA(t.cat([my_direction, blank_direction], dim=1).flatten(1).T, "the direction vectors")
# %%
plot_PCA(my_direction.flatten(1).T, "the direction vectors")
# %%
plot_PCA(blank_direction.flatten(1).T, "the direction vectors")


## Training probes

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from probes import ProbeTrainingArgs, LitLinearProbe, PROBE_DIR

In [None]:
if "make new training data" and False:
    games_tokens, games_valid_moves = make_training_data()
else:
    games_tokens, games_valid_moves = get_training_data()

In [None]:
games_board_index = tokens_to_board(games_tokens)
games_states = move_sequence_to_state(games_board_index, mode="mine-their")

In [None]:
valid_tokens, _ = generate_training_data(1_000, seed=69)
valid_states = move_sequence_to_state(tokens_to_board(valid_tokens), mode="mine-their")

### Compute stats on the dataset

In [27]:
if not STATS_PATH.exists():
    stats = compute_stats(games_states, games_valid_moves)
else:
    stats = t.load(STATS_PATH)
    print("Stats shape:", stats.shape)

stat_names = ["Empty", "My piece", "Their piece", "Valid move"]


Stats shape: torch.Size([4, 60, 8, 8])


In [None]:
plot_square_as_board(
    stats.mean(1),
    title="Average frequency of each cell being ...",
    facet_col=0,
    facet_labels=stat_names,
)



In [None]:
# Plot per move
lines = stats[1:].mean((2, 3)).T

df = pd.DataFrame(lines.tolist(), columns=stat_names[1:])
df["Move"] = df.index
# Add propotion of my pieces
df["Proportion of my pieces"] = df["My piece"] / (df["My piece"] + df["Their piece"])
df = df.melt(id_vars=["Move"], var_name="Cell type", value_name="Frequency")
px.line(
    df,
    x="Move",
    y="Frequency",
    color="Cell type",
    title="Frequency of each cell being ...",
    labels={
        "value": "Frequency",
        "index": "Move"
    },
    # legend=stat_names,
)

In [None]:
# %% Plot stats per cell and move
if 0:
    moves_to_show = [0, 5, 10, 20, 30, 40, 50, 55]
    x = einops.rearrange(stats[:, moves_to_show], "option m r c -> (m option) r c")
    labels = [f"{name} (move {move})" for move in moves_to_show for name in stat_names]

    plot_square_as_board(
        x,
        facet_col=0,
        facet_col_wrap=4,
        facet_labels=labels,
        title="Average frequency of each cell being ...",
        height=3000,
    )


### Actual Training

In [None]:
import dataclasses


def train_probe(args: ProbeTrainingArgs, *orthogonal):
    if args.probe_name == ProbeTrainingArgs.probe_name:
        args.probe_name = f'probe-{len(orthogonal)}'
        if args.black_and_white:
            args.probe_name += '-bw'
        if args.correct_for_dataset_bias:
            args.probe_name += '-unbiased'

    lit_ortho_probe = LitLinearProbe(model, args, *orthogonal)

    wandb.finish()
    logger = WandbLogger(save_dir=os.getcwd() + "/logs", project='orthogonal-probes')
    config = dataclasses.asdict(args)
    del config['train_tokens']
    del config['valid_tokens']
    logger.log_hyperparams(config)
    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        logger=logger,
        log_every_n_steps=5,
        val_check_interval=100,
        check_val_every_n_epoch=None,
    )

    trainer.fit(model=lit_ortho_probe)
    probe = lit_ortho_probe.linear_probe
    
    plot_aggregate_metric(
        valid_tokens,
        model,
        probe,
        # per_option=True,
        per_move="board_accuracy",
    )

    wandb.finish()

    path = PROBE_DIR / f"{args.probe_name}.pt"
    if not path.exists():
        t.save(lit_ortho_probe.linear_probe, path)
        print(f"Saved probe to {path.resolve()}")
    else:
        print(f"Warning: {path.resolve()} already exists. Not saving the probe.")

    return probe

In [None]:
probe = get_probe(0, device=device, base_name='probe-0.9')

In [None]:
probe2 = train_probe(ProbeTrainingArgs(games_tokens, val_tokens, lr=0.001, batch_size=600, max_epochs=12),
                    probe,)

### Training with sweeps

In [None]:
# Train with sweeps
import yaml


config = yaml.safe_load(Path('sweep.yml').read_text())
rprint(config)

sweep_id = wandb.sweep(sweep=config, project='cell-probes-sweep')


In [None]:
def run():
    with wandb.init() as run:
        config = run.config
        probe_config = ProbeConfig(
            cell="C4",
            trained_on_move=40,
            num_valid_games=1000,
            validate_every=20,
            **config)
        probe = Probe(model, probe_config)
        probe.train(games_tokens, val_tokens)

wandb.agent(sweep_id, function=run, count=1)

### Probe per move/cell

In [4]:
val_tokens, _ = generate_training_data(1000, 69)
train_tokens, _ = get_training_data()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
import probes
from importlib import reload
reload(probes)
from probes import OthelloProbe

test_config = OthelloProbe.Config(
    cell="D3",
    num_probes=12,
    layer=3,
    options=2,
    validate_every=20,
    use_wandb=False,
    num_train_games=1000,
    num_val_games=100,
)
rprint(test_config)
probe = OthelloProbe(model, test_config)
train_data = probe.dataloader_activations(train_tokens[:, :test_config.num_probes])
val_data = probe.dataloader_activations(val_tokens[:, :test_config.num_probes], for_validation=True)



In [None]:
from probes import OthelloProbe
TRAIN = True
probes = []
# for layer in range(8):
for layer in [3]:
    probe_config = OthelloProbe.Config(
        cell="D3",
        num_probes=12,
        layer=layer,
        options=2,
        device=device,
        epochs=3,
        lr=0.002,
        # wd=0.00001,
        batch_size=1000,
        num_train_games=100_000,
        num_val_games=1_000,
        validate_every=20,
        use_wandb=True,
    )
    if TRAIN:
        probe = OthelloProbe(model, probe_config)

        with wandb.init(project='position-dependent-probes',
                        config=probe_config,
                        mode='online' if probe_config.use_wandb else 'disabled',
                        group=f'{probe_config.cell}-{probe_config.trained_on}'):

            probe.train_loop(
                # train_data,
                # val_data,
                probe.dataloader_activations(train_tokens[:, :probe_config.num_probes]),
                probe.dataloader_activations(val_tokens[:, :probe_config.num_probes], 1000, for_validation=True),
            )
        # probe.save()
    else:
        probe = probe_config.load(model)
    probes.append(probe)

For a given cell and move number,
we have a probe that tries to predict the state of the cell at every move (after every layer)

In [21]:
val_data = probes[0].dataloader(val_tokens[:, :probes[0].config.num_probes], for_validation=True)

In [30]:
losses = [
    probe.validate(val_data, use_wandb=False)
    for probe in tqdm(probes)
]

  0%|          | 0/1 [00:00<?, ?it/s]

In [130]:
for probe in probes:
    probe.save()



In [36]:
# nice plot
from circuit import MagicTensor

(MagicTensor(t.stack(losses) - baseline_random, ['layer', 'metric', 'time'])
    .by('metric')
    ['metric', 1]
    .plot(title=f"Probe on move {probes[0].config.trained_on} trying to predict the state of {probes[0].config.cell} at given timestep"
          "<br>Accuracy improvement over baseline (dataset stats)",
          color_continuous_scale='RdBu',
          color_continuous_midpoint=0.0,
          zmax=0.5,
          zmin=-0.5,
          # facet_labels=["Loss", "Accuracy"],
          # facet_col_wrap=1,
          )
);

In [33]:
baseline_stats

tensor([[[inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0')

In [None]:
from circuit import MagicTensor
(MagicTensor(t.stack([p.probe for p in probes]), ['layer', 'option', 'move', 'dmodel'])
    .norm('dmodel')
    ['layer', 5]
    ['option', 0]
    # .new_dim('facet')
    # .flatten('move', 'dmodel')
    # .flatten('layer', 'option')
    .flatten()
    .print()
    # .histogram(nbins=100)
    .plot()
 )

### Results: knowledge of previous state of D3

In [67]:
from probes import OthelloProbe
# Load the probes
d3_probes = [OthelloProbe.load(model, layer=layer, num_probes=41, options=2, cell="D3", num_val_games=1000) for layer in range(8)]
# Fix the dims since I changed the probe format
for p in d3_probes:
    p.probe = t.nn.Parameter(einops.rearrange(p.probe, 'option probe dmodel -> probe option dmodel'))
d3_cfg = d3_probes[0].config
# Compute their losses
val_data = p.dataloader(val_tokens[:, :d3_cfg.num_probes], for_validation=True)
d3_losses = t.stack([
    probe.validate(val_data, use_wandb=False)
    for probe in tqdm(d3_probes)
])

  0%|          | 0/8 [00:00<?, ?it/s]

In [70]:
from probes import baselines
(MagicTensor(d3_losses - baselines(d3_probes, val_data, 'random'), ['layer', 'metric', 'time'])
 ['metric', 1]
 .plot(title=f"Probe on move {d3_cfg.trained_on} trying to predict the state of {d3_cfg.cell} at given timestep"
             "<br>Accuracy improvement over baseline (dataset stats)",
       color_continuous_scale='RdBu',
       color_continuous_midpoint=0.0,
       zmax=0.5,
       zmin=-0.5,
       )
 );


## Constructing residual stream

In [None]:
def make_residual_stream(
    world: Union[Int[Tensor, "row=8 cols=8"], str],
    probe: Float[Tensor, "d_model rows cols options=3"],
) -> Float[Tensor, "d_model"]:
    """
    Create the embedding of a board state according to the probe

    Args:
        world: the board state, with 0 for blank, +1 for mine, -1 for theirs
        probe: directions in the residual stream that correspond to each square.
            The last dimension is the options, with 0 for blank, 1 for mine, 2 for theirs
    """

    if isinstance(world, str):
        world = board_to_tensor(world)

    d_model = probe.shape[0]
    blank_direction = probe[..., 0] - (probe[..., 1] + probe[..., 2]) / 2
    my_direction = probe[..., 1] - probe[..., 2]

    world = world.to(probe.device)
    embedding = t.zeros(d_model, device=probe.device)
    for row in range(world.shape[0]):
        for col in range(world.shape[1]):
            if world[row, col] == 0:
                embedding += blank_direction[:, row, col]
            else:
                embedding += my_direction[:, row, col] * world[row, col]

    return embedding


# %% Try to run the model on a virtual residual stream


def hook(activation: Float[Tensor, "game move d_model"], hook: HookPoint):
    activation[:, -1] = resid


layer = 4
act_name = utils.get_act_name("resid_pre", layer)
osef_input = focus_games_tokens[:1, :20]  # 1 game, 20 moves
logits = model.run_with_hooks(osef_input, fwd_hooks=[(act_name, hook)])

# Plot what the model predicts
logits = logits_to_board(logits[0, -1], "log_prob")
plot_square_as_board(logits, title="Model predictions")


In [None]:
board = """
........
........
........
...xo...
...ox...
........
........
........
"""

board_2 = """
........
........
........
..xxxx..
.xooooo.
.o..ox..
.x.oxo..
........
"""
board_3 = """
........
........
........
...xo...
...oo...
....o...
........
........
"""

board_tensor = board_to_tensor(board_3)
resid = make_residual_stream(board_tensor, linear_probe)
# plot_square_as_board(board_tensor)


In [None]:
@t.inference_mode()
def modify_resid_given_probe(
        model: HookedTransformer,
        moves_orig: Int[Tensor, "move"],
        moves_new: Int[Tensor, "move"],
        *probes: Float[Tensor, "d_model rows cols options=3"],
        layer: int = 6,
        cells: Tuple[str, ...] = (),
):
    act_name = utils.get_act_name("resid_pre", layer)
    new_logits, new_cache = model.run_with_cache(
        moves_new,
        names_filter=lambda name: name == act_name,
    )

    def hook(orig_activation: Float[Tensor, "game move d_model"], hook: HookPoint):
        # Step 0. Find a basis of the subspace of the probe
        # collect the probe vectors
        all_probes = t.stack(probes, dim=-1).to(orig_activation.device)
        if cells:
            rows_cols = t.tensor([board_label_to_row_col(cell) for cell in cells])
            all_probes = all_probes[:, rows_cols[:, 0], rows_cols[:, 1]]

        probe_vectors = einops.rearrange(all_probes, "d_model ... -> (...) d_model")

        orig_activation[:, -1] = swap_subspace(
            orig_activation[:, -1],
            new_cache[act_name][:, -1],
            probe_vectors,
        )

    patched_logits = model.run_with_hooks(
        moves_orig,
        fwd_hooks=[(act_name, hook)],
    )

    # display the logits
    orig_valid_moves = move_sequence_to_state(tokens_to_board(moves_orig), mode="valid")
    if cells:
        rows_cols = [board_label_to_row_col(cell) for cell in cells]
        index = tuple(zip(*rows_cols))

        # Compute the state of orig and new board
        new_board_state = move_sequence_to_state(tokens_to_board(moves_new), mode="normal")[0, -1]
        orig_board_state = move_sequence_to_state(tokens_to_board(moves_orig), mode="normal")[0, -1]
        # Put the cells of the new board in the orig board
        orig_board_state[index] = new_board_state[index]

        valid_cells = valid_moves_from_board(orig_board_state, moves_orig.shape[1])
        new_valid_moves = one_hot(valid_cells).reshape(1, 1, 8, 8)

    else:
        new_valid_moves = move_sequence_to_state(tokens_to_board(moves_new), mode="valid")

    orig_logits = logits_to_board(model(moves_orig)[0, -1], "log_prob")
    patched_logits = logits_to_board(patched_logits[0, -1], "log_prob")
    new_logits = logits_to_board(new_logits[0, -1], "log_prob")

    scale = new_logits.abs().max().cpu()

    to_stack = [
        orig_valid_moves[0, -1] * scale,
        orig_logits,
        patched_logits,
        new_logits,
        new_valid_moves[0, -1] * scale,
        patched_logits - orig_logits,
    ]

    all_logits = t.stack([t.cpu() for t in to_stack], dim=-1)
    plot_square_as_board(
        all_logits,
        title="Model predictions",
        facet_col=-1,
        facet_col_wrap=3,
        facet_labels=[
            "New expected",
            "new logits",
            "logit diff (patch - orig)",
            "original expected",
            "original logits",
            "patched logits",
        ],
    )

    # plot_square_as_board(logits_to_board(new_logits[0, -1], 'log_prob'),
    #                      title="Model predictions (new)")
    # plot_square_as_board(logits_to_board(patched_logits[0, -1], 'log_prob'),
    #                      title="Model predictions (patched)")

    # with model.hooks(fwd_hooks=[(act_name, hook)]):
    #     plot_board_log_probs(
    #         tokens_to_board(moves_new[0]),
    #         patched_logits[0],
    #     )


orig_index = 2
new_index = 3
move_index = 20
layer = 4
orig_games = focus_games_tokens[orig_index:orig_index + 1, :move_index]
new_games = focus_games_tokens[new_index:new_index + 1, :move_index]

modify_resid_given_probe(model, orig_games, new_games, *probes, layer=layer, cells=["D2"])


In [None]:
plot_single_board(focus_games_board_index[orig_index, :move_index], title="Original game")
plot_single_board(focus_games_board_index[new_index, :move_index], title="New game")