In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from vector_quantize_pytorch import FSQ, VectorQuantize
import math
import einops
from fancy_einsum import einsum
import numpy as np

### TRANSFORMER LENS ###
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

## Cache residual streams

In [59]:
# Load the model
PTH_LOCATION = "data/transformer_lens.pth"
model_dict = torch.load(PTH_LOCATION)
model = HookedTransformer(model_dict["config"])
model.load_state_dict(torch.load(PTH_LOCATION)["model"])

# Load our tensors
train_data = torch.load("data/train_data.pt")
eval_data = torch.load("data/eval_data.pt")
print(train_data.shape, eval_data.shape)

_, train_cache = model.run_with_cache(train_data)
_, eval_cache = model.run_with_cache(eval_data)

train_residual_stream, train_labels = train_cache.decompose_resid(return_labels=True)
eval_residual_stream, eval_labels = eval_cache.decompose_resid(return_labels=True)

train_residual_stream.shape # (n_layers, n_examples, seq_len, d_model) = torch.Size([4, 204134, 9, 32])

# We want each of the indices in the seq_len to be a separate example
train_residual_stream = einops.rearrange(train_residual_stream, "layers examples seq_len d_model -> layers (examples seq_len) d_model")
eval_residual_stream = einops.rearrange(eval_residual_stream, "layers examples seq_len d_model -> layers (examples seq_len) d_model")
print(train_residual_stream.shape) # torch.Size([4, 1837206, 32])

# Save residual streams
# torch.save(train_residual_stream, "data/train_residual_stream.pt")
# torch.save(eval_residual_stream, "data/eval_residual_stream.pt")

torch.Size([204134, 9]) torch.Size([51034, 9])
torch.Size([4, 1837206, 32])


In [3]:
def get_original_data_from_index(i, seq_len=9):
    # Calculate the original example index and position within the sequence
    original_index = i // seq_len
    position = i % seq_len
    
    # Assuming train_data is loaded and available
    # Note: train_data's assumed shape here is (n_examples, seq_len, d_model)
    original_data = train_data[original_index]
    
    # Slicing to get the specific position data
    # Note: Adjust this slicing based on how the original data is structured and what's needed
    specific_data = original_data[:position+1]  # Adjust if you need a different slice
    
    return specific_data

# Example usage
i = 1000000  # Example index in the train_residual_stream
specific_data = get_original_data_from_index(i)
print(specific_data)  # This will depend on the exact slicing and data shape

tensor([8, 0])


In [39]:
def get_original_data_from_indices(indices, seq_len=9):
    # Initialize a list to hold the data slices corresponding to each index
    data_slices = []
    
    for i in indices:
        # Calculate the original example index and position within the sequence for each index
        original_index = i // seq_len
        position = i % seq_len
        
        # Assuming train_data is loaded and available
        # Extract the specific slice for the current index
        original_data = train_data[original_index]
        
        # Adjust the slicing based on how the original data is structured and what's needed
        specific_data = original_data[:position+1]  # Adjust if you need a different slice

        # Specific data is always a one-dimensional tensor - pad to be a one-d tensor of length 9
        specific_data = F.pad(specific_data, (0, 9 - specific_data.shape[0]), "constant", -1)
        
        # Append the extracted slice to the list
        data_slices.append(specific_data)

    # Stack all tensors in the list to get a single tensor
    data_slices = torch.stack(data_slices)

    return data_slices

# Example usage with a list of different indices
indices = [3, 4, 5]  # Example indices in the train_residual_stream
data_slices = get_original_data_from_indices(indices)
print(data_slices)

tensor([[ 8,  1,  6,  7, -1, -1, -1, -1, -1],
        [ 8,  1,  6,  7,  0, -1, -1, -1, -1],
        [ 8,  1,  6,  7,  0,  5, -1, -1, -1]])


## VQ-VAE

