In [1]:
import sys
import os

# Determine the absolute path to the src directory
src_path = os.path.abspath(os.path.join(os.getcwd(), 'src'))

# Add src_path to sys.path if it's not already present
if src_path not in sys.path:
    sys.path.insert(0, src_path)

In [2]:
import pandas as pd
from tabpfn_extensions import TabPFNRegressor, TabPFNClassifier
from tabpfn_extensions.embedding import TabPFNEmbedding

import argparse
import os
import numpy as np
import random
from tqdm import tqdm

from src.data_utils import load_data_for_tabPFN
from src.data_constants import targets

# Set the random seed for reproducibility
s = 42
np.random.seed(s)
random.seed(s)

# load the dataset
f = 'data/processed/vitals_yeo-johnson_test_data.csv'
if f.endswith(".csv"):
    data, ids = load_data_for_tabPFN(f)
    print(f"Loaded data from {f}")

# fit the model and extract the embeddings
y = data[targets]
x = data.drop(columns=targets)
vecs = []
print(x.columns)
t = targets[0]


Data after dropping rows with NA target values, (489, 45) matrix.
Unique patient IDs in the data: 171
Loaded data from data/processed/vitals_yeo-johnson_test_data.csv
Index(['PTID', 'EXAMDATE', 'DX_bl', 'AGE', 'PTGENDER', 'PTEDUCAT', 'PTETHCAT',
       'PTRACCAT', 'PTMARRY', 'APOE4', 'PIB', 'ABETA', 'TAU', 'Hippocampus',
       'EXAMDATE_bl', 'ADAS11_bl', 'ADAS13_bl', 'MMSE_bl', 'Ventricles_bl',
       'Hippocampus_bl', 'MOCA_bl', 'ABETA_bl', 'TAU_bl', 'PIB_bl', 'Years_bl',
       'Month_bl', 'VSWEIGHT', 'VSWTUNIT', 'VSBPSYS', 'VSBPDIA', 'VSPULSE',
       'VSRESP', 'VSTEMP', 'VSTMPSRC', 'VSTMPUNT'],
      dtype='object')


In [3]:
x.dtypes

PTID               object
EXAMDATE          float64
DX_bl              object
AGE               float64
PTGENDER           object
PTEDUCAT          float64
PTETHCAT           object
PTRACCAT           object
PTMARRY            object
APOE4             float64
PIB               float64
ABETA             float64
TAU               float64
Hippocampus       float64
EXAMDATE_bl       float64
ADAS11_bl         float64
ADAS13_bl         float64
MMSE_bl           float64
Ventricles_bl     float64
Hippocampus_bl    float64
MOCA_bl           float64
ABETA_bl          float64
TAU_bl            float64
PIB_bl            float64
Years_bl          float64
Month_bl          float64
VSWEIGHT          float64
VSWTUNIT          float64
VSBPSYS           float64
VSBPDIA           float64
VSPULSE           float64
VSRESP            float64
VSTEMP            float64
VSTMPSRC          float64
VSTMPUNT          float64
dtype: object

In [4]:
x["EXAMDATE"]

10      15141.0
12      15475.0
14      15839.0
16      16191.0
17      16588.0
         ...   
1642    18498.0
1643    18898.0
1644    19269.0
1645    17318.0
1647    19052.0
Name: EXAMDATE, Length: 489, dtype: float64

In [5]:
x

