In [1]:
import os
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
import wandb
import plotly.express as px
from pathlib import Path
import itertools
import random
from IPython.display import display
import wandb
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd
import torch

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.eval_board_reconstruction as eval_board_reconstruction
import circuits.get_eval_results as get_eval_results
import circuits.f1_analysis as f1_analysis
import circuits.utils as utils
import circuits.pipeline_config as pipeline_config

from huggingface_hub import hf_hub_download
import chess_utils

import pickle
with open('meta.pkl', 'rb') as picklefile:
    meta = pickle.load(picklefile)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0")

autoencoder_group_path = "/root/chessgpt_git/chessgpt_git/SAE_BoardGameEval/autoencoders/testing_chess/"
autoencoder_path = "/root/chessgpt_git/chessgpt_git/SAE_BoardGameEval/autoencoders/testing_chess/trainer4/"

othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)
config = pipeline_config.Config()

# These both significantly reduce peak GPU memory usage
config.batch_size = 5
config.analysis_on_cpu = True

# Precompute will create both datasets and save them as pickle files
# If precompute == False, it creates the dataset on the fly
# This is far slower when evaluating multiple SAEs, but for an exploratory run it is fine
config.precompute = False

config.eval_results_n_inputs = 1000
config.eval_sae_n_inputs = 1000
config.board_reconstruction_n_inputs = 1000

# Once you have ran the analysis, you can set this to False and it will load the saved results
config.run_analysis = False
config.run_board_reconstruction = False
config.run_eval_sae = False
config.run_eval_results = False

# If you want to save the results of the analysis
config.save_results = True
config.save_feature_labels = True

print(f"Is Othello: {othello}")

Is Othello: False


In [3]:
indexing_functions = eval_sae.get_recommended_indexing_functions(othello)
indexing_function = indexing_functions[0]

expected_aggregation_output_location = eval_sae.get_output_location(
    autoencoder_path,
    n_inputs=config.eval_sae_n_inputs,
    indexing_function=indexing_function,
)

analysis_device = device

torch.cuda.empty_cache()

expected_feature_labels_output_location = expected_aggregation_output_location.replace(
    "results.pkl", "feature_labels.pkl"
)

with open(expected_feature_labels_output_location, "rb") as f:
    feature_labels = pickle.load(f)
feature_labels = utils.to_device(feature_labels, analysis_device)

In [4]:
def rc_to_square_notation(row, col):
    letters = "ABCDEFGH"
    number = row + 1
    letter = letters[col]
    return f"{letter}{number}"

def plot_board(board_RR: torch.Tensor, title: str = "Board", png_filename: Optional[str] = None):
    """
    Plots an 8x8 board with the value of the maximum square displayed in red text to two decimal places.

    Args:
        board_RR (torch.Tensor): A 2D tensor of shape (8, 8) with values from 0 to 1.
        title (str): Title of the plot.
    """
    assert board_RR.shape == (8, 8), "board_RR must be of shape 8x8"

    # Flip the board vertically
    board_RR = torch.flip(board_RR, [0])

    plt.imshow(board_RR, cmap='gray_r', interpolation='none', vmin=0, vmax=1)
    plt.colorbar()  # Adds a colorbar to help identify the values
    plt.title(title)

    # Set labels for columns (A-H)
    plt.xticks(range(8), ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])

    # Set labels for rows (1-8)
    plt.yticks(range(8), range(8, 0, -1))

    # Add gridlines mimicking a chess board
    # plt.grid(True, color='black', linewidth=1, linestyle='-', alpha=0.5)
    # plt.tick_params(bottom=False, left=False, labelbottom=True, labelleft=True)

    # Offset gridlines by 0.5 in x and y
    plt.gca().set_xticks([x - 0.5 for x in range(1, 9)], minor=True)
    plt.gca().set_yticks([y - 0.51 for y in range(1, 9)], minor=True)
    plt.grid(True, which='minor', color='black', linewidth=1, linestyle='-', alpha=0.5)

    # Find the maximum value and its position
    max_value, max_pos = torch.max(board_RR), torch.argmax(board_RR)
    max_i, max_j = torch.div(max_pos, 8, rounding_mode='floor'), max_pos % 8

    # Display the maximum value in red text at the corresponding position
    plt.text(max_j, max_i, f"{max_value:.0%}", color='red', ha='center', va='center', fontsize=12)

    if png_filename is not None:
        plt.savefig(png_filename)

    plt.show()

num_to_class = {0: "Black King", 1: "Black Queen", 2: "Black Rook", 3: "Black Bishop", 4: "Black Knight", 5: "Black Pawn",
                6: "Blank", 7: "White Pawn", 8: "White Knight", 9: "White Bishop", 10: "White Rook", 11: "White Queen", 12: "White King"}

