In [1]:
%load_ext autoreload
%autoreload 2

## Download and load SAE

In [2]:
import wandb

entity = "andyrdt"
project = "othello_gpt_sae"

for artifact in [
    "sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_512:latest",
    "sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_1024:latest",
    "sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_2048:latest",
    "sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_4096:latest",
]:
    artifact_path = f"{entity}/{project}/{artifact}"
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [3]:
from sae_training.utils import LMSparseAutoencoderSessionloader

path ="./artifacts/sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_512:v0/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_512.pt"
# path ="./artifacts/sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_1024:v0/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_1024.pt"
# path ="./artifacts/sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_2048:v3/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_2048.pt"
# path="./artifacts/sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_4096:v4/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_4096.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(path)
sparse_autoencoder.eval()

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loaded pretrained model othello-gpt into HookedTransformer
Moving model to device:  cuda


Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Dataset is tokenized! Updating config.


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

## Othello stuff

In [4]:
import sys
import numpy as np
import torch
torch.set_grad_enabled(False)

from pathlib import Path

!git clone https://github.com/likenneth/othello_world
OTHELLO_ROOT = Path("./othello_world/")
sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))

from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState

fatal: destination path 'othello_world' already exists and is not an empty directory.


In [5]:
board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [6]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

board_labels = list(map(to_board_label, stoi_indices))

In [7]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


In [8]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

In [9]:
from neel_plotly import imshow
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

In [10]:
num_games = 200
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

In [11]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)

focus states: (200, 60, 8, 8)
focus_valid_moves torch.Size([200, 60, 64])


In [12]:
focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())

In [13]:
full_linear_probe = torch.load(OTHELLO_ROOT/"mechanistic_interpretability/main_linear_probe.pth")
full_linear_probe.shape # [modes, d_model, row, col, options]

torch.Size([3, 512, 8, 8, 3])

In [14]:
rows = 8
cols = 8 
options = 3
black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2
linear_probe = torch.zeros(model.cfg.d_model, rows, cols, options, device="cuda")
linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])

In [15]:
blank_probe = linear_probe[..., 0] - linear_probe[..., 1] * 0.5 - linear_probe[..., 2] * 0.5
my_probe = linear_probe[..., 2] - linear_probe[..., 1]

In [16]:
game_index = 4
move = 20
plot_single_board(focus_games_string[game_index, :move+1])
# plot_single_board(focus_games_string[game_index, :move].tolist()+[16], title="Corrupted Game (blank plays C0)")

### Max activating dataset examples

In [17]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black
    
    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)

# focus_states_flipped_value does this but has their's==2 and mine==1, not 1 and -1. So let's convert, this is boring and you don't need to understand it.
focus_states_flipped_pm1 = torch.zeros_like(focus_states_flipped_value, device="cuda")
focus_states_flipped_pm1[focus_states_flipped_value==2] = 1.
focus_states_flipped_pm1[focus_states_flipped_value==1] = -1.

In [18]:
layer, neuron = 0, 774
neuron_acts_post = focus_cache['post', layer][..., neuron]

imshow(
    neuron_acts_post,
    title=f"L{layer}N{neuron} acts over focus games",
    xaxis="Move", yaxis="Game"
)

In [19]:
neuron_activated_mask = neuron_acts_post > 0.0
current_move_when_active = focus_games_int[:, :-1][neuron_activated_mask.cpu()]
print("Current move when L0774 activates:", int_to_label(current_move_when_active))
print("Fraction of time current_move==E0 when L0N774 activates:", (current_move_when_active == to_int('E0')).float().mean().item())

Current move when L0774 activates: ['E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0'

In [20]:
top_moves = neuron_acts_post.flatten() > neuron_acts_post.flatten().quantile(0.99)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774",
)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].abs().reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774 (abs value)",
)

### Spectrum plots

In [21]:
import pandas as pd
import plotly.express as px

In [22]:
layer, neuron = 0, 774
neuron_acts_pre = focus_cache['pre', layer][..., neuron]

