In [1]:
from tabpfn.transformer_make_model import TabPFNMaker, load_model_maker, extract_linear_model, predict_with_linear_model, ForwardLinearModel, MotherNet
from tabpfn import encoders
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = "models_diff/prior_diff_real_checkpoint_output_attention_nlayer6_mlp_lr0001_multiclass_04_14_2023_20_16_03_n_0_epoch_on_exit.cpkt"
model_state, _, config  = torch.load(path)
encoder = encoders.Linear(config['num_features'], config['emsize'], replace_nan_by_zero=True)
y_encoder = encoders.OneHotAndLinear(config['max_num_classes'], emsize=config['emsize'])
loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.ones(int(config['max_num_classes'])))
model_maker = config.get('model_maker', "")
if model_maker  == "mlp":
    model = MotherNet(ninp=config['emsize'], nlayers=config['nlayers'], n_out=config['max_num_classes'], nhead=config['nhead'],nhid=config['emsize'] * config['nhid_factor'], encoder=encoder, y_encoder=y_encoder)
elif  model_maker:
    model = TabPFNMaker(ninp=config['emsize'], nlayers=config['nlayers'], n_out=config['max_num_classes'], nhead=config['nhead'],nhid=config['emsize'] * config['nhid_factor'], encoder=encoder, y_encoder=y_encoder)

model.criterion = loss
module_prefix = 'module.'
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}

model.load_state_dict(model_state)

RuntimeError: Error(s) in loading state_dict for MotherNet:
	Unexpected key(s) in state_dict: "decoder.query", "decoder.output_layer.q_proj_weight", "decoder.output_layer.k_proj_weight", "decoder.output_layer.v_proj_weight", "decoder.output_layer.in_proj_bias", "decoder.output_layer.out_proj.weight", "decoder.output_layer.out_proj.bias". 
	size mismatch for decoder.mlp.0.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([256, 128]).

In [3]:
import numpy as np
from tabpfn.utils import normalize_data, normalize_by_used_features_f

def extract_mlp_model(model, X_train, y_train, device="cpu"):
    max_features = 100
    eval_position = X_train.shape[0]
    n_classes = len(np.unique(y_train))
    n_features = X_train.shape[1]

    ys = torch.Tensor(y_train).to(device)
    xs = torch.Tensor(X_train).to(device)

    eval_xs_ = normalize_data(xs, eval_position)

    eval_xs = normalize_by_used_features_f(eval_xs_, X_train.shape[-1], max_features,
                                                   normalize_with_sqrt=False)
    x_all_torch = torch.concat([eval_xs, torch.zeros((X_train.shape[0], 100 - X_train.shape[1]), device=device)], axis=1)
    
    x_src = model.encoder(x_all_torch.unsqueeze(1)[:len(X_train)])
    y_src = model.y_encoder(ys.unsqueeze(1).unsqueeze(-1))
    train_x = x_src + y_src
    # src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
    output = model.transformer_encoder(train_x)
    b1, w1, b2, w2 = model.decoder(output)
    encoder_weight = model.encoder.get_parameter("weight")
    encoder_bias = model.encoder.get_parameter("bias")

    total_weights = torch.matmul(encoder_weight[:, :n_features].T, w1)
    total_biases = torch.matmul(encoder_bias, w1) + b1
    return  total_biases.squeeze().detach().cpu().numpy(), total_weights.squeeze().detach().cpu().numpy() / (n_features / max_features), b2.squeeze()[:n_classes].detach().cpu().numpy(), w2.squeeze()[:, :n_classes].detach().cpu().numpy()

In [4]:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=3)


In [5]:
b1, w1, b2, w2 = extract_mlp_model(model, X_train, y_train)

In [6]:
b2.shape

(3,)

In [7]:
def predict_with_mlp_model(X_train, X_test, b1, w1, b2, w2):
    mean = X_train.mean(axis=0)
    std = X_train.std(axis=0, ddof=1) + .000001
    X_test_scaled = (X_test - mean) / std
    X_test_scaled = np.clip(X_test_scaled, a_min=-100, a_max=100)
    res = np.dot(np.maximum(np.dot(X_test_scaled, w1) + b1, 0), w2) + b2
    from scipy.special import softmax
    return softmax(res / .8, axis=1)