In [4]:
class TransformerAutoencoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, 
                 codebook_size=128, codebook_dim= 16, threshold_ema_dead_code=2, dropout=0.1):
        super(TransformerAutoencoder, self).__init__()
        self.input_dim = input_dim
        self.d_model = d_model

        # VQ Quantizer
        dim = 256
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.quantizer = VectorQuantize(
            dim=dim,  # Assuming the dimensionality to match d_model for simplicity
            codebook_size=codebook_size,  # Example codebook size
            codebook_dim=codebook_dim,  # This is an illustrative example, adjust based on your model's needs
            decay=0.8,
            commitment_weight=1.0,
            use_cosine_sim=True,  # Example, adjust as needed
            threshold_ema_dead_code = threshold_ema_dead_code
        )
        self.bottleneck_dim = dim

        # Positional Encoding
        self.positional_encoder = PositionalEncoding(d_model, dropout)

        # Encoder Layer
        encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)

        # Linear projection down - replaces torch.zeros with nn.Linear
        self.encoder_output_projection = nn.Linear(d_model, self.bottleneck_dim)

        # Linear projection up - replaces torch.zeros with nn.Linear
        self.decoder_input_projection = nn.Linear(self.bottleneck_dim, d_model)

        # Decoder Layer
        decoder_layers = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_decoder_layers)

        self.encoder_input_projection = nn.Linear(input_dim, d_model)
        self.decoder_output_projection = nn.Linear(d_model, input_dim)

    def forward(self, src):
        # Encode
        src = self.encoder_input_projection(src)
        src = self.positional_encoder(src)
        memory = self.transformer_encoder(src)
        
        # Apply the encoder output projection down
        memory = F.relu(self.encoder_output_projection(memory))

        # Vector quantize the memory
        quantized_memory, _, commit_loss = self.quantize(memory)
        #print(f"Quantised Memory shape = {quantized_memory.shape}")

        # Decode
        quantized_memory = F.relu(self.decoder_input_projection(quantized_memory))
        output = self.transformer_decoder(quantized_memory, quantized_memory)
        output = self.decoder_output_projection(output)
        return output, commit_loss

    def quantize(self, bottleneck):
        quantized, indices, commit_loss = self.quantizer(bottleneck)
        return quantized, indices, commit_loss

    def quantized_indices(self, src):
        # Encode
        src = self.encoder_input_projection(src)
        src = self.positional_encoder(src)
        memory = self.transformer_encoder(src)
        
        # Apply the encoder output projection down
        memory = F.relu(self.encoder_output_projection(memory))

        # Vector quantize the memory
        quantised, indices, _ = self.quantize(memory)

        return quantised, indices
    
    def indices_to_rep(self, indices):
        with torch.no_grad():
            low_dim_vectors = self.quantizer.get_codes_from_indices([indices])
            quantized_memory = self.quantizer.project_out(low_dim_vectors)
            # Decode
            quantized_memory = F.relu(self.decoder_input_projection(quantized_memory))
            output = self.transformer_decoder(quantized_memory, quantized_memory)
            output = self.decoder_output_projection(output)
            return output

        
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [5]:
train_streams = torch.load("data/train_residual_stream.pt")
eval_streams = torch.load("data/eval_residual_stream.pt")

# Keep 5% of both
# train_streams = train_streams[:, :int(train_streams.shape[1]*0.05), :]
# eval_streams = eval_streams[:, :int(eval_streams.shape[1]*0.05), :]

print(train_streams.shape, eval_streams.shape)

torch.Size([4, 1837206, 32]) torch.Size([4, 459306, 32])


In [6]:
# Combine the train and eval streams
combined_streams = torch.cat((train_streams, eval_streams), 1)
print(combined_streams.shape)

torch.Size([4, 2296512, 32])


In [None]:
# from datasets import Dataset

# # Push to HuggingFace hub
# dataset = Dataset.from_dict({"data": combined_streams})
# dataset.push_to_hub("ttt_resid_streams")

In [8]:
# Example Configuration
input_dim = train_streams.shape[-1]  # Size of the input
sequence_length = train_streams.shape[0]  # Length of the sequence
d_model = 64  # The number of expected features in the encoder/decoder inputs
nhead = 4  # The number of heads in the multiheadattention models
num_encoder_layers = 1  # The number of sub-encoder-layers in the encoder
num_decoder_layers = 1  # The number of sub-decoder-layers in the decoder
dim_feedforward = 128  # The dimension of the feedforward network model
dropout = 0.1  # The dropout value
codebook_size=32
codebook_dim=8

