In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import joblib
import sys
import imodels
import imodelsx.process_results
from collections import defaultdict
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from tabpfn import TabPFNClassifier, TabPFNRegressor
from sklearn.datasets import make_classification, make_regression
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, r2_score

# data
n = 2000
d = 20
test_frac = 0.2

# DGP
n_classes = 2
n_informative = 10

In [None]:
# increasing tree depth experiment

# generate synthetic data from decision trees of increasing depths and test tabpfn acc
r = defaultdict(list)
depths = [1, 2, 3, 4, 5]  # , 6, 7, 8, 9, 10]
r['depth'] = depths
for depth in tqdm(depths):
    X, y = make_classification(
        n_samples=n, n_features=d, n_informative=n_informative, n_classes=n_classes, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_frac, random_state=42)
    tree = DecisionTreeClassifier(max_depth=depth)
    tree.fit(X_train, y_train)
    y_train_tree = tree.predict(X_train)
    y_test_tree = tree.predict(X_test)

    classifier = TabPFNClassifier(device='cuda')
    classifier.fit(X_train, y_train_tree)
    y_test_tabpfn = classifier.predict(X_test)
    acc = accuracy_score(y_test_tree, y_test_tabpfn)
    r['acc'].append(acc)
    print(acc)

r = pd.DataFrame(r)
sns.lineplot(data=r, x='depth', y='acc')
plt.show()

# Activations analysis

In [None]:
# regression
X, y = make_regression(
    n_samples=n, n_features=d, n_informative=n_informative, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=test_frac, random_state=42)
regressor = TabPFNRegressor(device='cuda', n_estimators=1)
regressor.fit(X_train, y_train)
y_test_tabpfn = regressor.predict(X_test)
print('r2 test', r2_score(y_test, y_test_tabpfn))

# classification
X, y = make_classification(
    n_samples=n, n_features=d, n_informative=n_informative, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=test_frac, random_state=42)
classifier = TabPFNClassifier(device='cuda', n_estimators=1)
classifier.fit(X_train, y_train)
y_test_tabpfn = classifier.predict(X_test)
print('acc test', accuracy_score(y_test, y_test_tabpfn))

In [100]:
# model = regressor
model = classifier
m = model.model_
activations = {}


def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


# Register hooks on specific layers (example: Transformer blocks)
for idx, block in enumerate(m.transformer_encoder.layers):
    block.self_attn_between_features.register_forward_hook(
        get_activation(f'enc_attn_f_{idx}'))
for idx, block in enumerate(m.transformer_encoder.layers):
    block.self_attn_between_items.register_forward_hook(
        get_activation(f'enc_attn_i_{idx}'))
for idx, block in enumerate(m.transformer_encoder.layers):
    block.register_forward_hook(get_activation(f'enc_{idx}'))
for idx, block in enumerate(m.decoder_dict.standard):
    block.register_forward_hook(get_activation(f'dec_{idx}'))

# Fit model (for inference)
preds = model.predict_proba(X_test)
# preds = model.predict(X_test, output_type='mode')

# Display extracted activations

print('enc sizes are batch_size, num_samples (train + test), n_heads, d_output?')
# note: half the heads are training heads, whereas the other half are inference heads which do not attend to each other
print(
    'dec sizes are (num_samples [test], 1, 10 classes [actual values are the first n_classes, rest are dropped])')
# note: regression basically does 5000-class classification over quantiles then averages them
# (classification is easier to study)
for layer, activation in activations.items():
    print(f"{layer}: {activation.shape}")

init shape (400, 20)
input preprocessed shape (400, 20)
init output torch.Size([400, 10])
filtered outputs, one for each classifier [torch.Size([400, 2])]
enc sizes are batch_size, num_samples (train + test), n_heads, d_output?
dec sizes are (num_samples [test], 1, 10 classes [actual values are the first n_classes, rest are dropped])
enc_attn_0: torch.Size([1, 12, 1600, 192])
enc_attn_f_0: torch.Size([1, 2000, 12, 192])
enc_attn_i_0: torch.Size([1, 12, 1600, 192])
enc_0: torch.Size([1, 2000, 12, 192])
enc_attn_1: torch.Size([1, 12, 1600, 192])
enc_attn_f_1: torch.Size([1, 2000, 12, 192])
enc_attn_i_1: torch.Size([1, 12, 1600, 192])
enc_1: torch.Size([1, 2000, 12, 192])
enc_attn_2: torch.Size([1, 12, 1600, 192])
enc_attn_f_2: torch.Size([1, 2000, 12, 192])
enc_attn_i_2: torch.Size([1, 12, 1600, 192])
enc_2: torch.Size([1, 2000, 12, 192])
enc_attn_3: torch.Size([1, 12, 1600, 192])
enc_attn_f_3: torch.Size([1, 2000, 12, 192])
enc_attn_i_3: torch.Size([1, 12, 1600, 192])
enc_3: torch.Size(

In [96]:
d = vars(m.transformer_encoder.layers[0].self_attn_between_features)
# only show elements with scalar values
for k, v in d.items():
    if np.isscalar(v):
        print(k, v)

training False
_input_size 192
_output_size 192
_d_k 32
_d_v 32
_nhead 6
_nhead_kv 6
recompute False
init_gain 1.0
two_sets_of_queries False


In [97]:
m

PerFeatureTransformer(
  (encoder): SequentialEncoder(
    (0): RemoveEmptyFeaturesEncoderStep()
    (1): NanHandlingEncoderStep()
    (2): VariableNumFeaturesEncoderStep()
    (3): InputNormalizationEncoderStep()
    (4): VariableNumFeaturesEncoderStep()
    (5): LinearInputEncoderStep(
      (layer): Linear(in_features=4, out_features=192, bias=False)
    )
  )
  (y_encoder): SequentialEncoder(
    (0): NanHandlingEncoderStep()
    (1): MulticlassClassificationTargetEncoder()
    (2): LinearInputEncoderStep(
      (layer): Linear(in_features=2, out_features=192, bias=True)
    )
  )
  (transformer_encoder): LayerStack(
    (layers): ModuleList(
      (0-11): 12 x PerFeatureEncoderLayer(
        (self_attn_between_features): MultiHeadAttention()
        (self_attn_between_items): MultiHeadAttention()
        (mlp): MLP(
          (linear1): Linear(in_features=192, out_features=768, bias=False)
          (linear2): Linear(in_features=768, out_features=192, bias=False)
        )
       