In [16]:
import os
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
import wandb
import plotly.express as px
from pathlib import Path
import itertools
import random
from IPython.display import display
import wandb
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import Union, Optional, Tuple, Callable, Dict
from collections import Counter
import typeguard
from functools import partial
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd
import torch
import matplotlib.pyplot as plt

import circuits.eval_sae_as_classifier as eval_sae
import circuits.analysis as analysis
import circuits.eval_board_reconstruction as eval_board_reconstruction
import circuits.get_eval_results as get_eval_results
import circuits.f1_analysis as f1_analysis
import circuits.utils as utils
import circuits.pipeline_config as pipeline_config
from circuits.dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder, AutoEncoderNew
import common
import chess_utils
import chess
#from plotly_utils import imshow
#from neel_plotly import scatter, line



device = t.device('cuda' if t.cuda.is_available() else 'cpu')
import pickle
with open('meta.pkl', 'rb') as picklefile:
    meta = pickle.load(picklefile)

In [134]:
autoencoder = common.load_autoencoder(device)
model = common.load_model(device)
dataset = common.get_dataset(device)
TRAIN_TEST_GAME_SPLIT = 4000



Loading statistics aggregation dataset


In [3]:
encoded_inputs_tensor = t.stack([t.tensor(x) for x in dataset['encoded_inputs']]).to(device)
is_check = dataset['board_to_check_state'].squeeze(-1)

In [4]:
aggregation_results = common.get_aggregation_results(1000)
formatted_results = common.get_formatted_results(aggregation_results)
features_for_check_state = common.get_true_feature_indices(formatted_results, "board_to_check_state")
features_for_check_state = features_for_check_state.to(device)

In [5]:

max_sae_activations = []
def get_activation(name):
    def hook(model, input, output):
        encoded_activations = autoencoder.encode(output[0]) # batch_size x len_seq x n_features_sae
        collapsed = (encoded_activations.max(dim=0).values.max(dim=0).values)
        max_sae_activations.append(collapsed.unsqueeze(0))
    return hook

t.set_grad_enabled(False)
activation_handle = model.transformer.h[5].register_forward_hook(get_activation(f"resid_stream_{5}"))
num_batches = 100
batch_size = len(encoded_inputs_tensor) // num_batches
if device == 'cpu':
    num_batches = 1
    batch_size = 1
    
for i in range(num_batches):
    model(encoded_inputs_tensor[i*batch_size:(i+1)*batch_size])
    t.cuda.empty_cache()
activation_handle.remove()

max_activations = t.concat(max_sae_activations).max(dim=0).values

In [6]:
relevant_sae_activations = []
def get_activation(name, relevant_features):
    def hook(model, input, output):
        encoded_activations = autoencoder.encode(output[0]) # batch_size x len_seq x n_features_sae
        relevant_sae_activations.append(encoded_activations[:, :, features_for_check_state])
    return hook

t.set_grad_enabled(False)
activation_handle = model.transformer.h[5].register_forward_hook(get_activation(f"resid_stream_{5}", features_for_check_state))
num_batches = 100
batch_size = len(encoded_inputs_tensor) // num_batches
if device == 'cpu':
    num_batches = 1
    batch_size = 1
for i in range(num_batches):
    model(encoded_inputs_tensor[i*batch_size:(i+1)*batch_size])
    t.cuda.empty_cache()
activation_handle.remove()

In [7]:
all_relevant_sae_activations = t.concat(relevant_sae_activations, dim=0)
is_this_a_dot = (encoded_inputs_tensor == meta['stoi']['.']).to(device)
all_relevant_sae_activations.shape, is_this_a_dot.shape


(torch.Size([5000, 256, 30]), torch.Size([5000, 256]))

In [8]:
sae_activations_for_dots = all_relevant_sae_activations * is_this_a_dot.unsqueeze(-1)
inverse_sae_activations_for_dots = (max_activations[features_for_check_state] - all_relevant_sae_activations) * is_this_a_dot.unsqueeze(-1)

In [9]:
sae_activations_for_dots.shape, max_activations.shape

(torch.Size([5000, 256, 30]), torch.Size([4096]))

In [10]:
threshold = 0.4
indices_of_high_activation = (sae_activations_for_dots > threshold*max_activations[features_for_check_state]).nonzero()
indices_of_low_activation = (inverse_sae_activations_for_dots > 0.99*max_activations[features_for_check_state]).nonzero()
indices_of_high_activation.shape, indices_of_low_activation.shape