In [9]:
model = TransformerAutoencoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, 
                               codebook_size=codebook_size, dropout=dropout)

In [10]:
x = train_streams[:, 0, :].unsqueeze(1)
x.shape

torch.Size([4, 1, 32])

In [11]:
from torch.utils.data import DataLoader, Dataset

class ResidDataset(Dataset):
    """Dataset wrapping tensors where the second dimension is treated as the batch dimension."""
    def __init__(self, data_tensor):
        # Expect data_tensor to be of shape [C, N, H] where N is the batch dimension.
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        # This method should return a single sample.
        return self.data_tensor[:, index, :]

    def __len__(self):
        return self.data_tensor.size(1)

# Wrap train_streams in a TensorDataset and DataLoader
batch_size = 8192
train_dataset = ResidDataset(train_streams)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(train_dataset[0].shape)
# Print the first batch
for i, batch in enumerate(train_loader):
    print(batch.shape)
    break

torch.Size([4, 32])
torch.Size([8192, 4, 32])


In [12]:
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adjust the learning rate as needed
loss_function = nn.MSELoss()
epochs = 10

In [13]:
from tqdm import tqdm

def train(model, epochs, optimizer, train_streams, eval_streams, verbose=True, print_epoch=None, batch_size=2048):
    # Wrap train_streams in a TensorDataset and DataLoader
    train_dataset = ResidDataset(train_streams)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    train_losses, eval_losses, unique_indices_utilised = [], [], []

    for epoch in range(epochs):
        model.train()  # Set the model to training mode
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Batches for epoch {epoch}"):
            optimizer.zero_grad()  # Zero the gradients
            batch_streams = einops.rearrange(batch, "batch seq_len d_model -> seq_len batch d_model")
            outputs, commit_loss = model(batch_streams)  # Forward pass: compute the model output
            loss = loss_function(outputs, batch_streams)  # Compute the loss
            loss += commit_loss[0]
            
            loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
            optimizer.step()  # Perform a single optimization step (parameter update)
            
            epoch_loss += loss.item() * batch_streams.size(0)  # Accumulate the loss, scaled by batch size

        epoch_loss /= len(train_loader.dataset)  # Normalize loss by the number of samples

        divider = epochs // 10 if print_epoch is None else print_epoch

        if epoch % divider == 0:
            # Calculate eval loss
            model.eval()  # Set the model to evaluation mode
            eval_loss = 0
            with torch.no_grad():  # No need to track the gradients
                eval_outputs, eval_commit_loss = model(eval_streams)  # Forward pass: compute the model output
                eval_loss = loss_function(eval_outputs, eval_streams)  # Compute the loss
                eval_loss += eval_commit_loss[0]
            # Calculate unique codes on train
            _, train_indices = model.quantized_indices(train_streams.cpu())
            unique_indices = torch.unique(train_indices)
            unique_indices = len(unique_indices)
            unique_indices_utilised.append(unique_indices)

            if verbose:
                print(f"Epoch [{epoch}/{epochs}], Train Loss: {epoch_loss:.3f}, Eval Loss: {eval_loss:.3f}, Unique Indices: {unique_indices}")
            train_losses.append(epoch_loss)
            eval_losses.append(eval_loss)

    return model, train_losses, eval_losses, unique_indices_utilised

In [14]:
#model, train_losses, eval_losses, unique_indices_utilised = train(model, epochs, optimizer, train_streams, eval_streams)

## Load the model and have a look

In [15]:
train_streams = torch.load("data/train_residual_stream.pt")
eval_streams = torch.load("data/eval_residual_stream.pt")
resid_streams = torch.cat((train_streams, eval_streams), 1)
resid_streams.shape

torch.Size([4, 2296512, 32])