In [5]:
function_of_interest = "board_to_piece_masked_blank_and_initial_state"

board_state_feature_labels_TFRRC = feature_labels[function_of_interest]
print(f"Board state feature labels: {board_state_feature_labels_TFRRC.shape}")
threshold = 2

board_state_feature_labels_FRRC = board_state_feature_labels_TFRRC[threshold]
board_state_counts_F = einops.reduce(board_state_feature_labels_FRRC, "F R1 R2 C -> F", "sum")

max_features = 175
demo_idx = 0
for i in range(max_features):
    if board_state_counts_F[i] > 0:
        print(f"Feature {i} has {board_state_counts_F[i]} classified squares")
        demo_idx = i

demo_feature_labels_RRC = board_state_feature_labels_FRRC[demo_idx]
print(f"\nFeature {demo_idx} has {board_state_counts_F[demo_idx].sum().item()} classified squares")

classified_squares = torch.where(demo_feature_labels_RRC == 1)
print(f"Classified squares as tensors: {classified_squares}")

row, column, classes = classified_squares

print(f"\nClassified squares for feature {demo_idx} at threshold {threshold}:")
for i in range(row.shape[0]):
    print(rc_to_square_notation(row[i].item(), column[i].item()), num_to_class[classes[i].item()])

Board state feature labels: torch.Size([11, 3914, 8, 8, 13])
Feature 1 has 1 classified squares
Feature 44 has 1 classified squares
Feature 52 has 2 classified squares
Feature 59 has 1 classified squares
Feature 66 has 1 classified squares
Feature 70 has 1 classified squares
Feature 75 has 2 classified squares
Feature 107 has 1 classified squares
Feature 109 has 1 classified squares
Feature 121 has 1 classified squares
Feature 127 has 1 classified squares
Feature 134 has 1 classified squares
Feature 143 has 1 classified squares
Feature 146 has 4 classified squares
Feature 150 has 1 classified squares
Feature 152 has 1 classified squares
Feature 165 has 2 classified squares
Feature 172 has 1 classified squares

Feature 172 has 1 classified squares
Classified squares as tensors: (tensor([2], device='cuda:0'), tensor([5], device='cuda:0'), tensor([8], device='cuda:0'))

Classified squares for feature 172 at threshold 2:
F3 White Knight


In [6]:
expected_aggregation_output_location = eval_sae.get_output_location(
    autoencoder_path,
    n_inputs=config.eval_sae_n_inputs,
    indexing_function=indexing_function,
)


with open(expected_aggregation_output_location, "rb") as f:
    aggregation_results = pickle.load(f)
aggregation_results = utils.to_device(aggregation_results, device)

In [7]:
custom_functions = config.chess_functions
formatted_results = analysis.add_off_tracker(aggregation_results, custom_functions, analysis_device)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "on",
    custom_functions,
    analysis_device,
)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "off",
    custom_functions,
    analysis_device,
)


In [8]:
formatted_results = analysis.add_off_tracker(aggregation_results, custom_functions, analysis_device)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "on",
    custom_functions,
    analysis_device,
)

formatted_results = analysis.normalize_tracker(
    formatted_results,
    "off",
    custom_functions,
    analysis_device,
)


In [23]:
chess_latent_indices = (formatted_results['board_to_check_state']['on_normalized'][2] > 0.999).nonzero()[:, 0]
chess_latent_indices

tensor([ 156,  215,  390,  583,  670,  709,  727, 1155, 1327, 1536, 1596, 1730,
        1777, 1797, 1881, 1957, 2154, 2201, 2276, 2283, 2424, 2485, 2674, 2835,
        3003, 3037, 3109, 3368, 3649, 3737], device='cuda:0')

In [17]:
formatted_results['board_to_check_state']['on_normalized'].shape

torch.Size([11, 3914, 1, 1, 1])

In [22]:
aggregation_results['alive_features'][chess_latent_indices]

tensor([ 167,  228,  410,  614,  706,  746,  765, 1218, 1396, 1610, 1673, 1810,
        1858, 1878, 1969, 2048, 2252, 2303, 2381, 2388, 2533, 2595, 2794, 2960,
        3133, 3168, 3240, 3517, 3814, 3911], device='cuda:0')

In [27]:
aggregation_results['board_to_check_state'].keys()

dict_keys(['on', 'all', 'off', 'on_normalized', 'off_normalized'])

In [31]:
from circuits.eval_sae_as_classifier import prep_data_ae_buffer_and_model
data, ae_bundle, pgn_strings_bL, encoded_inputs_bL = prep_data_ae_buffer_and_model(
        autoencoder_path,
        32,
        data,
        device,
        n_inputs,
    )

NameError: name 'batch_size' is not defined