In this notebook I experiment with bootstrap aggregation (bagging) to improve the performance of KANs on the test set and reduce variance / overfitting / sensitivity to extreme values.

This is primarily motivated by my reading of Elements of Statistical Learning 2nd Ed. Chapter 8, which details the use of the bootstrap.

In [173]:
import sys
sys.path.append("../..")

In [174]:
options = {
    'test_size': 0.2,
    'num_bootstrap_samples': 25,
    'random_seed': 1738,
    'num_train_iterations': 500,
    'num_repetitions': 10
}

In [175]:
import numpy as np
import torch
import torch.nn.functional as F
import polars as pl
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit

from train import train_regression

from utils.data_utils import get_all_descriptors_from_smiles_list, get_scaffolds, DESCRIPTOR_NAMES
from utils.evaluation_utils import regression_report

In [176]:
filepath = '../../datasets/aqueous_solubility_delaney.csv'
df_delaney = pl.read_csv(filepath)
df_delaney.head()

Compound ID,measured log(solubility:mol/L),ESOL predicted log(solubility:mol/L),SMILES
str,f64,f64,str
"""1,1,1,2-Tetrachloroethane""",-2.18,-2.794,"""ClCC(Cl)(Cl)Cl"""
"""1,1,1-Trichloroethane""",-2.0,-2.232,"""CC(Cl)(Cl)Cl"""
"""1,1,2,2-Tetrachloroethane""",-1.74,-2.549,"""ClC(Cl)C(Cl)Cl"""
"""1,1,2-Trichloroethane""",-1.48,-1.961,"""ClCC(Cl)Cl"""
"""1,1,2-Trichlorotrifluoroethane""",-3.04,-3.077,"""FC(F)(Cl)C(F)(Cl)Cl"""


In [177]:
smiles = df_delaney['SMILES'].to_list()

In [178]:
descriptors_df = get_all_descriptors_from_smiles_list(smiles, as_polars=True)

Calculating descriptors:   0%|          | 0/1144 [00:00<?, ?it/s]

Calculating descriptors: 100%|██████████| 1144/1144 [00:06<00:00, 176.62it/s]


In [179]:
df = pl.concat([df_delaney, descriptors_df], how='horizontal').drop('ESOL predicted log(solubility:mol/L)')

del df_delaney, descriptors_df

df.head()