In [84]:
# Example Configuration
input_dim = resid_streams.shape[-1]  # Size of the input
sequence_length = resid_streams.shape[0]  # Length of the sequence
d_model = 64  # The number of expected features in the encoder/decoder inputs
nhead = 2  # The number of heads in the multiheadattention models
num_encoder_layers = 1  # The number of sub-encoder-layers in the encoder
num_decoder_layers = 1  # The number of sub-decoder-layers in the decoder
dim_feedforward = 128  # The dimension of the feedforward network model
dropout = 0.1  # The dropout value
codebook_size=32
codebook_dim=8

model = TransformerAutoencoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, 
                               codebook_size=codebook_size, dropout=dropout)

In [85]:
# Load the model
device = 'cpu'
model.load_state_dict(torch.load("data/vq_vae.pt", map_location=torch.device(device)))

<All keys matched successfully>

In [86]:
train_streams[:, :1000, :].shape

torch.Size([4, 1000, 32])

In [87]:
# Random sample of residual streams
#random_indices = np.random.choice(resid_streams.shape[1], 1000, replace=False)

all_quantised, all_indices = model.quantized_indices(train_streams[:, :10000, :].to(device))
all_quantised = all_quantised.detach().cpu().numpy()
all_indices = all_indices.T
print(all_indices)

# Print number of unique indices overall
unique_indices = torch.unique(all_indices)
print(f"Unique indices: {len(unique_indices)}")

tensor([[29,  4, 28, 15],
        [ 6,  8, 23, 21],
        [26, 26, 23, 21],
        ...,
        [ 8,  0,  0, 14],
        [12,  0,  1,  7],
        [ 6,  4, 25, 27]])
Unique indices: 32


In [88]:
small_train = get_original_data_from_indices([x for x in range(10000)])
small_train.shape

torch.Size([10000, 9])

In [89]:
# Plotly plot the frequency of codes by column in all_indices
import plotly.express as px

fig = px.histogram(x=all_indices.flatten(), title="Frequency of codes by column in all_indices")
fig.show()

In [163]:
# Multiple histogram of the frequency of codes by column in all_indices
import plotly.graph_objects as go

labels = ['embed', 'pos_embed', '0_attn_out', '1_attn_out']
fig = go.Figure()
for i in range(all_indices.shape[1]):
    fig.add_trace(go.Histogram(x=all_indices[:, i], name=labels[i]))
fig.update_layout(title="Frequency of codes by column in all_indices")
fig.show()

In [164]:
# Some autointerpretability
code = 29
layer = 2
cl_indices = np.where(all_indices[:, layer] == code)[0]
cl_indices_neg = np.where(all_indices[:, layer] != code)[0]
print(cl_indices.shape, cl_indices_neg.shape)

(1067,) (8933,)


In [165]:
# Using indices, get all examples of small train data where this occurs
positive_examples = small_train[cl_indices]
negative_examples = small_train[cl_indices_neg]
positive_indices = all_indices[cl_indices]
negative_indices = all_indices[cl_indices_neg]
# Combine positive and negative indices
all_indices_cl = np.concatenate((positive_indices, negative_indices), axis=0)
all_examples = np.concatenate((positive_examples, negative_examples), axis=0)

In [166]:
print(positive_examples.shape, negative_examples.shape)
print(positive_indices.shape, negative_indices.shape)
print(all_indices_cl.shape, all_examples.shape)

torch.Size([1067, 9]) torch.Size([8933, 9])
torch.Size([1067, 4]) torch.Size([8933, 4])
(10000, 4) (10000, 9)


In [167]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import OneHotEncoder


labels = np.ones(positive_indices.shape[0])
labels = np.concatenate((labels, np.zeros(negative_indices.shape[0])))

# Initialize the OneHotEncoder
encoder = OneHotEncoder()

# Fit the encoder and transform `all_indices` to a one-hot encoded matrix
all_indices_encoded = encoder.fit_transform(all_examples)

# Feature names for interpretation
length = all_examples.shape[1]
original_feature_names = ["Feature_{}".format(i) for i in range(length)]

# Mapping each one-hot encoded column back to the original feature and category
feature_mapping = []
for i, categories in enumerate(encoder.categories_):
    feature_mapping.extend([(original_feature_names[i], category) for category in categories])

# Train test split on the one-hot encoded data
X_train, X_test, y_train, y_test = train_test_split(all_indices_encoded, labels, test_size=0.2, random_state=49)