label = focus_games_int[:, :-1] == to_int('E0')

df = pd.DataFrame({"pre_act": neuron_acts_pre.flatten().tolist(), "label": label.flatten().tolist()})

px.histogram(
    df,
    nbins=100, histnorm="percent",
    x="pre_act", color="label",
    title=f"L{layer}N{neuron} spectrum plot testing current_move==E0"
)





## SAE features

This section might be bugged, I didn't have time to properly debug.

But the idea is to compare the activations of each feature to the state of each cell in the board. We can try and compute the correlation between a feature activating and the state of a cell.

In [23]:
import einops

In [24]:
print(f"focus_games: {focus_games_int.shape}")
print(f"focus_states: {focus_states.shape}")

focus_games: torch.Size([200, 60])
focus_states: (200, 60, 8, 8)


In [25]:
# focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())

x = einops.rearrange(focus_cache[sparse_autoencoder.cfg.hook_point], "batch seq_len d_model -> (batch seq_len) d_model")
sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sparse_autoencoder(x)

In [26]:
print(feature_acts.shape)
focus_states_flattened = einops.rearrange(focus_states[:, :-1], "batch seq_len rows cols -> (batch seq_len) (rows cols)")

focus_is_empty = torch.tensor(focus_states_flattened < 0).float()
print(focus_is_empty.shape)

torch.Size([11800, 512])
torch.Size([11800, 64])


In [27]:
combined_tensor = torch.cat(((feature_acts == 0).float(), focus_is_empty.to(feature_acts)), dim=1) # first d_features are the feature acts, next 64 are the is_empty

corr_matrix = torch.corrcoef(combined_tensor.T)

relevant_corr = corr_matrix[:sparse_autoencoder.d_sae, sparse_autoencoder.d_sae:sparse_autoencoder.d_sae+focus_is_empty.shape[1]]

In [28]:
# replace nan values with 0
relevant_corr = torch.where(torch.isnan(relevant_corr), torch.zeros_like(relevant_corr), relevant_corr)

In [29]:
px.histogram(
    relevant_corr.flatten().cpu().numpy(),
    nbins=100,
    title="Correlation between feature acts and is_empty"
)

In [30]:
top_empty_vals, top_empty_idxs = relevant_corr.flatten().topk(10)

feature_ids = top_empty_idxs // 64
square_ids = top_empty_idxs % 64

In [31]:
print(feature_ids[0])
print(square_ids[0])

print(relevant_corr[feature_ids[0], square_ids[0]])

tensor(485, device='cuda:0')
tensor(43, device='cuda:0')
tensor(0.3359, device='cuda:0')


In [32]:
import plotly.graph_objects as go

for i in range(10):

    square_id = square_ids[i]
    feature_id = feature_ids[i]
    board_label = board_labels[square_id]
    empty_criteria = focus_states_flattened[:, square_id] == 0

    fig = go.Figure()

    true_activations_empty = (feature_acts[empty_criteria, feature_id]>0).cpu().numpy().sum()
    false_activations_empty = (feature_acts[empty_criteria, feature_id]<=0).cpu().numpy().sum()
    true_activations_not_empty = (feature_acts[~empty_criteria, feature_id]>0).cpu().numpy().sum()
    false_activations_not_empty = (feature_acts[~empty_criteria, feature_id]<=0).cpu().numpy().sum()

    fig.add_trace(
        go.Bar(
            x=['False', 'True'],
            y=[false_activations_empty, true_activations_empty],
            name="Empty",
            marker_color='blue'
        )
    )

    fig.add_trace(
        go.Bar(
            x=['False', 'True'],
            y=[false_activations_not_empty, true_activations_not_empty],
            name="Not Empty",
            marker_color='red'
        )
    )

    fig.update_layout(
        title=f"Feature {feature_id} vs IsEmpty({board_label})",
        xaxis_title="Feature fired?",
        yaxis_title="Count",
        barmode='group',
        width=600,
    )

    fig.show()