(torch.Size([4115, 3]), torch.Size([3608581, 3]))

In [150]:
indices_of_high_activation[:, 1].unique(), features_for_check_state.unique()

(tensor([ 20,  21,  22,  28,  29,  30,  31,  32,  33,  36,  37,  38,  39,  40,
          41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,
          55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,
          69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,
          83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,
          97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
         111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124,
         125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
         139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152,
         153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166,
         167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
         181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194,
         195, 196, 197, 198, 199, 200, 201, 202, 203

How well do we do if we just use 'is_check' as a classifier for some random feature?

In [155]:
def is_check(board):
    return board.is_check()

i = 5
boards_to_show = 3
low_to_high_data_points = 50

def get_last_space_index(game_string):
    for i in range(len(game_string)-1, -1, -1):
        if game_string[i] == ' ':
            return i
    return -1

def get_activation_data_for_feature(i, low_to_high_data_points, use_train_data):
    if use_train_data:
        filtered_indices_of_high_activation = indices_of_high_activation[indices_of_high_activation[:, 0] < TRAIN_TEST_GAME_SPLIT]
        filtered_indices_of_low_activation = indices_of_low_activation[indices_of_low_activation[:, 0] < TRAIN_TEST_GAME_SPLIT]
    else:
        filtered_indices_of_high_activation = indices_of_high_activation[indices_of_high_activation[:, 0] >= TRAIN_TEST_GAME_SPLIT]
        filtered_indices_of_low_activation = indices_of_low_activation[indices_of_low_activation[:, 0] >= TRAIN_TEST_GAME_SPLIT]
    
    high_activation_indices_for_feature = filtered_indices_of_high_activation[filtered_indices_of_high_activation[:, 2] == i]
    low_activation_indices_for_feature = filtered_indices_of_low_activation[filtered_indices_of_low_activation[:, 2] == i]
    low_activation_indices_for_feature = filtered_indices_of_low_activation[filtered_indices_of_low_activation[:, 1] > 100]
    low_activation_indices_for_feature = low_activation_indices_for_feature[t.randint(0, len(low_activation_indices_for_feature), (low_to_high_data_points * len(high_activation_indices_for_feature), ))]
    return high_activation_indices_for_feature, low_activation_indices_for_feature


def get_counts(indices_for_feature, board_meets_criteria, use_train_data):
    positive_count = 0
    negative_count = 0
    for game_index, position_index, _ in indices_for_feature:
        if use_train_data == (game_index >= TRAIN_TEST_GAME_SPLIT):
            continue
        game_string = dataset['decoded_inputs'][game_index][:position_index]
        last_space_index = get_last_space_index(game_string)
        game_string = dataset['decoded_inputs'][game_index][:last_space_index]
        board = chess_utils.pgn_string_to_board(game_string)
        if board_meets_criteria(board):
            positive_count += 1
        else:
            negative_count += 1
    return positive_count, negative_count

def evaluate_board_metric_criteria_for_feature(i, board_meets_criteria, use_train_data):
    high_activation_indices_for_feature, low_activation_indices_for_feature = get_activation_data_for_feature(i, low_to_high_data_points, use_train_data)

    true_positive_count, false_negative_count = get_counts(high_activation_indices_for_feature, board_meets_criteria, use_train_data)
    false_positive_count, true_negative_count = get_counts(low_activation_indices_for_feature, board_meets_criteria, use_train_data)
    return true_positive_count, false_negative_count, false_positive_count, true_negative_count
    
def get_summary_stats(true_positive_count, false_negative_count, false_positive_count, true_negative_count):
    accuracy = (true_positive_count + true_negative_count) / (true_positive_count + true_negative_count + false_positive_count + false_negative_count)
    precision = true_positive_count / (true_positive_count + false_positive_count)
    recall = true_positive_count / (true_positive_count + false_negative_count)
    f1_score = 2 * (precision * recall)/(precision + recall)
    return accuracy, precision, recall, f1_score

true_positive_count, false_negative_count, false_positive_count, true_negative_count = evaluate_board_metric_criteria_for_feature(i, is_check, use_train_data=True)
accuracy, precision, recall, f1_score = get_summary_stats(true_positive_count, false_negative_count, false_positive_count, true_negative_count)
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1_score}")


Accuracy: 0.9747899159663865
Precision: 0.4375
Recall: 1.0
F1 Score: 0.6086956521739131


Let's try to classify the feature using some simple hardcoded criteria

In [135]:
def split_into_three_categories(a, b):
    if a < b:
        return -1
    else:
        return int(a > b)

def get_attacking_pieces_list(board):
    return [board.piece_at(x).symbol() for x in board.attackers(not board.turn, board.king(board.turn))]

def count_attacking_pieces(board):
    return len(get_attacking_pieces_list(board))

def piece_attacking_king(board):
    output = get_attacking_pieces_list(board)
    if len(output) > 1:
        print(f"{len(output)} attackers")
    if len(output) == 0:
        print("No attackers")
        return None
    return output[0]

def location_of_piece_attacking_king(board):
    output = [chess.square_name(x) for x in board.attackers(not board.turn, board.king(board.turn))]

    if len(output) == 0:
        print("No attackers")
        return None
    return output[0]

def relative_location_of_piece_attacking_king(board):
    king_location = board.king(board.turn)
    attacker_location = location_of_piece_attacking_king(board)
    if attacker_location is None:
        return None
    king_file, king_rank = chess.square_file(king_location), chess.square_rank(king_location)
    attacker_file, attacker_rank = chess.square_file(chess.parse_square(attacker_location)), chess.square_rank(chess.parse_square(attacker_location))
    return attacker_file - king_file, attacker_rank - king_rank

def direction_of_attack(board):
    king_location = board.king(board.turn)
    attacker = piece_attacking_king(board)
    king_square = chess.square_name(king_location)
    attacker_square = location_of_piece_attacking_king(board)
    if attacker_square is None:
        return None
    index_1 = split_into_three_categories(king_square[0], attacker_square[0])
    index_2 = split_into_three_categories(king_square[1], attacker_square[1])
    if index_1 == index_2 == 0:
        raise Exception("Index 1 and 2 are 0")
    return index_1, index_2

def is_check(board):
    return board.is_check()


criteria_functions = [is_check, piece_attacking_king, location_of_piece_attacking_king, direction_of_attack, count_attacking_pieces, relative_location_of_piece_attacking_king]

def collect_information_about_criteria(high_activation_indices, criteria_functions):
    output_criteria_dict = {x.__name__: [] for x in criteria_functions}
    for game_index, position_index, _ in high_activation_indices:
        if game_index >= TRAIN_TEST_GAME_SPLIT:
            continue
        game_string = dataset['decoded_inputs'][game_index][:position_index]
        last_space_index = get_last_space_index(game_string)
        game_string = dataset['decoded_inputs'][game_index][:last_space_index]
        board = chess_utils.pgn_string_to_board(game_string)
        for criteria_function in criteria_functions:
            output_criteria_dict[criteria_function.__name__].append(criteria_function(board))
    for criterion in criteria_functions:
        output_criteria_dict[criterion.__name__] = Counter(output_criteria_dict[criterion.__name__])
    return output_criteria_dict

collect_information_about_criteria(high_activation_indices_for_feature, criteria_functions)

        


2 attackers
2 attackers


{'is_check': Counter({True: 210}),
 'piece_attacking_king': Counter({'n': 201, 'p': 6, 'q': 1, 'r': 1, 'b': 1}),
 'location_of_piece_attacking_king': Counter({'g4': 23,
          'd3': 20,
          'c2': 18,
          'e2': 18,
          'f3': 17,
          'f4': 13,
          'e4': 13,
          'f2': 13,
          'e3': 10,
          'g3': 7,
          'h3': 7,
          'd4': 5,
          'c3': 5,
          'c4': 5,
          'f5': 4,
          'e5': 3,
          'h4': 3,
          'd2': 3,
          'g2': 3,
          'a4': 2,
          'a2': 2,
          'b4': 2,
          'a3': 2,
          'b3': 2,
          'a5': 2,
          'f6': 1,
          'e1': 1,
          'd5': 1,
          'b5': 1,
          'f8': 1,
          'h5': 1,
          'c1': 1,
          'g5': 1}),
 'direction_of_attack': Counter({(1, -1): 132,
          (-1, -1): 67,
          (1, 1): 7,
          (-1, 1): 2,
          (-1, 0): 1,
          (0, -1): 1}),
 'count_attacking_pieces': Counter({1: 209, 2: 1}),
 

In [136]:
# game_string
# board.piece_at(board.peek().to_square)

In [137]:
all_criteria_dicts = {}
for i in range(len(features_for_check_state)):
    print(i)
    high_activation_indices_for_feature, low_activation_indices_for_feature = get_activation_data_for_feature(i, low_to_high_data_points)
    all_criteria_dicts[i] = collect_information_about_criteria(high_activation_indices_for_feature, criteria_functions)
    # true_positive_count, false_negative_count = get_counts(high_activation_indices_for_feature, board_meets_criteria)
    # false_positive_count, true_negative_count = get_counts(low_activation_indices_for_feature, board_meets_criteria)

0
1
2
3
4
5
6
7
8
9
2 attackers
2 attackers
10
No attackers
No attackers
No attackers
No attackers
No attackers
11
12
No attackers
No attackers
No attackers
No attackers
No attackers
No attackers
No attackers
No attackers
No attackers
No attackers
13
14
15
16
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
2 attackers
17
18
19
No attackers
No attackers
No attackers
No attackers
No attackers
20
21
22
23
24
25
26
27
28
29
2 attackers
2 attackers


In [120]:
all_criteria_acceptable_keys_dict = {}
threshold_for_including_key = 0.05
threshold_for_including_feature = 0.5

for i, criteria_dict in all_criteria_dicts.items():
    criterion_acceptable_keys_dict = {}
    for criterion, criterion_results_counter in criteria_dict.items():
        counter_values = list(criterion_results_counter.values())
        data_point_count = sum(counter_values)
        if len(counter_values) == 0 or max(counter_values) < data_point_count * threshold_for_including_feature:
            continue
        acceptable_keys = [x for x in criterion_results_counter.keys() if criterion_results_counter[x]/data_point_count > threshold_for_including_key]
        criterion_acceptable_keys_dict[criterion] = acceptable_keys
    if len(criterion_acceptable_keys_dict) == 0:
        continue
    all_criteria_acceptable_keys_dict[i] = criterion_acceptable_keys_dict
print(all_criteria_acceptable_keys_dict.keys())
    

dict_keys([0, 1, 3, 5, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 21, 22, 25, 26, 28, 29])


In [156]:
for i in all_criteria_acceptable_keys_dict.keys():
    print(i)
    print("Just using is check")
    true_positive_count, false_negative_count, false_positive_count, true_negative_count = evaluate_board_metric_criteria_for_feature(i, is_check, use_train_data=False)
    print(true_positive_count, false_negative_count, false_positive_count, true_negative_count)
    accuracy, precision, recall, f1_score = get_summary_stats(true_positive_count, false_negative_count, false_positive_count, true_negative_count)
    # print(f"Accuracy: {round(accuracy, 3)}")
    print(f"Precision: {round(precision, 3)}")
    # print(f"Recall: {round(recall, 3)}")
    # print(f"F1 Score: {round(f1_score, 3)}")

    custom_criteria_dict = all_criteria_acceptable_keys_dict[i]
    def custom_criteria(board):
        for criterion_function in criteria_functions:
            if criterion_function.__name__ not in custom_criteria_dict:
                continue
            criterion_value = criterion_function(board)
            if criterion_value not in custom_criteria_dict[criterion_function.__name__]:
                return False
        return True
    print("Using custom criteria")
    true_positive_count, false_negative_count, false_positive_count, true_negative_count = evaluate_board_metric_criteria_for_feature(i, custom_criteria, use_train_data=False)
    print(true_positive_count, false_negative_count, false_positive_count, true_negative_count)
    accuracy, precision, recall, f1_score = get_summary_stats(true_positive_count, false_negative_count, false_positive_count, true_negative_count)
    # print(f"Accuracy: {round(accuracy, 3)   }")
    print(f"Precision: {round(precision, 3)}")
    # print(f"Recall: {round(recall, 3)}")
    # print(f"F1 Score: {round(f1_score, 3)}")
    

0
Just using is check
24 0 40 1160
Precision: 0.375
Using custom criteria
24 0 22 1178
Precision: 0.522
1
Just using is check
128 0 168 6232
Precision: 0.432
Using custom criteria
2 attackers
126 2 63 6337
Precision: 0.667
3
Just using is check
31 0 43 1507
Precision: 0.419
Using custom criteria
31 0 2 1548
Precision: 0.939
5
Just using is check
23 0 38 1112
Precision: 0.377
Using custom criteria
19 4 1 1149
Precision: 0.95
8
Just using is check


ValueError: Invalid move: Bb4

In [128]:
criterion

'relative_location_of_piece_attacking_king'