In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from tabpfn.transformer_make_model import TransformerModelMaker, load_model_maker, extract_linear_model, predict_with_linear_model, ForwardLinearModel


In [3]:
from sklearn.datasets import load_iris
iris = load_iris()

In [4]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=3)

In [5]:
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class ForwardLinearModelASDF(ClassifierMixin, BaseEstimator):
    def __init__(self, path=None, n_permutations=1, device="cpu"):
        self.n_permutations = n_permutations
        self.device = device
        self.path = path or "models_diff/prior_diff_real_checkpoint_predict_linear_coefficients_nlayer_6_multiclass_04_11_2023_01_26_19_n_0_epoch_94.cpkt"
        
    def fit(self, X, y):
        self.X_train_ = X
        model = load_model_maker(self.path).to(self.device)
        n_classes = len(np.unique(y))
        weight_list = []
        bias_list = []
        n_features = X.shape[1]
        for j in range(self.n_permutations):
            if j > 0:
                perm = np.random.permutation(n_features)
                inv = np.zeros(n_features, dtype=int)
                inv[perm] = np.arange(n_features)
            else:
                # always include original order
                perm = np.arange(n_features)
                inv = perm
            for i in range(n_classes):
                if i > 0:
                    continue
                indices = np.mod(np.arange(n_classes) + i, n_classes)
                weights, biases = extract_linear_model(model, X[:, perm], np.mod(y + i, n_classes), device=self.device)
                weight_list.append(weights[:, indices][inv, :])
                bias_list.append(biases[indices])
        self.weights_ = np.mean(weight_list, axis=0)
        self.biases_ = np.mean(bias_list, axis=0)
        return self
        
    def predict_proba(self, X):
        return predict_with_linear_model(self.X_train_, X, self.weights_, self.biases_)
    
    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)

In [6]:

bla = ForwardLinearModel(path="models_diff/prior_diff_real_checkpoint_predict_linear_coefficients_nlayer_12_multiclass_04_11_2023_23_25_22_n_0_epoch_75.cpkt", device="cuda").fit(X_train, y_train)
bla.score(X_test, y_test)

0.9473684210526315

In [7]:
from sklearn.model_selection import cross_validate
import pandas as pd
pd.DataFrame(cross_validate(ForwardLinearModel(), iris.data, iris.target)).mean()

fit_time      2.732323
score_time    0.000868
test_score    0.920000
dtype: float64

In [8]:
from sklearn.model_selection import cross_validate
pd.DataFrame(cross_validate(ForwardLinearModel(path="models_diff/prior_diff_real_checkpoint_predict_linear_coefficients_nlayer_12_multiclass_04_11_2023_23_25_22_n_0_epoch_75.cpkt"), iris.data, iris.target)).mean()

fit_time      5.277533
score_time    0.000983
test_score    0.966667
dtype: float64

In [12]:
from sklearn.model_selection import cross_validate
pd.DataFrame(cross_validate(ForwardLinearModel(path="models_diff/prior_diff_real_checkpoint_predict_linear_coefficients_nlayer_12_multiclass_04_11_2023_23_25_22_n_0_epoch_80.cpkt", device="cuda"), iris.data, iris.target)).mean()

fit_time      0.915291
score_time    0.000622
test_score    0.966667
dtype: float64

In [77]:
inv = np.zeros(4, dtype=int)
inv[bla] = np.arange(4)

In [78]:
bla[inv]

array([0, 1, 2, 3])

In [10]:
from sklearn.model_selection import cross_validate
pd.DataFrame(cross_validate(ForwardLinearModel(), iris.data, iris.target)).mean()

fit_time      2.720631
score_time    0.000746
test_score    0.920000
dtype: float64

In [11]:
from sklearn.linear_model import LogisticRegression
from time import time
import pandas as pd
pd.DataFrame(cross_validate(LogisticRegression(), iris.data, iris.target)).mean()

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