Compound ID,measured log(solubility:mol/L),SMILES,MaxAbsEStateIndex,MaxEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,SPS,MolWt,HeavyAtomMolWt,ExactMolWt,NumValenceElectrons,NumRadicalElectrons,MaxPartialCharge,MinPartialCharge,MaxAbsPartialCharge,MinAbsPartialCharge,FpDensityMorgan1,FpDensityMorgan2,FpDensityMorgan3,BCUT2D_MWHI,BCUT2D_MWLOW,BCUT2D_CHGHI,BCUT2D_CHGLO,BCUT2D_LOGPHI,BCUT2D_LOGPLOW,BCUT2D_MRHI,BCUT2D_MRLOW,AvgIpc,BalabanJ,BertzCT,Chi0,Chi0n,Chi0v,Chi1,Chi1n,…,fr_imide,fr_isocyan,fr_isothiocyan,fr_ketone,fr_ketone_Topliss,fr_lactam,fr_lactone,fr_methoxy,fr_morpholine,fr_nitrile,fr_nitro,fr_nitro_arom,fr_nitro_arom_nonortho,fr_nitroso,fr_oxazole,fr_oxime,fr_para_hydroxylation,fr_phenol,fr_phenol_noOrthoHbond,fr_phos_acid,fr_phos_ester,fr_piperdine,fr_piperzine,fr_priamide,fr_prisulfonamd,fr_pyridine,fr_quatN,fr_sulfide,fr_sulfonamd,fr_sulfone,fr_term_acetylene,fr_tetrazole,fr_thiazole,fr_thiocyan,fr_thiophene,fr_unbrch_alkane,fr_urea
str,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""1,1,1,2-Tetrachloroethane""",-2.18,"""ClCC(Cl)(Cl)Cl""",5.116512,5.116512,0.039352,-1.276235,0.487138,12.0,167.85,165.834,165.891061,38.0,0.0,0.203436,-0.122063,0.203436,0.122063,1.166667,1.333333,1.333333,35.582798,10.92878,2.155416,-2.003515,2.259665,-2.010232,6.690915,1.27955,1.351644,3.16849,35.302969,5.207107,2.718965,5.74268,2.56066,1.187761,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,1-Trichloroethane""",-2.0,"""CC(Cl)(Cl)Cl""",5.060957,5.060957,1.083333,-1.083333,0.445171,12.0,133.405,130.381,131.930033,32.0,0.0,0.187382,-0.084013,0.187382,0.084013,1.2,1.2,1.2,35.582513,10.948044,2.065641,-1.94269,2.228454,-1.870955,6.667091,1.268178,0.721928,3.023716,20.364528,4.5,2.633893,4.90168,2.0,1.066947,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2,2-Tetrachloroethane""",-1.74,"""ClC(Cl)C(Cl)Cl""",5.114198,5.114198,0.67284,-0.67284,0.527312,11.0,167.85,165.834,165.891061,38.0,0.0,0.137344,-0.102365,0.137344,0.102365,0.666667,0.833333,0.833333,35.544933,10.929605,2.063775,-1.94593,2.160253,-1.959015,6.563502,1.293168,1.360964,2.993497,26.529325,5.154701,2.666558,5.690274,2.642734,1.206205,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2-Trichloroethane""",-1.48,"""ClCC(Cl)Cl""",5.095679,5.095679,0.308642,-0.405864,0.480258,9.6,133.405,130.381,131.930033,32.0,0.0,0.120829,-0.123772,0.123772,0.120829,1.4,1.6,1.6,35.539546,10.949696,1.882355,-1.820706,2.007265,-1.820631,6.469236,1.398333,1.378783,2.539539,18.854753,4.284457,2.41835,4.686137,2.270056,1.111945,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2-Trichlorotrifluoroethane""",-3.04,"""FC(F)(Cl)C(F)(Cl)Cl""",11.544753,11.544753,3.685957,-4.22608,0.553756,14.25,187.375,187.375,185.901768,50.0,0.0,0.382976,-0.199489,0.382976,0.199489,0.875,1.0,1.0,35.539783,10.767629,2.506738,-2.131584,2.420075,-2.268487,6.509766,0.024089,1.253298,4.020392,67.509775,7.0,3.267787,5.535574,3.25,1.383893,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [180]:
scaffolds = get_scaffolds(df['SMILES'].to_list())

df.insert_column(
    index = 3,
    column = pl.Series('scaffolds', scaffolds)
)

Compound ID,measured log(solubility:mol/L),SMILES,scaffolds,MaxAbsEStateIndex,MaxEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,SPS,MolWt,HeavyAtomMolWt,ExactMolWt,NumValenceElectrons,NumRadicalElectrons,MaxPartialCharge,MinPartialCharge,MaxAbsPartialCharge,MinAbsPartialCharge,FpDensityMorgan1,FpDensityMorgan2,FpDensityMorgan3,BCUT2D_MWHI,BCUT2D_MWLOW,BCUT2D_CHGHI,BCUT2D_CHGLO,BCUT2D_LOGPHI,BCUT2D_LOGPLOW,BCUT2D_MRHI,BCUT2D_MRLOW,AvgIpc,BalabanJ,BertzCT,Chi0,Chi0n,Chi0v,Chi1,…,fr_imide,fr_isocyan,fr_isothiocyan,fr_ketone,fr_ketone_Topliss,fr_lactam,fr_lactone,fr_methoxy,fr_morpholine,fr_nitrile,fr_nitro,fr_nitro_arom,fr_nitro_arom_nonortho,fr_nitroso,fr_oxazole,fr_oxime,fr_para_hydroxylation,fr_phenol,fr_phenol_noOrthoHbond,fr_phos_acid,fr_phos_ester,fr_piperdine,fr_piperzine,fr_priamide,fr_prisulfonamd,fr_pyridine,fr_quatN,fr_sulfide,fr_sulfonamd,fr_sulfone,fr_term_acetylene,fr_tetrazole,fr_thiazole,fr_thiocyan,fr_thiophene,fr_unbrch_alkane,fr_urea
str,f64,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""1,1,1,2-Tetrachloroethane""",-2.18,"""ClCC(Cl)(Cl)Cl""","""no_scaffold_0""",5.116512,5.116512,0.039352,-1.276235,0.487138,12.0,167.85,165.834,165.891061,38.0,0.0,0.203436,-0.122063,0.203436,0.122063,1.166667,1.333333,1.333333,35.582798,10.92878,2.155416,-2.003515,2.259665,-2.010232,6.690915,1.27955,1.351644,3.16849,35.302969,5.207107,2.718965,5.74268,2.56066,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,1-Trichloroethane""",-2.0,"""CC(Cl)(Cl)Cl""","""no_scaffold_1""",5.060957,5.060957,1.083333,-1.083333,0.445171,12.0,133.405,130.381,131.930033,32.0,0.0,0.187382,-0.084013,0.187382,0.084013,1.2,1.2,1.2,35.582513,10.948044,2.065641,-1.94269,2.228454,-1.870955,6.667091,1.268178,0.721928,3.023716,20.364528,4.5,2.633893,4.90168,2.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2,2-Tetrachloroethane""",-1.74,"""ClC(Cl)C(Cl)Cl""","""no_scaffold_2""",5.114198,5.114198,0.67284,-0.67284,0.527312,11.0,167.85,165.834,165.891061,38.0,0.0,0.137344,-0.102365,0.137344,0.102365,0.666667,0.833333,0.833333,35.544933,10.929605,2.063775,-1.94593,2.160253,-1.959015,6.563502,1.293168,1.360964,2.993497,26.529325,5.154701,2.666558,5.690274,2.642734,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2-Trichloroethane""",-1.48,"""ClCC(Cl)Cl""","""no_scaffold_3""",5.095679,5.095679,0.308642,-0.405864,0.480258,9.6,133.405,130.381,131.930033,32.0,0.0,0.120829,-0.123772,0.123772,0.120829,1.4,1.6,1.6,35.539546,10.949696,1.882355,-1.820706,2.007265,-1.820631,6.469236,1.398333,1.378783,2.539539,18.854753,4.284457,2.41835,4.686137,2.270056,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""1,1,2-Trichlorotrifluoroethane""",-3.04,"""FC(F)(Cl)C(F)(Cl)Cl""","""no_scaffold_4""",11.544753,11.544753,3.685957,-4.22608,0.553756,14.25,187.375,187.375,185.901768,50.0,0.0,0.382976,-0.199489,0.382976,0.199489,0.875,1.0,1.0,35.539783,10.767629,2.506738,-2.131584,2.420075,-2.268487,6.509766,0.024089,1.253298,4.020392,67.509775,7.0,3.267787,5.535574,3.25,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""vamidothion""",1.144,"""CNC(=O)C(C)SCCSP(=O)(OC)(OC)""","""no_scaffold_317""",11.615392,11.615392,0.003087,-2.968949,0.543859,13.5,287.343,269.199,287.041487,96.0,0.0,0.388103,-0.358225,0.388103,0.358225,1.375,1.9375,2.375,32.735952,10.444383,2.213171,-2.127898,2.587624,-2.27051,8.547558,-0.119315,2.323876,3.756993,256.422865,12.604448,9.888267,12.415688,7.486511,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""Vinclozolin""",-4.925,"""CC1(OC(=O)N(C1=O)c2cc(Cl)cc(Cl…","""O=C1COC(=O)N1c1ccccc1""",12.114445,12.114445,0.271366,-1.355281,0.782457,23.166667,286.114,277.042,284.995949,94.0,0.0,0.422243,-0.428036,0.428036,0.422243,1.277778,1.833333,2.277778,35.496835,10.071253,2.489549,-2.146264,2.32667,-2.365046,6.35264,-0.124869,2.489028,2.482348,538.003547,13.499636,9.444395,10.956253,8.370028,…,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""Warfarin""",-3.893,"""CC(=O)CC(c1ccccc1)c3c(O)c2cccc…","""O=c1oc2ccccc2cc1Cc1ccccc1""",12.412307,12.412307,0.064063,-0.614534,0.747626,12.217391,308.333,292.205,308.104859,116.0,0.0,0.343366,-0.506592,0.506592,0.343366,1.086957,1.73913,2.434783,16.39494,9.86892,2.251678,-2.169199,2.354357,-2.161027,5.847714,-0.116735,2.541895,2.258072,909.550973,16.396977,12.652568,12.652568,11.075387,…,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""Xipamide""",-3.79,"""Cc1cccc(C)c1NC(=O)c2cc(c(Cl)cc…","""O=C(Nc1ccccc1)c1ccccc1""",12.38,12.38,0.235113,-4.118515,0.786275,11.304348,354.815,339.695,354.044106,122.0,0.0,0.258979,-0.507064,0.507064,0.258979,1.217391,1.782609,2.304348,35.495694,10.084512,2.207049,-2.129387,2.298114,-2.170017,7.888858,0.101919,2.410222,2.481292,874.340157,17.361443,12.422273,13.994698,10.608079,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# Split and Scale Data

In [181]:
from sklearn.preprocessing import RobustScaler

In [182]:
splitter = GroupShuffleSplit(n_splits=1, test_size=options['test_size'], random_state=options['random_seed'])

train_idx, test_idx = next(splitter.split(df, groups=df['scaffolds']))

train_idx.shape, test_idx.shape

((940,), (204,))

In [183]:
# get the train and test data
df_train = df[train_idx]
df_test = df[test_idx]

X_train = df_train[DESCRIPTOR_NAMES].to_torch()
y_train = df_train['measured log(solubility:mol/L)'].to_torch().to(torch.float32)

X_test = df_test[DESCRIPTOR_NAMES].to_torch()
y_test = df_test['measured log(solubility:mol/L)'].to_torch().to(torch.float32)

del df_train, df_test

X_train.shape, y_train.shape, X_test.shape, y_test.shape

(torch.Size([940, 217]),
 torch.Size([940]),
 torch.Size([204, 217]),
 torch.Size([204]))

In [184]:
scaler = RobustScaler()
X_train_scaled = torch.tensor(scaler.fit_transform(X_train)).to(torch.float32)
X_test_scaled = torch.tensor(scaler.transform(X_test)).to(torch.float32)

# Train Single Model

In [185]:
from kan import KAN

In [186]:
num_features = X_train.shape[1]
single_model = KAN(width=[num_features, 1], seed=options['random_seed'], auto_save=False)

In [187]:
loss_dict = train_regression(single_model, X_train_scaled, y_train, X_test_scaled, y_test, num_itrs=options['num_train_iterations'])

Train iteration 0, mse: 15.15121841430664, r2: -2.9997494220733643, mae: 3.1482648849487305, rmse: 3.892456531524658
Test iteration 0, mse: 19.47254180908203, r2: -2.1771907806396484, mae: 3.5290088653564453, rmse: 4.4127702713012695
Train iteration 50, mse: 0.54261714220047, r2: 0.856755256652832, mae: 0.5660755038261414, rmse: 0.7366254925727844
Test iteration 50, mse: 0.8396344780921936, r2: 0.8630030155181885, mae: 0.6648234128952026, rmse: 0.9163157343864441
Train iteration 100, mse: 0.3416093587875366, r2: 0.9098190069198608, mae: 0.4494020640850067, rmse: 0.5844735503196716
Test iteration 100, mse: 0.6919218301773071, r2: 0.8871042132377625, mae: 0.5848508477210999, rmse: 0.8318184018135071
Train iteration 150, mse: 0.2690228521823883, r2: 0.9289810061454773, mae: 0.39788222312927246, rmse: 0.5186741352081299
Test iteration 150, mse: 0.6290132403373718, r2: 0.8973685503005981, mae: 0.5619761943817139, rmse: 0.7931035757064819
Train iteration 200, mse: 0.22571152448654175, r2: 0.

# Train Bag of KANs

In [188]:
torch_random_generator = torch.Generator().manual_seed(options['random_seed'])

bootstrapped_models = []
loss_dicts = []

for i in range(options['num_bootstrap_samples']):
    print('-'*100, f'\nTraining bootstrapped model {i+1}/{options["num_bootstrap_samples"]}')
    bootstrap_sample_indices = torch.randint(low=0, high=X_train_scaled.shape[0], 
                                             size=(X_train_scaled.shape[0],), 
                                            generator=torch_random_generator)

    X_train_bootstrap = X_train_scaled[bootstrap_sample_indices]
    y_train_bootstrap = y_train[bootstrap_sample_indices]

    bootstrapped_model = KAN(width=[num_features, 1], seed=options['random_seed'] + i, auto_save=False)

    loss_dict = train_regression(bootstrapped_model, X_train_bootstrap, y_train_bootstrap, X_test_scaled, y_test,
                                  num_itrs=options['num_train_iterations'])

    bootstrapped_models.append(bootstrapped_model)

---------------------------------------------------------------------------------------------------- 
Training bootstrapped model 1/100
Train iteration 0, mse: 15.173768997192383, r2: -2.9541091918945312, mae: 3.143610715866089, rmse: 3.8953521251678467
Test iteration 0, mse: 19.492952346801758, r2: -2.180521249771118, mae: 3.5307157039642334, rmse: 4.4150824546813965
Train iteration 50, mse: 0.4961289167404175, r2: 0.870714545249939, mae: 0.5424894094467163, rmse: 0.7043641805648804
Test iteration 50, mse: 1.025462031364441, r2: 0.8326829671859741, mae: 0.7039005756378174, rmse: 1.012650966644287
Train iteration 100, mse: 0.3026880919933319, r2: 0.9211229681968689, mae: 0.4194069504737854, rmse: 0.5501709580421448
Test iteration 100, mse: 0.8655644059181213, r2: 0.8587722778320312, mae: 0.6311904191970825, rmse: 0.9303571581840515
Train iteration 150, mse: 0.2195110023021698, r2: 0.9427979588508606, mae: 0.357797771692276, rmse: 0.46852001547813416
Test iteration 150, mse: 0.778685986

KeyboardInterrupt: 

In [163]:
# this is dumb. Make prettier later
class BagOfKans():
    def __init__(self, models):
        self.models = models
    
    def __call__(self, X):
        return torch.stack([bootstrapped_model(X)[:, 0] for bootstrapped_model in self.models], dim=0).mean(dim=0)

# Compare Performance

In [164]:
y_pred_train_single = single_model(X_train_scaled)[:, 0].detach().numpy()
y_pred_test_single = single_model(X_test_scaled)[:, 0].detach().numpy()

In [165]:
bag_of_kans = BagOfKans(bootstrapped_models)

y_pred_train_bok = bag_of_kans(X_train_scaled).detach().numpy()
y_pred_test_bok = bag_of_kans(X_test_scaled).detach().numpy()

In [166]:
regression_report(y_train, y_pred_train_single)

{'R2': 0.9654132723808289,
 'MSE': 0.13101594150066376,
 'MAE': 0.2777465283870697,
 'MAPE': 892679356416.0,
 'RMSE': 0.3619612455368042}

In [167]:
regression_report(y_train, y_pred_train_bok)

{'R2': 0.9607968330383301,
 'MSE': 0.14850324392318726,
 'MAE': 0.2864103317260742,
 'MAPE': 1322044882944.0,
 'RMSE': 0.3853611946105957}

In [168]:
regression_report(y_test, y_pred_test_single)

{'R2': 0.8893250823020935,
 'MSE': 0.6783104538917542,
 'MAE': 0.5960142612457275,
 'MAPE': 0.27880141139030457,
 'RMSE': 0.8235960602760315}

In [169]:
regression_report(y_test, y_pred_test_bok)

{'R2': 0.9007682204246521,
 'MSE': 0.6081770658493042,
 'MAE': 0.5842803716659546,
 'MAPE': 0.2572971284389496,
 'RMSE': 0.7798570990562439}