Unnamed: 0,PTID,EXAMDATE,DX_bl,AGE,PTGENDER,PTEDUCAT,PTETHCAT,PTRACCAT,PTMARRY,APOE4,...,Month_bl,VSWEIGHT,VSWTUNIT,VSBPSYS,VSBPDIA,VSPULSE,VSRESP,VSTEMP,VSTMPSRC,VSTMPUNT
10,002_S_0413,15141.0,CN,76.3,Female,16.0,Not Hisp/Latino,White,Married,0.0,...,61.1475,57.606184,1.0,120.0,62.0,69.0,18.0,36.888889,1.0,1.0
12,002_S_0413,15475.0,CN,76.3,Female,16.0,Not Hisp/Latino,White,Married,0.0,...,72.0984,58.513368,1.0,131.0,64.0,71.0,15.0,36.555556,1.0,1.0
14,002_S_0413,15839.0,CN,76.3,Female,16.0,Not Hisp/Latino,White,Married,0.0,...,84.0328,58.059776,1.0,125.0,69.0,75.0,15.0,36.777778,1.0,1.0
16,002_S_0413,16191.0,CN,76.3,Female,16.0,Not Hisp/Latino,White,Married,0.0,...,95.5738,58.966960,1.0,115.0,67.0,82.0,15.0,36.333333,1.0,1.0
17,002_S_0413,16588.0,CN,76.3,Female,16.0,Not Hisp/Latino,White,Married,0.0,...,108.5900,58.966960,1.0,110.0,62.0,77.0,16.0,36.833333,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1642,941_S_4187,18498.0,LMCI,62.0,Male,16.0,Not Hisp/Latino,White,Married,0.0,...,107.7700,86.182480,1.0,148.0,71.0,70.0,18.0,36.222222,3.0,1.0
1643,941_S_4187,18898.0,LMCI,62.0,Male,16.0,Not Hisp/Latino,White,Married,0.0,...,120.8850,85.275296,1.0,153.0,74.0,69.0,16.0,36.000000,3.0,1.0
1644,941_S_4187,19269.0,LMCI,62.0,Male,16.0,Not Hisp/Latino,White,Married,0.0,...,133.0490,84.368112,1.0,155.0,78.0,66.0,18.0,36.400000,3.0,2.0
1645,941_S_6017,17318.0,LMCI,76.6,Male,17.0,Not Hisp/Latino,White,Married,1.0,...,0.0000,77.110640,1.0,160.0,70.0,51.0,20.0,36.833333,1.0,1.0


In [None]:
from tabpfn.config import ModelInterfaceConfig

In [None]:
config = ModelInterfaceConfig(
    FEATURE_SHIFT_METHOD = None,
    CLASS_SHIFT_METHOD = None
)

In [None]:
reg = TabPFNRegressor(random_state=s,  
                      categorical_features_indices=[0,1,3,5,6,7,13,25,26,27],
                      inference_config=config,
                      n_estimators = 1
                     )
reg.feature_names_in_= x.columns

In [None]:
from tabpfn.utils import validate_X_predict, _fix_dtypes

In [None]:
x.dtypes

In [None]:
reg.fit(x, y[t])

In [None]:
x = validate_X_predict(x, reg)

In [None]:
x = _fix_dtypes(x, cat_indices=reg.categorical_features_indices)

In [None]:
from sklearn.base import check_is_fitted, is_classifier
print(check_is_fitted(reg))

In [None]:
tran = reg.executor_.preprocessor.transform(x)

In [None]:
reg.executor_.preprocessors

In [None]:
reg.executor_.ensemble_configs

In [None]:
def iter_outputs(
    self,
    X: np.ndarray,
    *,
    device: torch.device,
    autocast: bool,
    only_return_standard_out: bool = True,
) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]:
    for preprocessor, X_train, y_train, config, cat_ix in zip(
        self.preprocessors,
        self.X_trains,
        self.y_trains,
        self.ensemble_configs,
        self.cat_ixs,
    ):
        X_train = torch.as_tensor(X_train, dtype=torch.float32, device=device)  # noqa: PLW2901

        X_test = preprocessor.transform(X).X
        X_test = torch.as_tensor(X_test, dtype=torch.float32, device=device)

        X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1)
        y_train = torch.as_tensor(y_train, dtype=torch.float32, device=device)  # noqa: PLW2901

In [None]:
reg.executor_.ensemble_configs

In [None]:
for i in reg.executor_.iter_outputs(x ,device = 'cpu', autocast = False):
    print(i)

In [None]:
for name, transformer, columns in reg.preprocessor_.transformers_:
    print(f"Transformer: {name}")
    print(f"Applies to columns: {columns}")
    print(f"Transformer steps: {transformer}")
    print("-" * 40)

In [None]:
shuffle_idx = [idx for i in [col for _, _, col in reg.preprocessor_.transformers_] for idx in i]

In [None]:
sort_idx = np.argsort(shuffle_idx)

In [None]:
x_in = tran[:, sort_idx]

In [None]:
pd.DataFrame(x_in)

In [None]:
data

In [None]:
x

In [None]:
pd.DataFrame(tran)