fit_time      6.662136
score_time    0.000401
test_score    0.973333
dtype: float64

In [35]:
pred1 = res.numpy().argmax(axis=2).ravel()

In [105]:
from tabpfn.utils import normalize_by_used_features_f


def extract_linear_model(model, X_train, y_train):
    max_features = 100
    eval_position = X_train.shape[0]
    n_classes = len(np.unique(y_train))
    n_features = X_train.shape[1]

    ys = torch.Tensor(y_train)
    xs = torch.Tensor(X_train)

    eval_xs_ = normalize_data(xs, eval_position)

    eval_xs = normalize_by_used_features_f(eval_xs_, X_train.shape[-1], max_features,
                                                   normalize_with_sqrt=False)
    x_all_torch = torch.Tensor(np.hstack([eval_xs, np.zeros((X_train.shape[0], 100 - X_train.shape[1]))]))
    
    x_src = model.encoder(x_all_torch.unsqueeze(1)[:len(X_train)])
    y_src = model.y_encoder(ys.unsqueeze(1).unsqueeze(-1))
    train_x = x_src + y_src
    # src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
    output = model.transformer_encoder(train_x)
    linear_model_coefs = model.decoder(output)
    encoder_weight = model.encoder.get_parameter("weight")
    encoder_bias = model.encoder.get_parameter("bias")

    total_weights = torch.matmul(encoder_weight[:, :n_features].T, linear_model_coefs[0, :-1, :n_classes])
    total_biases = torch.matmul(encoder_bias, linear_model_coefs[0, :-1, :n_classes]) + linear_model_coefs[0, -1, :n_classes]
    return total_weights.detach().numpy() / (n_features / max_features), total_biases.detach().numpy()

def predict_with_linear_model(model, X_train, y_train, X_test):
    max_features = 100
    eval_position = X_train.shape[0]
    n_classes = len(np.unique(y_train))
    n_features = X_train.shape[1]

    ys = torch.Tensor(y_train)
    xs = torch.Tensor(np.vstack([X_train, X_test]))

    eval_xs_ = normalize_data(xs, eval_position)

    eval_xs = normalize_by_used_features_f(eval_xs_, X_train.shape[-1], max_features,
                                                   normalize_with_sqrt=False)
    x_all_torch = torch.Tensor(np.hstack([eval_xs, np.zeros((eval_xs.shape[0], 100 - eval_xs.shape[1]))]))
    
    x_src = model.encoder(x_all_torch.unsqueeze(1)[:len(X_train)])
    y_src = model.y_encoder(ys.unsqueeze(1).unsqueeze(-1))
    train_x = x_src + y_src
    # src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
    output = model.transformer_encoder(train_x)
    linear_model_coefs = model.decoder(output)
    encoder_weight = model.encoder.get_parameter("weight")
    encoder_bias = model.encoder.get_parameter("bias")

    total_weights = torch.matmul(encoder_weight[:, :n_features].T, linear_model_coefs[0, :-1, :n_classes])
    total_biases = torch.matmul(encoder_bias, linear_model_coefs[0, :-1, :n_classes]) + linear_model_coefs[0, -1, :n_classes]
                      
    pred_simple = torch.matmul(model.encoder(x_all_torch),  linear_model_coefs[0, :-1, :n_classes]) + linear_model_coefs[0, -1, :n_classes]
    probs =  torch.nn.functional.softmax(pred_simple/ 0.8, dim=1)
    return total_weights.detach().numpy() / (n_features / max_features), total_biases.detach().numpy(), probs[eval_position:]


In [102]:
weights, biases = extract_linear_model(model, X_train, np.mod(y_train + 2, 3) )
weights, biases, probs = predict_with_linear_model(model, X_train, np.mod(y_train + 2, 3), X_test)


In [106]:
probs