In [8]:
res = predict_with_mlp_model(X_train, X_train, b1, w1, b2, w2)

In [9]:
res.shape

(112, 3)

In [10]:
(np.argmax(res, axis=1) == y_train).mean()

0.9642857142857143

Bad pipe message: %s [b'\xb4\xd0\xebs\xbe\n{\xad\xed\tj\xce\x80.m\xe5\xf8\xd3 \x078 >\x9c)\xed\x97\xb1\x82\x99\x85>5\xdb \xe3\xb2@\x89\xc5\xbd\x9f\xa2\xfd\xcd&\x17T\x87\xf0\x9a\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08', b'\x06\x04\x01\x05']
Bad pipe message: %s [b'']
Bad pipe message: %s [b"wq\xbd\x1f\x14\xbe(\x11\x04\xcd\xceM\t`\xb9\x05\xc1\x96 \xff\xf4\xd1\xd1L{\x07\x1d\xc3\xea\xdbR\x19S\xa5m\x8b\xe8\xe0'|\x01\xff\xeb\xad\x85\xcd\xda9K\x06\xe6\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00"]
Bad pipe message: %s [b'']
Bad pipe message: %s [b'\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00

In [17]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import LabelEncoder

class MotherNetClassifier(ClassifierMixin, BaseEstimator):
    def __init__(self, path=None, device="cpu"):
        self.path = path or "models_diff/prior_diff_real_checkpoint_predict_mlp_nlayer12_multiclass_04_13_2023_16_41_16_n_0_epoch_37.cpkt"
        self.device = device
        
    def fit(self, X, y):
        self.X_train_ = X
        le = LabelEncoder()
        y = le.fit_transform(y)
        model_state, _, config  = torch.load(path)
        encoder = encoders.Linear(config['num_features'], config['emsize'], replace_nan_by_zero=True)
        y_encoder = encoders.OneHotAndLinear(config['max_num_classes'], emsize=config['emsize'])
        loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.ones(int(config['max_num_classes'])))
        model_maker = config.get('model_maker', "")
        if model_maker  == "mlp":
            model = MotherNet(ninp=config['emsize'], nlayers=config['nlayers'], n_out=config['max_num_classes'], nhead=config['nhead'],nhid=config['emsize'] * config['nhid_factor'], encoder=encoder, y_encoder=y_encoder)
        elif  model_maker:
            model = TabPFNMaker(ninp=config['emsize'], nlayers=config['nlayers'], n_out=config['max_num_classes'], nhead=config['nhead'],nhid=config['emsize'] * config['nhid_factor'], encoder=encoder, y_encoder=y_encoder)

        model.criterion = loss
        module_prefix = 'module.'
        model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}

        model.load_state_dict(model_state)
        model.to(self.device)
        b1, w1, b2, w2 = extract_mlp_model(model, X_train, y_train, device=self.device)
        self.parameters_  = (b1, w1, b2, w2)
        self.classes_ = le.classes_
        return self
        
    def predict_proba(self, X):
        return predict_with_mlp_model(self.X_train_, X, *self.parameters_)
    
    def predict(self, X):
        return self.classes_[self.predict_proba(X).argmax(axis=1)]


In [15]:
MotherNetClassifier().fit(X_train, y_train).score(X_test, y_test)

0.9473684210526315

In [18]:
MotherNetClassifier(device="cuda").fit(X_train, y_train).score(X_test, y_test)

0.9473684210526315

Bad pipe message: %s [b'\xe9&\x9f\xab~\x90i3,\xf0g5\xeaS', b'\xfb/ \x05\x03\nX\x8cAhl~\x83\x1eK\xb6\xad\x16y\xbe\xee|\xd6\xc6%\xb3o\x14\x0c}\x0c\xa0!q\r\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 P']
Bad pipe message: %s [b"\x8f\xee(\xcf\xc2\x8b]\xd8\x804\xcf\xf3S/v}\xb2\xb0\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x