In [None]:
%load_ext autoreload
%autoreload 2
from copy import deepcopy
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.linear_model import ElasticNetCV, LinearRegression, RidgeCV, LassoCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y
from sklearn.utils.validation import _check_sample_weight
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor, AdaBoostClassifier, AdaBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm
from collections import defaultdict
import dvu
import pandas as pd
import matplotlib.pyplot as plt
import json
from matplotlib.colors import TwoSlopeNorm
from matplotlib.colors import Normalize
import joblib
import viz
from interpret import show

import imodels
from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.base import RegressorMixin, ClassifierMixin
from imodels.algebraic.gam_multitask import MultiTaskGAMRegressor

# Fit some simple GAMs

In [39]:
d = defaultdict(list)
for dset in ['bike_sharing', 'california_housing', 'diabetes_regr', 'heart', 'satellite_image', 'abalone', 'echo_months']:
    for use_single_task_with_reweighting in [False, True]:
        if use_single_task_with_reweighting:
            fit_linear_frac_list = [0.1, 0.25, 0.5, 0.75, 0.9, None]
        else:
            fit_linear_frac_list = [None]
        for fit_linear_frac in fit_linear_frac_list:
            print(f'{dset=}, {use_single_task_with_reweighting=}, {fit_linear_frac=}')
            X, y, feature_names = imodels.get_clean_dataset(dset)
            X = StandardScaler().fit_transform(X)
            y = StandardScaler().fit_transform(y.reshape(-1, 1)).ravel()
            X, X_test, y_train, y_test = train_test_split(
                X, y, random_state=42, test_size=0.25)
            ebm_kwargs = dict(
                random_state=42,
                n_jobs=-2,
            )

            gam = MultiTaskGAMRegressor(
                multitask=use_single_task_with_reweighting,
                fit_linear_frac=fit_linear_frac,
                interactions=False,
                use_single_task_with_reweighting=use_single_task_with_reweighting,
                ebm_kwargs=ebm_kwargs)

            np.random.seed(42)
            gam.fit(X, y_train)
            d['dset'].append(dset)
            d['use_single_task_with_reweighting'].append(
                use_single_task_with_reweighting)
            d['fit_linear_frac_list'].append(fit_linear_frac)
            d['test_corr'].append(np.corrcoef(
                y_test, gam.predict(X_test))[0, 1])
            d['test_r2'].append(gam.score(X_test, y_test))
            d['train_corr'].append(np.corrcoef(y_train, gam.predict(X))[0, 1])
            d['train_r2'].append(gam.score(X, y_train))
            if hasattr(gam, 'lin_model') and hasattr(gam.lin_model, 'coef_'):
                d['coef'].append(gam.lin_model.coef_)
            else:
                d['coef'].append([])
            # print(pd.DataFrame(d))
joblib.dump(pd.DataFrame(
    d), '../figs/use_single_task_reweighting_results.pkl')

dset='bike_sharing', use_single_task_with_reweighting=False, fit_linear_frac=None
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=0.1
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=0.25
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=0.5
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=0.75
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=0.9
fetching 42712 from openml
dset='bike_sharing', use_single_task_with_reweighting=True, fit_linear_frac=None
fetching 42712 from openml
dset='california_housing', use_single_task_with_reweighting=False, fit_linear_frac=None
fetching california_housing from sklearn
dset='california_housing', use_single_task_with_reweighting=True, fit_linear_frac=0.1
fetching california_housing fro

['../figs/use_single_task_reweighting_results.pkl']

In [None]:
df = pd.DataFrame(d)
df = df.fillna(1)

In [52]:
d = (
    df
    .pivot_table(
        index='dset', values='test_corr',
        columns=['use_single_task_with_reweighting', 'fit_linear_frac_list']
    )
    .round(3)
)

# add mean row (with std)
d.loc['AVG'] = d.mean()
# d.loc['std'] = d.std()
(
    d

    .style.background_gradient(cmap='Blues', axis=1)
    .format(precision=3)
    # make numbers in bottom row bold and underlined and large
    .apply(lambda x: ['font-weight: bold; text-decoration: underline; font-size: 110%' if x.name == 'AVG' else '' for i in x], axis=1)


)

use_single_task_with_reweighting,False,True,True,True,True,True,True
fit_linear_frac_list,1.000000,0.100000,0.250000,0.500000,0.750000,0.900000,1.000000
dset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
abalone,0.719,0.692,0.708,0.713,0.718,0.717,0.702
bike_sharing,0.836,0.83,0.833,0.835,0.836,0.835,0.836
california_housing,0.875,0.854,0.864,0.87,0.874,0.874,0.875
diabetes_regr,0.708,0.698,0.673,0.709,0.731,0.667,0.675
echo_months,0.664,0.665,0.662,0.664,0.665,0.663,0.659
heart,0.749,0.686,0.695,0.665,0.72,0.759,0.708
satellite_image,0.916,0.914,0.914,0.915,0.914,0.912,0.916
AVG,0.781,0.763,0.764,0.767,0.78,0.775,0.767


In [None]:
df1 = df[(df['use_single_task_with_reweighting'] == True)
         * (df['fit_linear_frac_list'] == 1)]
plt.figure(figsize=(12, 4))
for i in range(len(df1)):
    plt.subplot(1, len(df1), i+1)
    plt.hist(df1['coef'].iloc[i])
    plt.title(df1['dset'].iloc[i])
    plt.xlabel('Coefficient')
    plt.ylabel('Count')
plt.tight_layout()
plt.show()