# Train the decision tree model on the one-hot encoded data
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)

# Get the predictions
predictions = tree.predict(X_test)
accuracy = (predictions == y_test).mean()
print(f"Accuracy: {accuracy:.3f}")

# Print max depth of the tree
print(f"Max Depth: {tree.get_depth()}")

Accuracy: 0.932
Max Depth: 21


In [168]:
import pandas as pd

def plot_feature_importances(tree, feature_mapping, top_n = 5, save_title = None):
    # Plot feature importances
    importances = tree.feature_importances_

    # Get the feature names
    feature_names = [f"Position{feature[0].replace("Feature_", "")}_Code{feature[1]}" for feature in feature_mapping]

    # Create a DataFrame with the feature importances
    df = pd.DataFrame({"Feature": feature_names, "Importance": importances})
    df = df.sort_values("Importance", ascending=False)
    # Keep only the top n=10 features
    df = df.head(top_n)

    # Plot the feature importances
    fig = px.bar(df, x="Feature", y="Importance")
    fig.show()

plot_feature_importances(tree, top_n=15, feature_mapping=feature_mapping)

This is awesome! It seems like we've found a code that determines whether there's a token at position 7 on the board or not. Interestingly, it looks like the best predictors of code 29 in layer 2 (`0_attn_out`) are both the presence of 7 and its absence (i.e. -1 in lots of positions). Let's package this all up into a single function that takes in a layer and a code and returns the most important tic-tac-toe "tokens" for that code.

In [169]:
def dt_for_code_to_board(positive_indices, negative_indices, all_examples):
    labels = np.ones(positive_indices.shape[0])
    labels = np.concatenate((labels, np.zeros(negative_indices.shape[0])))

    # Initialize the OneHotEncoder
    encoder = OneHotEncoder()

    # Fit the encoder and transform `all_indices` to a one-hot encoded matrix
    all_indices_encoded = encoder.fit_transform(all_examples)

    # Feature names for interpretation
    length = all_examples.shape[1]
    original_feature_names = ["Feature_{}".format(i) for i in range(length)]

    # Mapping each one-hot encoded column back to the original feature and category
    feature_mapping = []
    for i, categories in enumerate(encoder.categories_):
        feature_mapping.extend([(original_feature_names[i], category) for category in categories])

    # Train test split on the one-hot encoded data
    X_train, X_test, y_train, y_test = train_test_split(all_indices_encoded, labels, test_size=0.2, random_state=49)

    # Train the decision tree model on the one-hot encoded data
    tree = DecisionTreeClassifier()
    tree.fit(X_train, y_train)

    # Get the predictions
    predictions = tree.predict(X_test)
    accuracy = (predictions == y_test).mean()
    return tree, accuracy, feature_mapping

def code_layer_to_board_importance(code: int, layer: int, all_indices, small_train, verbose=True):
    # Some autointerpretability
    cl_indices = np.where(all_indices[:, layer] == code)[0]
    cl_indices_neg = np.where(all_indices[:, layer] != code)[0]
    
    # Get the examples
    # Using indices, get all examples of small train data where this occurs
    positive_examples = small_train[cl_indices]
    negative_examples = small_train[cl_indices_neg]
    positive_indices = all_indices[cl_indices]
    negative_indices = all_indices[cl_indices_neg]
    # Combine positive and negative indices
    all_indices_cl = np.concatenate((positive_indices, negative_indices), axis=0)
    all_examples = np.concatenate((positive_examples, negative_examples), axis=0)

    # Train the decision tree model
    tree, accuracy, feature_mapping = dt_for_code_to_board(positive_indices, negative_indices, all_examples)
    if verbose:
        print(f"Accuracy: {accuracy:.3f}")
        print(f"Max Depth: {tree.get_depth()}")

    # Return feature importances
    importances = tree.feature_importances_

    # Get the feature names
    feature_names = [f"Position{feature[0].replace("Feature_", "")}_Code{feature[1]}" for feature in feature_mapping]

    # Create a DataFrame with the feature importances
    df = pd.DataFrame({"Feature": feature_names, "Importance": importances})
    df = df.sort_values("Importance", ascending=False)
    # Keep only non-zero importances
    df = df[df["Importance"] > 0]
    df.reset_index(inplace=True, drop=True)

    return df, accuracy

