In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

  from .autonotebook import tqdm as notebook_tqdm


## Predictions

In [2]:
# Load the model
PTH_LOCATION = "data/transformer_lens.pth"
model_dict = torch.load(PTH_LOCATION)
model = HookedTransformer(model_dict["config"])
model.load_state_dict(torch.load(PTH_LOCATION)["model"])

<All keys matched successfully>

In [6]:
# Load our tensors
train_data = torch.load("data/train_data.pt")
eval_data = torch.load("data/eval_data.pt")
print(train_data.shape, eval_data.shape)

torch.Size([204134, 9]) torch.Size([51034, 9])


In [7]:
# Do a prediction
model.eval()
# Random idx
idx = np.random.randint(0, len(eval_data))
example = eval_data[idx:idx+1, :]
example_logits = model(example).squeeze()

# Get the argmax over the final dimension
example_pred = example_logits.argmax(-1)
print(f"Example: {example.squeeze()}")
print(f"Pred:    {example_pred.squeeze()}")

Example: tensor([3, 1, 5, 7, 8, 2, 4, 9, 9])
Pred:    tensor([6, 7, 8, 8, 0, 4, 6, 9, 9])


In [8]:
# Do a prediction of the next token
model.eval()
idx = np.random.randint(0, len(eval_data))
example = eval_data[idx:idx+1, :]
layer = 5
src = example[:, :layer]
target = example[:, layer]
example_logits = model(src)
print(f"Logits shape: {example_logits.shape}")
example_pred = example_logits[:, -1].argmax(-1)
print(f"Example: {example.squeeze()}")
print(f"Src: {src}")
print(f"Target:  {target.squeeze()}")
print(f"Pred:    {example_pred.squeeze()}")

Logits shape: torch.Size([1, 5, 10])
Example: tensor([5, 3, 0, 8, 4, 6, 7, 2, 1])
Src: tensor([[5, 3, 0, 8, 4]])
Target:  6
Pred:    7


In [9]:
def predict_next_token(src, model):
    model.eval()
    example_logits = model(src)
    example_pred = example_logits[:, -1].argmax(-1)
    return example_pred

# Do a prediction of the next token
idx = np.random.randint(0, len(eval_data))
example = eval_data[idx:idx+1, :]
layer = 7
src = example[:, :layer]
target = example[:, layer]
example_pred = predict_next_token(src, model)
print(f"Target:  {target.squeeze()}")
print(f"Pred:    {example_pred.squeeze()}")

Target:  4
Pred:    6


In [10]:
import torch
import torch.nn.functional as F
import plotly.express as px

def get_player_positions(board_state, player):
    return [(pos % 3, pos // 3) for pos in board_state if pos % 2 == player]

def add_player_annotations(fig, positions, player_symbol):
    for x, y in positions:
        fig.add_annotation(x=x, y=y, text=player_symbol, showarrow=False,
                           font=dict(size=40), xref="x", yref="y")
        
def add_token_circles(fig, positions):
    for x, y in positions:
        fig.add_shape(type="circle",
                      xref="x", yref="y",
                      x0=x - 0.25, y0=y - 0.25, x1=x + 0.25, y1=y + 0.25,
                      line_color="green", fillcolor="green")

def plot_board_with_logits(board_state, step_logits, layer):
    # Set everything up
    target = board_state[0, layer+1]
    board_state = board_state[0, :layer+1]
    step_logits = step_logits[layer, :]

    # Take softmax BEFORE we discard the game over token
    step_logits = F.softmax(step_logits, dim=-1)
    # Game over probability is the last value in the logits tensor
    game_over_prob = step_logits[-1]
    # Get the predicted next token
    next_token = step_logits.argmax().item()
    # Assume 'layer' is the current step in the game represented by board_state
    step_logits = step_logits[:-1]  # Exclude the last value if not part of the board
    # Reshape into 3x3
    step_logits = step_logits.reshape(3, 3)
    # Imshow with plotly
    fig = px.imshow(step_logits.cpu().detach().numpy(), text_auto=True)

    # Get 'X' and 'O' positions
    x_positions = get_player_positions(board_state[:layer+1], player=0)
    o_positions = get_player_positions(board_state[:layer+1], player=1)

    # Add circles for 'X' and 'O' positions on green tokens
    all_positions = x_positions + o_positions
    add_token_circles(fig, all_positions)

    # Add 'X' and 'O' annotations to the figure
    add_player_annotations(fig, x_positions, 'X')
    add_player_annotations(fig, o_positions, 'O')

    fig.update_xaxes(side="top")  # This will put the (0,0) position of imshow in the top left corner
    fig.update_traces(texttemplate="%{text}", textfont_size=20)  # Set text size

    # Update axes properties to not show any labels or ticks
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)

    fig.update_layout(autosize=False, width=500, height=500, title=f"Predicted = {next_token} (targ = {target}), Game over prob = {game_over_prob.item()*100:.2f}%")
    fig.show()

# Example usage with dummy data
layer = 4  # Current step
# Do a prediction
model.eval()
# Random idx
idx = np.random.randint(0, len(eval_data))
example = eval_data[idx:idx+1, :]
example_logits = model(example).squeeze()
# example_logits = torch.rand(9, 10) * 4 - 2  # Dummy logits tensor
# example = torch.tensor([[8, 6, 0, 7, 1, 2, 5, 3, 4]], dtype=torch.int32)  # Dummy board state
plot_board_with_logits(example, example_logits, layer=layer)


## Linear probe

In [21]:
import torch

def board_state_to_vector(board_state_tensor):
    # Initialize a vector of zeros with length 9
    state_vector = torch.zeros(9, dtype=torch.int32)
    
    # Set the positions indicated by the board_state_tensor to 1
    state_vector[board_state_tensor] = 1
    
    return state_vector

# Example usage
idx = np.random.randint(0, len(eval_data))
board_state_tensor = eval_data[idx:idx+1, :4].squeeze()
state_vector = board_state_to_vector(board_state_tensor)
print(state_vector)

tensor([0, 0, 0, 0, 1, 1, 0, 1, 1], dtype=torch.int32)


In [None]:
def board_state_to_logit_vector_pair(board_state_tensor):
    # Convert the board state to a vector
    state_vector = board_state_to_vector(board_state_tensor)
    # Get the logits for the next token
    logits = model(state_vector.unsqueeze(0))
    # Get the predicted next token
    pred_token = logits.argmax(-1)
    return state_vector, logits, pred_token

In [14]:
# Example usage with dummy data
layer = 4  # Current step
# Do a prediction
model.eval()
# Random idx
idx = np.random.randint(0, len(eval_data))
example = eval_data[idx:idx+1, :layer]
example_logits = model(example).squeeze()

print(example.shape)
print(example_logits.shape)

torch.Size([1, 4])
torch.Size([4, 10])


In [18]:
example_logits.shape

torch.Size([4, 10])

In [17]:
# Next pred
pred = example_logits.argmax(-1)[-1]
pred

tensor(6)