tensor([[3.8660e-02, 2.6740e-12, 9.6134e-01],
        [7.1888e-02, 1.0708e-11, 9.2811e-01],
        [3.9089e-02, 1.2088e-11, 9.6091e-01],
        [1.3291e-01, 4.6124e-11, 8.6709e-01],
        [6.4584e-04, 1.0191e-15, 9.9935e-01],
        [1.0291e-07, 1.0000e+00, 1.0170e-15],
        [8.1023e-01, 1.8967e-01, 9.9063e-05],
        [2.2517e-02, 1.5885e-12, 9.7748e-01],
        [3.1979e-03, 9.9680e-01, 1.0756e-08],
        [9.1629e-01, 8.3342e-02, 3.6608e-04],
        [9.7878e-01, 2.0469e-02, 7.4996e-04],
        [6.6996e-02, 1.0258e-11, 9.3300e-01],
        [9.5299e-01, 4.6785e-02, 2.2798e-04],
        [5.9498e-01, 4.0499e-01, 3.1640e-05],
        [1.2667e-05, 9.9999e-01, 1.8275e-11],
        [1.2139e-02, 4.3015e-13, 9.8786e-01],
        [4.8299e-02, 9.5170e-01, 3.7459e-07],
        [5.6333e-06, 9.9999e-01, 1.9947e-12],
        [1.9692e-04, 9.9980e-01, 3.1611e-10],
        [3.0287e-02, 1.5947e-12, 9.6971e-01],
        [1.2270e-02, 9.8773e-01, 1.8575e-07],
        [2.3998e-04, 9.9976e-01, 6

In [119]:
eval_xs_[len(X_train):]

tensor([[-1.6016,  0.3172, -1.3934, -1.3621],
        [-1.6016,  0.0835, -1.3359, -1.3621],
        [-0.5795,  0.7847, -1.3359, -1.0961],
        [-1.0906, -0.1503, -1.2784, -1.3621],
        [-0.1962,  3.1221, -1.3359, -1.0961],
        [ 2.3591, -1.0852,  1.7680,  1.4310],
        [-0.4517, -1.3189,  0.1011,  0.1009],
        [-1.6016,  0.7847, -1.3934, -1.2291],
        [ 0.5704, -0.3840,  1.0208,  0.7660],
        [ 1.0814,  0.0835,  0.3310,  0.2339],
        [-0.3240, -0.1503,  0.1586,  0.1009],
        [-1.8572, -0.1503, -1.4508, -1.3621],
        [-0.0684, -1.0852,  0.1011, -0.0321],
        [-0.4517, -1.7864,  0.1011,  0.1009],
        [ 1.0814,  0.0835,  1.0208,  1.5640],
        [-1.0906,  1.0184, -1.4508, -1.2291],
        [ 0.5704, -1.3189,  0.6184,  0.3670],
        [ 2.3591, -0.1503,  1.3082,  1.4310],
        [-0.1962, -1.3189,  0.6759,  1.0320],
        [-1.8572,  0.3172, -1.4508, -1.3621],
        [ 0.4426, -0.6177,  0.5609,  0.7660],
        [ 1.2092, -0.1503,  0.9633

In [128]:
X_test_scaled

array([[-1.60162237,  0.31721574, -1.39336883, -1.36212528],
       [-1.60162237,  0.08347783, -1.33588916, -1.36212528],
       [-0.57950439,  0.78469156, -1.33588916, -1.09611302],
       [-1.09056338, -0.15026009, -1.27840949, -1.36212528],
       [-0.19621015,  3.12207066, -1.33588916, -1.09611302],
       [ 2.3590848 , -1.08521173,  1.76801312,  1.43100346],
       [-0.45173964, -1.31894964,  0.10110264,  0.10094215],
       [-1.60162237,  0.78469156, -1.39336883, -1.22911915],
       [ 0.57037834, -0.383998  ,  1.02077739,  0.76597281],
       [ 1.08143732,  0.08347783,  0.33102132,  0.23394828],
       [-0.32397489, -0.15026009,  0.15858231,  0.10094215],
       [-1.85715186, -0.15026009, -1.4508485 , -1.36212528],
       [-0.0684454 , -1.08521173,  0.10110264, -0.03206398],
       [-0.45173964, -1.78642546,  0.10110264,  0.10094215],
       [ 1.08143732,  0.08347783,  1.02077739,  1.56400959],
       [-1.09056338,  1.01842947, -1.4508485 , -1.22911915],
       [ 0.57037834, -1.

In [117]:
xs = torch.Tensor(np.vstack([X_train, X_test]))

eval_xs_ = normalize_data(xs, len(X_train))

In [127]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0, ddof=1) + .000001
X_test_scaled = (X_test - mean) / std
X_test_scaled = np.clip(X_test_scaled, a_min=-100, a_max=100)
res2 = np.dot(X_test_scaled , weights) + biases
from scipy.special import softmax
probs2 = softmax(res2 / .8, axis=1)


In [124]:
probs

tensor([[3.8660e-02, 2.6740e-12, 9.6134e-01],
        [7.1888e-02, 1.0708e-11, 9.2811e-01],
        [3.9089e-02, 1.2088e-11, 9.6091e-01],
        [1.3291e-01, 4.6124e-11, 8.6709e-01],
        [6.4584e-04, 1.0191e-15, 9.9935e-01],
        [1.0291e-07, 1.0000e+00, 1.0170e-15],
        [8.1023e-01, 1.8967e-01, 9.9063e-05],
        [2.2517e-02, 1.5885e-12, 9.7748e-01],
        [3.1979e-03, 9.9680e-01, 1.0756e-08],
        [9.1629e-01, 8.3342e-02, 3.6608e-04],
        [9.7878e-01, 2.0469e-02, 7.4996e-04],
        [6.6996e-02, 1.0258e-11, 9.3300e-01],
        [9.5299e-01, 4.6785e-02, 2.2798e-04],
        [5.9498e-01, 4.0499e-01, 3.1640e-05],
        [1.2667e-05, 9.9999e-01, 1.8275e-11],
        [1.2139e-02, 4.3015e-13, 9.8786e-01],
        [4.8299e-02, 9.5170e-01, 3.7459e-07],
        [5.6333e-06, 9.9999e-01, 1.9947e-12],
        [1.9692e-04, 9.9980e-01, 3.1611e-10],
        [3.0287e-02, 1.5947e-12, 9.6971e-01],
        [1.2270e-02, 9.8773e-01, 1.8575e-07],
        [2.3998e-04, 9.9976e-01, 6

In [129]:
probs2

array([[3.86596546e-02, 2.67393133e-12, 9.61340345e-01],
       [7.18872965e-02, 1.07083961e-11, 9.28112703e-01],
       [3.90886567e-02, 1.20883877e-11, 9.60911343e-01],
       [1.32907184e-01, 4.61238117e-11, 8.67092816e-01],
       [6.45841319e-04, 1.01912324e-15, 9.99354159e-01],
       [1.02907803e-07, 9.99999897e-01, 1.01698184e-15],
       [8.10234363e-01, 1.89666575e-01, 9.90621714e-05],
       [2.25163715e-02, 1.58854090e-12, 9.77483628e-01],
       [3.19794815e-03, 9.96802041e-01, 1.07561765e-08],
       [9.16291372e-01, 8.33425423e-02, 3.66085924e-04],
       [9.78781307e-01, 2.04687284e-02, 7.49964994e-04],
       [6.69958395e-02, 1.02574856e-11, 9.33004160e-01],
       [9.52986698e-01, 4.67853176e-02, 2.27984869e-04],
       [5.94978896e-01, 4.04989463e-01, 3.16404333e-05],
       [1.26670696e-05, 9.99987333e-01, 1.82755961e-11],
       [1.21387794e-02, 4.30149550e-13, 9.87861221e-01],
       [4.82992958e-02, 9.51700330e-01, 3.74589210e-07],
       [5.63328661e-06, 9.99994

In [103]:
bla = np.mod(np.arange(3) + 2, 3)
bla

array([2, 0, 1])

In [104]:
probs[:, bla] - res

tensor([[[ 0.0000e+00,  3.3528e-08,  2.5370e-17],
         [-2.9802e-07,  3.2037e-07, -4.4235e-17],
         [ 1.1921e-07, -1.7881e-07, -2.1684e-17],
         [ 6.5565e-07, -5.3644e-07,  3.4694e-17],
         [ 0.0000e+00, -4.3074e-09, -3.9175e-21],
         [-3.9175e-21, -1.5703e-12,  0.0000e+00],
         [ 1.8917e-10,  0.0000e+00, -4.4703e-08],
         [-2.3842e-07,  2.1048e-07,  8.6736e-18],
         [-4.1744e-14,  0.0000e+00,  0.0000e+00],
         [-2.4738e-09,  7.1526e-07, -6.1840e-07],
         [ 0.0000e+00,  0.0000e+00,  9.3132e-09],
         [-1.7881e-07,  2.4587e-07,  3.8164e-17],
         [-1.6880e-09,  1.7881e-07, -2.1607e-07],
         [-4.7294e-11, -2.9802e-07,  3.5763e-07],
         [-1.7521e-16, -9.6406e-11,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  2.4666e-18],
         [-1.0516e-12, -1.1176e-08,  0.0000e+00],
         [ 1.5179e-17,  3.1832e-11,  0.0000e+00],
         [ 1.1935e-15,  5.5297e-10,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.8323e-17],


In [29]:
ss = StandardScaler()
ss.fit(X_train)
X_test_scaled = ss.transform(X_test)
bla = np.dot(X_test_scaled , weights) + biases

In [40]:
from scipy.special import softmax
probs = softmax(bla / np.exp(.8), axis=1)
pred = np.mod(np.argmax(probs, 1) - 2, 3)

In [51]:
bla = np.mod(np.arange(3) + 2, 3)
bla

array([2, 0, 1])

In [52]:
np.argmax(probs[:, bla], 1) == y_test

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True, False,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False, False])

In [53]:
probs

array([[2.36820625e-01, 5.12587027e-05, 7.63128117e-01],
       [2.82203939e-01, 8.05778346e-05, 7.17715483e-01],
       [2.37561888e-01, 8.83038710e-05, 7.62349808e-01],
       [3.34660613e-01, 1.29683041e-04, 6.65209704e-01],
       [6.52735482e-02, 3.60800810e-06, 9.34722844e-01],
       [2.96277778e-03, 9.97033372e-01, 3.85000330e-06],
       [6.11093308e-01, 3.65052707e-01, 2.38539850e-02],
       [2.02397883e-01, 4.41226870e-05, 7.97557994e-01],
       [1.10685490e-01, 8.88133981e-01, 1.18052947e-03],
       [6.73633270e-01, 2.86042763e-01, 4.03239672e-02],
       [7.54451829e-01, 1.88415251e-01, 5.71329201e-02],
       [2.76698799e-01, 7.97929074e-05, 7.23221408e-01],
       [7.19628046e-01, 2.44576811e-01, 3.57951427e-02],
       [5.24335470e-01, 4.60511062e-01, 1.51534682e-02],
       [1.66110448e-02, 9.83258442e-01, 1.30512982e-04],
       [1.68225637e-01, 2.85999563e-05, 8.31745763e-01],
       [2.51566554e-01, 7.44805606e-01, 3.62784080e-03],
       [1.24506542e-02, 9.87490

In [54]:
res

tensor([[[9.6134e-01, 3.8660e-02, 2.6739e-12],
         [9.2811e-01, 7.1888e-02, 1.0708e-11],
         [9.6091e-01, 3.9089e-02, 1.2088e-11],
         [8.6709e-01, 1.3291e-01, 4.6124e-11],
         [9.9935e-01, 6.4584e-04, 1.0191e-15],
         [1.0170e-15, 1.0291e-07, 1.0000e+00],
         [9.9062e-05, 8.1023e-01, 1.8967e-01],
         [9.7748e-01, 2.2516e-02, 1.5885e-12],
         [1.0756e-08, 3.1979e-03, 9.9680e-01],
         [3.6609e-04, 9.1629e-01, 8.3343e-02],
         [7.4996e-04, 9.7878e-01, 2.0469e-02],
         [9.3300e-01, 6.6996e-02, 1.0257e-11],
         [2.2799e-04, 9.5299e-01, 4.6785e-02],
         [3.1640e-05, 5.9498e-01, 4.0499e-01],
         [1.8275e-11, 1.2667e-05, 9.9999e-01],
         [9.8786e-01, 1.2139e-02, 4.3015e-13],
         [3.7459e-07, 4.8299e-02, 9.5170e-01],
         [1.9947e-12, 5.6332e-06, 9.9999e-01],
         [3.1611e-10, 1.9692e-04, 9.9980e-01],
         [9.6971e-01, 3.0287e-02, 1.5947e-12],
         [1.8575e-07, 1.2270e-02, 9.8773e-01],
         [6.5

In [19]:
from tabpfn.utils import normalize_by_used_features_f
X_all = np.vstack([X_train, X_test])
max_features = 100
eval_xs = normalize_by_used_features_f(X_all, X_all.shape[-1], max_features,
                                               normalize_with_sqrt=False)
x_all_torch = torch.Tensor(np.hstack([eval_xs, np.zeros((X_all.shape[0], 100 - X_all.shape[1]))]))

In [20]:
eval_position = len(X_train)
eval_xs_ = normalize_data(xs, eval_position)
max_features = 100
eval_xs_ = normalize_by_used_features_f(eval_xs_, eval_xs_.shape[-1], max_features,
                                               normalize_with_sqrt=False)

In [21]:
eval_xs_.shape

torch.Size([150, 4])

In [22]:
eval_xs_[:112].std(0)

tensor([25.0000, 24.9999, 25.0000, 25.0000])

In [23]:
eval_xs_[:112].mean(0)

tensor([-3.0177e-05, -1.6349e-06,  2.4779e-06,  3.0654e-07])

In [24]:
x_all_torch = torch.concat([eval_xs_, torch.zeros((eval_xs.shape[0], 100 - eval_xs.shape[1]))], dim=1)
x_all_torch.shape

torch.Size([150, 100])

In [25]:
ys = torch.Tensor(y_train)
res = model((x_all_torch.unsqueeze(1), ys.unsqueeze(1)), single_eval_pos=len(X_train))

In [38]:
x_src = model.encoder(x_all_torch.unsqueeze(1)[:len(X_train)])
y_src = model.y_encoder(ys.unsqueeze(1).unsqueeze(-1))
train_x = x_src + y_src
# src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
output = model.transformer_encoder(train_x)
linear_model_coefs = model.decoder(output)


In [41]:
linear_model_coefs.shape

torch.Size([1, 513, 10])

In [99]:
def extract_linear_model(model, X_train, y_train):
    max_features = 100
    n_classes = len(np.unique(y_train))
    n_features = X_train.shape[1]

    ys = torch.Tensor(y_train)
    xs = torch.Tensor(X_train)

    eval_xs_ = normalize_data(xs, eval_position)

    eval_xs = normalize_by_used_features_f(eval_xs_, X_train.shape[-1], max_features,
                                                   normalize_with_sqrt=False)
    x_all_torch = torch.Tensor(np.hstack([eval_xs, np.zeros((X_train.shape[0], 100 - X_train.shape[1]))]))
    
    x_src = model.encoder(x_all_torch.unsqueeze(1)[:len(X_train)])
    y_src = model.y_encoder(ys.unsqueeze(1).unsqueeze(-1))
    train_x = x_src + y_src
    # src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
    output = model.transformer_encoder(train_x)
    linear_model_coefs = model.decoder(output)
    total_weights = torch.matmul(encoder_weight[:, :n_features].T, linear_model_coefs[0, :-1, :n_classes])
    total_biases = torch.matmul(encoder_bias, linear_model_coefs[0, :-1, :n_classes]) + linear_model_coefs[0, -1, :n_classes]
    return total_weights.detach().numpy() / (n_features / max_features), total_biases.detach().numpy()

In [100]:
weights, biases = extract_linear_model(model, X_train, y_train)

In [105]:
ss = StandardScaler()
ss.fit(X_train)
X_test_scaled = ss.transform(X_test)

In [107]:
bla = np.dot(X_test_scaled , weights) + biases

0.8947368421052632

In [49]:
encoder_weight = model.encoder.get_parameter("weight")
encoder_bias = model.encoder.get_parameter("bias")


In [70]:
n_features = X_train.shape[1]
n_classes = 3
total_weights = torch.matmul(encoder_weight[:, :n_features].T, linear_model_coefs[0, :-1, :n_classes])

In [71]:
total_biases = torch.matmul(encoder_bias, linear_model_coefs[0, :-1, :n_classes]) + linear_model_coefs[0, -1, :n_classes]
total_biases.shape

torch.Size([3])

In [72]:
eval_xs_.shape

torch.Size([150, 4])

In [73]:
total_weights.shape

torch.Size([4, 3])

In [83]:
((torch.matmul(eval_xs_, total_weights) + total_biases).argmax(1)[112:].detach().numpy() == y_test).mean()

0.8947368421052632

Bad pipe message: %s [b"y\xd8'\xbc(\xf7\xaf=\xed\x07\xd1\xabv\xdcsKB# \xfc\x8d\x94Z2\xab*\xe7\xea\xd0\xf0\x1f\xa7/^\x11\x99\x86\x82"]
Bad pipe message: %s [b'\x7f\x1fL\xb8\x9d\x7f\xe4]\xcf\xf4\xe2\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03']
Bad pipe message: %s [b'Cb\xb5\xf5P7\x99\x1f)>\x1a\x0e\x94\xd0\xc7cG\x05\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00']
Bad pipe message: %s [b"s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00\xc0\x00<\x00\xba\x005\x00\x84\x00/\x00\x96\x00A\x00\x05\x00\n\x00\xff\x01\x00"]
Bad pipe message:

In [46]:
list(model.encoder.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0670, -0.0613,  0.0725,  ...,  0.0519,  0.0198, -0.0384],
          [-0.0141,  0.0031,  0.0919,  ...,  0.0247, -0.0806,  0.0766],
          [-0.0151, -0.0461,  0.0375,  ...,  0.0191,  0.0504, -0.0113],
          ...,
          [-0.0445,  0.0645,  0.0400,  ..., -0.0645,  0.0573, -0.0435],
          [-0.0159, -0.0177, -0.0205,  ...,  0.0716, -0.0397,  0.0622],
          [ 0.0612,  0.0584,  0.0001,  ...,  0.0596, -0.0759,  0.1071]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 2.0529e-02, -9.1211e-02, -8.5835e-02,  7.3189e-02, -1.5755e-02,
          -1.4510e-01, -1.0569e-01, -3.5785e-02, -1.3507e-01,  8.9870e-02,
           7.2533e-02,  5.2945e-02,  1.1816e-01,  1.2436e-01, -2.9489e-02,
          -3.2538e-02, -1.4101e-01, -5.0609e-03,  1.4305e-01,  2.4858e-02,
          -3.9943e-02,  6.6448e-02,  1.3602e-01, -1.4026e-01, -2.9958e-02,
           1.1587e-01,  1.3399e-02, -1.3202e-01, -1.1304e-01, -1.3910e-01,
 

In [26]:
probs = torch.nn.functional.softmax(res[:, 0, :3] / 0.8, dim=-1)

In [27]:
probs.shape

torch.Size([38, 3])

In [28]:
y_pred = probs.detach().numpy()

In [29]:
y_pred.argmax(axis=1)

array([0, 0, 0, 0, 0, 2, 1, 0, 2, 1, 1, 0, 1, 1, 2, 0, 2, 2, 2, 0, 2, 2,
       2, 2, 0, 2, 2, 1, 1, 2, 0, 0, 2, 1, 0, 0, 2, 0])

In [30]:
(y_pred.argmax(axis=1) == y_test).mean()

0.8947368421052632

In [73]:
y_pred

array([[5.56962641e-06, 9.99994397e-01, 4.45785076e-12],
       [9.55580235e-01, 4.42057252e-02, 2.14084997e-04],
       [4.81489085e-04, 2.15239360e-17, 9.99518514e-01],
       [5.14747342e-04, 9.99485254e-01, 9.29835653e-11],
       [6.84229610e-03, 5.19721019e-14, 9.93157685e-01],
       [3.64074936e-06, 9.99996305e-01, 6.71815088e-13],
       [3.82941007e-03, 2.91459294e-14, 9.96170580e-01],
       [7.49539316e-01, 2.50392497e-01, 6.81771708e-05],
       [5.49787700e-01, 4.50191855e-01, 2.04297266e-05],
       [9.63276565e-01, 3.61998267e-02, 5.23653696e-04],
       [3.63744535e-02, 9.63625431e-01, 7.42356150e-08],
       [8.95416558e-01, 1.04382806e-01, 2.00692826e-04],
       [9.27551746e-01, 7.23605305e-02, 8.78107894e-05],
       [3.85184467e-01, 6.14796162e-01, 1.93364904e-05],
       [7.09265172e-01, 2.90685296e-01, 4.95604145e-05],
       [2.07310217e-03, 1.22282081e-15, 9.97926831e-01],
       [5.67619383e-01, 4.32327658e-01, 5.29848876e-05],
       [8.82348239e-01, 1.17537

In [65]:
res

tensor([[[4.4579e-12, 5.5696e-06, 9.9999e-01],
         [2.1408e-04, 9.5558e-01, 4.4206e-02],
         [9.9952e-01, 4.8149e-04, 2.1524e-17],
         [9.2984e-11, 5.1475e-04, 9.9949e-01],
         [9.9316e-01, 6.8423e-03, 5.1972e-14],
         [6.7182e-13, 3.6407e-06, 1.0000e+00],
         [9.9617e-01, 3.8294e-03, 2.9146e-14],
         [6.8177e-05, 7.4954e-01, 2.5039e-01],
         [2.0430e-05, 5.4979e-01, 4.5019e-01],
         [5.2365e-04, 9.6328e-01, 3.6200e-02],
         [7.4236e-08, 3.6374e-02, 9.6363e-01],
         [2.0069e-04, 8.9542e-01, 1.0438e-01],
         [8.7811e-05, 9.2755e-01, 7.2361e-02],
         [1.9336e-05, 3.8518e-01, 6.1480e-01],
         [4.9560e-05, 7.0927e-01, 2.9069e-01],
         [9.9793e-01, 2.0731e-03, 1.2228e-15],
         [5.2985e-05, 5.6762e-01, 4.3233e-01],
         [1.1424e-04, 8.8235e-01, 1.1754e-01],
         [9.7774e-01, 2.2262e-02, 4.5128e-12],
         [9.9840e-01, 1.6018e-03, 3.1781e-15],
         [3.9647e-09, 8.1633e-04, 9.9918e-01],
         [8.8

In [None]:
res = transformer_predict(model, xs.unsqueeze(1), ys.unsqueeze(1), eval_position=len(X_train), num_classes=4, N_ensemble_configurations=1)