df, accuracy = code_layer_to_board_importance(29, 2, all_indices, small_train)
print(df.head(10))

Accuracy: 0.936
Max Depth: 21
            Feature  Importance
0   Position3_Code7    0.069498
1  Position3_Code-1    0.068297
2   Position4_Code7    0.061443
3   Position2_Code7    0.047282
4   Position1_Code7    0.046994
5   Position2_Code5    0.040697
6  Position6_Code-1    0.038264
7  Position2_Code-1    0.037090
8   Position0_Code1    0.033667
9   Position1_Code1    0.028997


In [170]:
df.head(10)['Feature'].values

array(['Position3_Code7', 'Position3_Code-1', 'Position4_Code7',
       'Position2_Code7', 'Position1_Code7', 'Position2_Code5',
       'Position6_Code-1', 'Position2_Code-1', 'Position0_Code1',
       'Position1_Code1'], dtype=object)

Now, we want to pass this information to GPT-4 to see if it can interpret the actual features found, and find some similarity. (In the future, we should have a hold out test set of positive and negative examples, and see how well GPT-4 can predict positive/negative presence of code based on the board state. Also, we should compare this with sparse autoencoder. Also compare accuracy to just decision tree - want to show this quantitative approach is much better.) For now, we'll just get GPT-4 to write a quick description of the general "feature" the code is representing.

In [171]:
import yaml
from openai import AzureOpenAI

config = yaml.safe_load(open("config.yaml"))

client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)


def get_response(client: AzureOpenAI, messages: list, **kwargs):
    """
    Generic function to get a content response from the OpenAI API
    """

    config = yaml.safe_load(open("config.yaml", "r")) # need to refigure this

    response = client.chat.completions.create(
        model=config["deployment_name"], messages=messages, **kwargs
    )
    return response.choices[0].message.content

In [181]:
def format_prompt(feature_df: pd.DataFrame):

    PROMPT = f"""
    ### Task Description

    The objective is to identify patterns within a subset of tic-tac-toe board states using a data-driven approach. Each tic-tac-toe board state is represented through a series of features that describe the sequence of moves made during the game. These features are labeled as "Position<X>_Code<Y>", where:

    - X represents the sequence of the move within the game, ranging from 0 (the first move) to 8 (the potential final move).
    - Y indicates a specific position on the tic-tac-toe board, with values ranging from 0 to 8 that correspond to positions from the top left (0) to the bottom right (8) of the board. A value of -1 for Y signifies that the game has not yet reached the corresponding number of moves (given by X). A value of 9 for Y signifies that the game has been won prior to that move (given by X).

    The importance of each feature is quantified, allowing us to deduce the most significant patterns in the progression of the game states. 

    ### Objective

    Your task is to analyze the provided feature importance rankings to deduce the most likely significant pattern or strategy indicated by the data. This involves understanding which board positions are prioritized and at which point in the game they become crucial.

    ### Example

    Consider the following set of feature importance data for a subset of tic-tac-toe games:

    ```
    Feature            Importance
    Position3_Code7    0.070284
    Position3_Code-1   0.067388
    Position4_Code7    0.061443
    Position2_Code7    0.048451
    Position1_Code7    0.046994
    ```

    **Analysis and Conclusion:**

    Based on the data, the most significant pattern observed is that "A token is placed in position 7 by the 5th move." This pattern is deduced from the recurring importance of position 7 in the rankings, indicating its critical role in the progression or outcome of the game within this subset. 
    This pattern is one-sentence, specific and to the point.

    ### Task for You

    Given a new set of feature importance rankings for a different subset of tic-tac-toe games:\n
    
    {feature_df.to_markdown()}\n
    
    Analyze the data to identify the most likely significant pattern. Only return one sentence about the pattern that is specific and to the point. Do not qualify, just make a Bayesian guess and then assume it to be true.
    """

    return PROMPT

