In [16]:
import numpy as np
from pathlib import Path
import pandas as pd
import time
import torch.nn as nn
import torch

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer, load_iris, load_wine
from sklearn.model_selection import train_test_split
from scripts.decision_boundary import DecisionBoundaryDisplay
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.preprocessing import LabelEncoder


from scripts.transformer_prediction_interface import TabPFNClassifier, load_model_workflow, transformer_predict, get_params_from_config
import torch.optim as optim

In [17]:
class SklearnDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = torch.tensor(self.data[index], dtype=torch.float32)
        y = torch.tensor(self.target[index], dtype=torch.long)
        return x, y
    
    # class MetaDataset(Dataset):
    # def __init__(self, *datasets):
    #     self.datasets = datasets
    #     self.lengths = [len(dataset) for dataset in self.datasets]

    # def __len__(self):
    #     return sum(self.lengths)

    # def __getitem__(self, index):
    #     data, target = [], []
    #     for i, dataset in enumerate(self.datasets):
    #         data_i, target_i = dataset[index % self.lengths[i]]
    #         data.append(data_i)training
    #         target.append(target_i)
    #     return tuple(data), tuple(target)


In [28]:
epochs = 10
device = 'cpu'
lr = 0.0001
batch_size = 32

In [30]:
breast_cancer_data, breast_cancer_targets = load_breast_cancer(return_X_y=True)
breast_cancer_data_train, breast_cancer_data_test, breast_cancer_targets_train, breast_cancer_targets_test = train_test_split(breast_cancer_data, breast_cancer_targets, test_size=0.2, random_state=42)

iris_data, iris_targets = load_iris(return_X_y=True)
iris_data_train, iris_data_test, iris_targets_train, iris_targets_test = train_test_split(iris_data, iris_targets, test_size=0.2, random_state=42)

wine_data, wine_targets = load_wine(return_X_y=True)
wine_data_train, wine_data_test, wine_targets_train, wine_targets_test = train_test_split(wine_data, wine_targets, test_size=0.2, random_state=42)

breast_cancer_dataset_train = SklearnDataset(breast_cancer_data_train, breast_cancer_targets_train)
breast_cancer_dataset_test = SklearnDataset(breast_cancer_data_test, breast_cancer_targets_test)

iris_dataset_train = SklearnDataset(iris_data_train, iris_targets_train)
iris_dataset_test = SklearnDataset(iris_data_test, iris_targets_test)

wine_dataset_train = SklearnDataset(wine_data_train, wine_targets_train)
wine_dataset_test = SklearnDataset(wine_data_test, wine_targets_test)

breast_cancer_train_data_loader = DataLoader(breast_cancer_dataset_train, batch_size=batch_size, shuffle=True)
iris_train_data_loader = DataLoader(iris_dataset_train, batch_size=batch_size, shuffle=True)
wine_train_data_loader = DataLoader(wine_dataset_train, batch_size=batch_size, shuffle=True)

classifier = TabPFNClassifier(device=device, N_ensemble_configurations=4, only_inference=False)

classifier.model[2].train()

optimizer = optim.Adam(classifier.model[2].parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for e in range(epochs):
    print('=' * 15, 'Epoch', e,'=' * 15)
    for  i, data in enumerate(breast_cancer_train_data_loader):
        x, y = data
        
        if i == 0:
            classifier.fit(x, y)
        
        if i != 0:
            optimizer.zero_grad()
            prediction = classifier.predict_proba2(x)
            prediction = prediction.squeeze(0)  
            loss = criterion(prediction,y)
            print('ds1 | loss =',loss.item())
            loss.backward()
            optimizer.step()
            
    for  i, data in enumerate(iris_train_data_loader):
        x, y = data
        
        if i == 0:
            classifier.fit(x, y)
        
        if i != 0:
            # optimizer.zero_grad()
            prediction = classifier.predict_proba2(x)
            prediction = prediction.squeeze(0)
            loss = criterion(prediction,y)
            print('ds2 | loss =',loss.item())
            loss.backward()
            optimizer.step()
            
    for  i, data in enumerate(wine_train_data_loader):
        x, y = data
        
        if i == 0:
            classifier.fit(x, y)
        
        if i != 0:
            # optimizer.zero_grad()
            prediction = classifier.predict_proba2(x)
            prediction = prediction.squeeze(0)
            loss = criterion(prediction,y)
            print('ds3 | loss =',loss.item())
            loss.backward()
            optimizer.step()



Using style prior: True
Using cpu:0 device
Using a Transformer with 25.82 M parameters
ds1 | loss = 0.36399906873703003
ds1 | loss = 0.335908442735672
ds1 | loss = 0.3980216085910797
ds1 | loss = 0.3818013668060303
ds1 | loss = 0.3268836438655853
ds1 | loss = 0.37069404125213623
ds1 | loss = 0.3374892473220825
ds1 | loss = 0.3330860435962677
ds1 | loss = 0.3522685170173645
ds1 | loss = 0.3365355134010315
ds1 | loss = 0.3558477461338043
ds1 | loss = 0.3335947096347809
ds1 | loss = 0.34391486644744873
ds1 | loss = 0.3135858476161957
ds2 | loss = 0.7244208455085754
ds2 | loss = 0.6147929430007935
ds2 | loss = 0.6760572791099548
ds3 | loss = 0.6453214883804321
ds3 | loss = 0.6615816354751587
ds3 | loss = 0.5514776706695557
ds3 | loss = 0.6233497858047485
ds1 | loss = 0.34421566128730774
ds1 | loss = 0.48586297035217285
ds1 | loss = 0.4543382525444031
ds1 | loss = 0.5462649464607239
ds1 | loss = 0.5119904279708862
ds1 | loss = 0.43735796213150024
ds1 | loss = 0.3968279957771301
ds1 | loss =