In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import joblib
import sys
import imodels
import imodelsx.process_results
from collections import defaultdict
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from tabpfn import TabPFNClassifier, TabPFNRegressor
from sklearn.datasets import make_classification, make_regression
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, r2_score

# data
n = 2000
d = 20
test_frac = 0.2

# DGP
n_classes = 2
n_informative = 10

In [None]:
# increasing tree depth experiment

# generate synthetic data from decision trees of increasing depths and test tabpfn acc
r = defaultdict(list)
depths = [1, 2, 3, 4, 5]  # , 6, 7, 8, 9, 10]
r['depth'] = depths
for depth in tqdm(depths):
    X, y = make_classification(
        n_samples=n, n_features=d, n_informative=n_informative, n_classes=n_classes, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_frac, random_state=42)
    tree = DecisionTreeClassifier(max_depth=depth)
    tree.fit(X_train, y_train)
    y_train_tree = tree.predict(X_train)
    y_test_tree = tree.predict(X_test)

    classifier = TabPFNClassifier(device='cuda')
    classifier.fit(X_train, y_train_tree)
    y_test_tabpfn = classifier.predict(X_test)
    acc = accuracy_score(y_test_tree, y_test_tabpfn)
    r['acc'].append(acc)
    print(acc)

r = pd.DataFrame(r)
sns.lineplot(data=r, x='depth', y='acc')
plt.show()

# Activations analysis

In [84]:
# regression
X, y = make_regression(
    n_samples=n, n_features=d, n_informative=n_informative, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=test_frac, random_state=42)
regressor = TabPFNRegressor(device='cuda', n_estimators=1)
regressor.fit(X_train, y_train)
y_test_tabpfn = regressor.predict(X_test)
print('r2 test', r2_score(y_test, y_test_tabpfn))

# classification
X, y = make_classification(
    n_samples=n, n_features=d, n_informative=n_informative, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=test_frac, random_state=42)
classifier = TabPFNClassifier(device='cuda', n_estimators=1)
classifier.fit(X_train, y_train)
y_test_tabpfn = classifier.predict(X_test)
print('acc test', accuracy_score(y_test, y_test_tabpfn))

r2 test 0.9993153974391654
init shape (400, 20)
input preprocessed shape (400, 20)
init output torch.Size([400, 10])
here! [torch.Size([400, 2])]
acc test 0.9475


In [None]:
# model = regressor
model = classifier
m = model.model_
activations = {}


def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


# Register hooks on specific layers (example: Transformer blocks)
for idx, block in enumerate(m.transformer_encoder.layers):
    block.register_forward_hook(get_activation(f'enc_{idx}'))
for idx, block in enumerate(m.decoder_dict.standard):
    block.register_forward_hook(get_activation(f'dec_{idx}'))

# Fit model (for inference)
preds = model.predict_proba(X_test)
# preds = model.predict(X_test, output_type='mode')

# Display extracted activations

print('enc sizes are batch_size, num_samples (train + test), d_model?, d_output?')
print(
    'dec sizes are (num_samples [test], 1, 10 classes [actual values are the first n_classes, rest are dropped])')
# note: regression basically does 5000-class classification over quantiles then averages them
# (classification is easier to study)
for layer, activation in activations.items():
    print(f"{layer}: {activation.shape}")

init shape (400, 20)
input preprocessed shape (400, 20)
init output torch.Size([400, 10])
filtered outputs, one for each classifier [torch.Size([400, 2])]
enc sizes are batch_size, num_samples (train + test), d_model?, d_output?
dec sizes are (num_samples [test], 1, 10 classes [actual values are the first n_classes, rest are dropped])
enc_0: torch.Size([1, 2000, 12, 192])
enc_1: torch.Size([1, 2000, 12, 192])
enc_2: torch.Size([1, 2000, 12, 192])
enc_3: torch.Size([1, 2000, 12, 192])
enc_4: torch.Size([1, 2000, 12, 192])
enc_5: torch.Size([1, 2000, 12, 192])
enc_6: torch.Size([1, 2000, 12, 192])
enc_7: torch.Size([1, 2000, 12, 192])
enc_8: torch.Size([1, 2000, 12, 192])
enc_9: torch.Size([1, 2000, 12, 192])
enc_10: torch.Size([1, 2000, 12, 192])
enc_11: torch.Size([1, 2000, 12, 192])
dec_0: torch.Size([400, 1, 768])
dec_1: torch.Size([400, 1, 768])
dec_2: torch.Size([400, 1, 10])


In [72]:
preds

array([[2.41386704e-04, 9.99758601e-01],
       [9.99445796e-01, 5.54217782e-04],
       [6.71208109e-05, 9.99932885e-01],
       [1.00600071e-01, 8.99399936e-01],
       [4.15149890e-03, 9.95848477e-01],
       [4.18101612e-04, 9.99581873e-01],
       [2.40788958e-03, 9.97592151e-01],
       [1.74628139e-01, 8.25371861e-01],
       [1.83178810e-03, 9.98168230e-01],
       [9.96417940e-01, 3.58208641e-03],
       [5.08290194e-02, 9.49171007e-01],
       [9.76236403e-01, 2.37635709e-02],
       [9.78870153e-01, 2.11298224e-02],
       [2.48119123e-02, 9.75188076e-01],
       [8.29987407e-01, 1.70012653e-01],
       [3.30220762e-04, 9.99669790e-01],
       [9.99418855e-01, 5.81174158e-04],
       [3.72496177e-03, 9.96275067e-01],
       [9.98074770e-01, 1.92521571e-03],
       [1.01816637e-04, 9.99898195e-01],
       [6.51944079e-04, 9.99348104e-01],
       [9.76233065e-01, 2.37669107e-02],
       [9.98103023e-01, 1.89696392e-03],
       [9.96846557e-01, 3.15343705e-03],
       [2.470357