prompt = format_prompt(df.head(10))
messages = [{"role": "user", "content": prompt}]
response = get_response(client, messages)

In [173]:
response

'"Most games in this subset see a token placed in position 7 by the 4th move, and also frequently in position 5 by the 3rd move."'

This seems to be working. Now, let's get all the most commonly occurring codes in the layer they occur, and collect a dataset of interpretable codes.

In [182]:
# Collect the histogram data as a df
hist_data = []
for i in range(all_indices.shape[1]):
    hist, bins = np.histogram(all_indices[:, i], bins=32)
    hist_data.append(hist)

hist_df = pd.DataFrame(hist_data).T
# Set column names
column_labels = ['embed', 'pos_embed', '0_attn_out', '1_attn_out']
hist_df.columns = column_labels
hist_df

Unnamed: 0,embed,pos_embed,0_attn_out,1_attn_out
0,1041,3337,2987,156
1,0,0,804,107
2,0,0,0,332
3,0,0,0,369
4,0,2184,330,319
5,0,0,0,313
6,2035,0,664,395
7,0,1,27,978
8,1004,0,615,305
9,0,1139,0,247


Let's just focus on the first attention layer for now, because it has a lot of codes that the embed and unembed layers don't have, and is much less uniform than the last layer.

In [183]:
# Get the row index of the 10 most frequent codes for the '0_attn_out' column
top_10_codes = hist_df['0_attn_out'].nlargest(10).index

code_interp = []

for code in tqdm(top_10_codes):
    code_df, accuracy = code_layer_to_board_importance(code, 2, all_indices, small_train, verbose=False)
    prompt = format_prompt(code_df.head(5))
    if code_df.shape[0] > 0:
        auto_interpretation = get_response(client, [{"role": "user", "content": prompt}])
        # Create dictionary
        code_interp.append({"code": code, "accuracy": accuracy, "df": code_df, "interpretation": auto_interpretation})

100%|██████████| 10/10 [00:16<00:00,  1.67s/it]


In [184]:
for interp in code_interp:
    print('-'*100)
    print(f"Code: {interp['code']}, Accuracy: {interp['accuracy']}")
    print(interp['df'].head(5))
    print(interp['interpretation'])

----------------------------------------------------------------------------------------------------
Code: 0, Accuracy: 0.9
            Feature  Importance
0  Position4_Code-1    0.306978
1  Position6_Code-1    0.099413
2   Position8_Code9    0.075209
3   Position7_Code9    0.070067
4  Position7_Code-1    0.027932
"The game often ends (or is decided) by the 9th move, preferably placing a token in the 7th or 8th position prematurely or the game doesn't progress beyond the 4th or 6th move."
----------------------------------------------------------------------------------------------------
Code: 30, Accuracy: 0.9875
            Feature  Importance
0  Position2_Code-1    0.277346
1  Position4_Code-1    0.070525
2   Position1_Code8    0.067394
3   Position1_Code2    0.059119
4   Position0_Code1    0.057803
"By the third move, a token is either not yet placed or placed in position 1 or position 2, indicating a strategy that prioritizes early control of the center and bottom positions on the

Plot some of these:

In [189]:
from utils import plot_tic_tac_toe_on_tokens, plot_tic_tac_toe_boards

# Some autointerpretability
code = 29
layer = 2
cl_indices = np.where(all_indices[:, layer] == code)[0]
cl_indices_neg = np.where(all_indices[:, layer] != code)[0]

# Using indices, get all examples of small train data where this occurs
positive_examples = small_train[cl_indices]

x = positive_examples[0]

# Remove -1s
x = x[x != -1]
x

tensor([8, 1, 6, 7])

In [200]:
# Get a random index from range length of positive examples
i = np.random.choice(range(len(positive_examples)))
x = positive_examples[i]
x = x[x != -1]
plot_tic_tac_toe_on_tokens(x)

In [211]:
# Get 9 random examples
random_indices = np.random.choice(range(len(positive_examples)), 9, replace=False)
random_examples = positive_examples[random_indices]
random_examples = [x[x != -1] for x in random_examples]
plot_tic_tac_toe_boards(random_examples)