In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState

In [21]:
from transformer_lens import HookedTransformer, HookedTransformerConfig
cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)

In [5]:

model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = HookedTransformer.from_pretrained(model_name)

LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

In [8]:
from pathlib import Path
!git clone https://github.com/likenneth/othello_world
OTHELLO_ROOT = Path("./othello_world/")
import sys
sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))
from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState



fatal: destination path 'othello_world' already exists and is not an empty directory.


In [10]:
import numpy as np
board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [11]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

In [5]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

In [22]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


In [23]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

In [24]:
from neel_plotly import imshow
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

In [25]:
num_games = 50
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

In [26]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)

focus states: (50, 60, 8, 8)
focus_valid_moves torch.Size([50, 60, 64])


In [27]:
focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())

In [29]:
full_linear_probe = torch.load(OTHELLO_ROOT/"mechanistic_interpretability/main_linear_probe.pth")
full_linear_probe.shape # [modes, d_model, row, col, options]

torch.Size([3, 512, 8, 8, 3])

In [30]:
rows = 8
cols = 8 
options = 3
black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2
linear_probe = torch.zeros(cfg.d_model, rows, cols, options, device="cuda")
linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])

In [31]:
blank_probe = linear_probe[..., 0] - linear_probe[..., 1] * 0.5 - linear_probe[..., 2] * 0.5
my_probe = linear_probe[..., 2] - linear_probe[..., 1]

In [32]:
full_linear_probe[black_to_play_index, ..., 0].shape

torch.Size([512, 8, 8])

In [35]:
torch.cosine_similarity(full_linear_probe[black_to_play_index, ..., 1].reshape(-1, rows*cols), full_linear_probe[white_to_play_index, ..., 1].reshape(-1, rows*cols), dim=0)

tensor([-0.2947, -0.2504, -0.2449, -0.4683, -0.3550, -0.3380, -0.3382, -0.2670,
        -0.3262, -0.1785, -0.3109, -0.3766, -0.3318, -0.1611,  0.0499, -0.3115,
        -0.2379, -0.2557, -0.5527, -0.4500, -0.5523, -0.5044, -0.3436, -0.3563,
        -0.3407, -0.4816, -0.5757, -0.8512, -0.8089, -0.5109, -0.3328, -0.3939,
        -0.4477, -0.2742, -0.5748, -0.8321, -0.8530, -0.5319, -0.2576, -0.5081,
        -0.4441, -0.3317, -0.4485, -0.6089, -0.5356, -0.4813, -0.2826, -0.2679,
        -0.3561,  0.0323, -0.3106, -0.3787, -0.3493, -0.3359, -0.1465, -0.2614,
        -0.2977, -0.3396, -0.2834, -0.5596, -0.3314, -0.2652, -0.3252, -0.2755],
       device='cuda:0', grad_fn=<SumBackward1>)

In [38]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 61,
 'd_vocab_out': 61,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.035355339059327376,
 'model_name': 'custom',
 'n_ctx': 59,
 'n_devices': 1,
 'n_heads': 8,
 'n_layers': 8,
 'n_params': 25165824,
 'normalization_type': 'LNPre',
 'original_architecture': None,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': None,
 'tokenizer_prepends_bos': None,
 'use_attn_in': False,
 'use_attn_result': Fa

In [36]:
game_index = 4
move = 20
plot_single_board(focus_games_string[game_index, :move+1])
plot_single_board(focus_games_string[game_index, :move].tolist()+[16], title="Corrupted Game (blank plays C0)")

## Max activating dataset examples

In [39]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black
    
    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)

# focus_states_flipped_value does this but has their's==2 and mine==1, not 1 and -1. So let's convert, this is boring and you don't need to understand it.
focus_states_flipped_pm1 = torch.zeros_like(focus_states_flipped_value, device="cuda")
focus_states_flipped_pm1[focus_states_flipped_value==2] = 1.
focus_states_flipped_pm1[focus_states_flipped_value==1] = -1.

In [41]:
layer, neuron = 0, 774
neuron_acts_post = focus_cache['post', layer][..., neuron]

imshow(
    neuron_acts_post,
    title=f"L{layer}N{neuron} acts over focus games",
    xaxis="Move", yaxis="Game"
)

In [42]:
neuron_activated_mask = neuron_acts_post > 0.0
current_move_when_active = focus_games_int[:, :-1][neuron_activated_mask.cpu()]
print("Current move when L0774 activates:", int_to_label(current_move_when_active))
print("Fraction of time current_move==E0 when L0N774 activates:", (current_move_when_active == to_int('E0')).float().mean().item())

Current move when L0774 activates: ['C3', 'C4', 'D5', 'C2', 'B7', 'F6', 'B5', 'D7', 'C5', 'G2', 'G3', 'G7', 'F4', 'G4', 'G5', 'F2', 'D6', 'H5', 'F1', 'E7', 'C1', 'B3', 'H4', 'H3', 'H2', 'C7', 'F7', 'B1', 'H6', 'D1', 'B2', 'F0', 'D0', 'D2', 'E0', 'A1', 'C0', 'E1', 'B0', 'E6', 'G0', 'A2', 'G1', 'A5', 'A3', 'H0', 'B6', 'A6', 'G6', 'F4', 'F5', 'F2', 'G3', 'C4', 'E5', 'F6', 'D6', 'E2', 'B4', 'C5', 'G7', 'C1', 'G6', 'F7', 'G5', 'C3', 'B3', 'H6', 'G4', 'B2', 'C2', 'B5', 'A6', 'B6', 'C6', 'H4', 'A4', 'C7', 'B0', 'C0', 'D5', 'E1', 'B1', 'E6', 'B7', 'H2', 'H7', 'D7', 'A5', 'A3', 'H5', 'D1', 'A7', 'H1', 'F0', 'E7', 'D0', 'E0', 'A2', 'G0', 'F1', 'G1', 'H0', 'E5', 'D2', 'D5', 'G2', 'C1', 'E6', 'H1', 'H5', 'C3', 'C2', 'C5', 'C6', 'E7', 'B7', 'C0', 'H7', 'B4', 'D6', 'C7', 'F7', 'F6', 'C4', 'B1', 'A4', 'B2', 'E0', 'H6', 'E1', 'A2', 'G3', 'F1', 'F2', 'E2', 'H3', 'B5', 'G7', 'F0', 'A3', 'H2', 'A1', 'B0', 'H0', 'A5', 'F4', 'G1', 'A0', 'B6', 'G0', 'G4', 'D7', 'A7', 'D2', 'F2', 'F4', 'F6', 'C4', 'B5', 'G6'

In [43]:
top_moves = neuron_acts_post.flatten() > neuron_acts_post.flatten().quantile(0.99)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774",
)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].abs().reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774 (abs value)",
)

## Spectrum Plots

In [45]:
import pandas as pd

In [44]:
layer, neuron = 0, 774
neuron_acts_pre = focus_cache['pre', layer][..., neuron]

label = focus_games_int[:, :-1] == to_int('E0')

df = pd.DataFrame({"pre_act": neuron_acts_pre.flatten().tolist(), "label": label.flatten().tolist()})

px.histogram(
    df,
    nbins=100, histnorm="percent",
    x="pre_act", color="label",
    title=f"L{layer}N{neuron} spectrum plot testing current_move==E0"
)

NameError: name 'pd' is not defined