# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
import torch
from glob import glob
import shap
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig, GatedAdditiveTreeEnsembleConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.classification import (
    multiclass_accuracy,
    multiclass_f1_score,
    multiclass_precision,
    multiclass_recall,
    multiclass_specificity,
    multiclass_cohen_kappa,
    multiclass_auroc
)
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import confusion_matrix
import pandas as pd
import pathlib
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from pytorch_tabular.utils import get_balanced_sampler
from src.plot.radar import radar_factory


# Load data

In [None]:
path_data = "D:/YandexDisk/Work/bbd/immunology/002_central_vs_yakutia/classification"
path_configs = "D:/Work/bbs/notebooks/immunology/002_central_vs_yakutia/pt_configs"
data = pd.read_excel(f"{path_data}/data.xlsx", index_col=0)
feats = pd.read_excel(f"{path_data}/feats.xlsx", index_col=0).index.values.tolist()

test_split_id = 0

val_n_splits = 4
val_random_state = 1337
val_fold_id = 0

for fold_id in range(val_n_splits):
    data[f"Fold_{fold_id}"] = data[f"Split_{test_split_id}"]

stratify_cat_parts = {
    'Central': data.index[(data['Region'] == 'Central') & (data[f"Split_{test_split_id}"] == 'trn_val')].values,
    'Yakutia': data.index[(data['Region'] == 'Yakutia') & (data[f"Split_{test_split_id}"] == 'trn_val')].values,
}
for part, ids in stratify_cat_parts.items():
    print(f"{part}: {len(ids)}")
    con = data.loc[ids, 'Age'].values
    ptp = np.ptp(con)
    num_bins = 5
    bins = np.linspace(np.min(con) - 0.1 * ptp, np.max(con) + 0.1 * ptp, num_bins + 1)
    binned = np.digitize(con, bins) - 1
    unique, counts = np.unique(binned, return_counts=True)
    occ = dict(zip(unique, counts))
    k_fold = RepeatedStratifiedKFold(
        n_splits=val_n_splits,
        n_repeats=1,
        random_state=val_random_state
    )
    splits = k_fold.split(X=ids, y=binned, groups=binned)
    
    for fold_id, (ids_trn, ids_val) in enumerate(splits):
        data.loc[ids[ids_trn], f"Fold_{fold_id}"] = "trn"
        data.loc[ids[ids_val], f"Fold_{fold_id}"] = "val"
        
test = data.loc[data[f"Split_{test_split_id}"] == "tst", feats + ['Region']]
train_validation = data.loc[data[f"Split_{test_split_id}"] == "trn_val", feats + ['Region'] + [f"Fold_{i}" for i in range(val_n_splits)]]
train_only = data.loc[data[f"Fold_{val_fold_id}"] == "trn", feats + ['Region']]
validation_only = data.loc[data[f"Fold_{val_fold_id}"] == "val", feats + ['Region']]
cv_indexes = [
    (
        np.where(train_validation.index.isin(train_validation.index[train_validation[f"Fold_{i}"] == 'trn']))[0],
        np.where(train_validation.index.isin(train_validation.index[train_validation[f"Fold_{i}"] == 'val']))[0],
    )
    for i in range(val_n_splits)
]

## Prepare balanced sampler

In [None]:
sampler_balanced = get_balanced_sampler(train_only['Region'].values.ravel())

# Models Search Spaces

## CategoryEmbeddingModel Search Space

In [None]:
search_space = {
    "model_config__layers": ["256-128-64", "512-256-128", "32-16", "32-32-16", "16-8", "32-16-8", "128-64", "128-128", "16-16"],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/models/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [5, 10, 15, 20, 25, 30, 35],
    "model_config__gflu_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__gflu_feature_init_sparsity": [0.1, 0.2, 0.3, 0.4, 0.5],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337, 666],
}
model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## TabNetModel Search Space

In [None]:
search_space = {
    "model_config__n_d": [8, 16, 24, 32, 40, 48],
    "model_config__n_a": [8, 16, 24, 32, 40, 48],
    "model_config__n_steps": [3, 5, 7],
    "model_config__gamma": [1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
    "model_config__n_independent": [1, 2, 3, 4, 5],
    "model_config__n_shared": [1, 2, 3, 4, 5],
    "model_config__mask_type": ["sparsemax", "entmax"],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337, 666],
}
model_config = read_parse_config(f"{path_configs}/models/TabNetModelConfig.yaml", TabNetModelConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## DANet Search Space

In [None]:
search_space = {
    "model_config__n_layers": [4, 8, 16, 20, 32],
    "model_config__abstlay_dim_1": [8, 16, 32, 64],
    "model_config__k": [3, 4, 5, 6, 7],
    "model_config__dropout_rate": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337, 666],
}
model_config = read_parse_config(f"{path_configs}/models/DANetConfig.yaml", DANetConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## FTTransformer Search Space

In [None]:
search_space = {
    "model_config__num_heads": [2, 4, 8, 16, 32],
    "model_config__num_attn_blocks": [4, 6, 8, 10, 12],
    "model_config__attn_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__add_norm_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__ff_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337, 666],
}
model_config = read_parse_config(f"{path_configs}/models/FTTransformerConfig.yaml", FTTransformerConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

# Grid Search and Random Search

In [None]:
%%capture

strategy = 'random_search' # 'grid_search'
seed = 1337
n_random_trials = 250
is_cross_validation = True

if grid_size < n_random_trials and strategy == 'random_search':
    strategy = 'grid_search'

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['continuous_feature_transform'] = 'yeo-johnson'
data_config['normalize_continuous_features'] = True
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
trainer_config['auto_lr_find'] = True
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if is_cross_validation:
        result = tuner.tune(
            train=train_validation,
            validation=None,
            search_space=search_space,
            metric="accuracy",
            mode="max",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=cv_indexes,
            return_best_model=True,
            verbose=False,
            progress_bar=False,
            random_state=seed,
            train_sampler=sampler_balanced
        )
    else: 
        result = tuner.tune(
            train=train_only,
            validation=validation_only,
            search_space=search_space,
            metric="accuracy",
            mode="max",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=None,
            return_best_model=True,
            verbose=False,
            progress_bar=False,
            random_state=seed,
            train_sampler=sampler_balanced
        )

result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/{model_config['_model_name']}_{strategy}_{seed}_{optimizer_config['lr_scheduler']}.xlsx")

# Model Sweep Training

## Generate models' configs from trials files

In [None]:
n_top_trials = 50

target_models_types = [
    'CategoryEmbeddingModel',
    'GANDALF',
    'TabNetModel',
    'DANet',
    # 'FTTransformer'
]

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)

common_params = {
    "task": "classification",
}

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for model_type in target_models_types:
    trials_files = glob(f"{trainer_config['checkpoints_path']}/trials/{model_type}*.xlsx")
    for trials_file in trials_files:
        df_trials = pd.read_excel(trials_file, index_col=0)
        df_trials.sort_values(['accuracy'], ascending=[True], inplace=True)
        df_trials = df_trials.head(n_top_trials)
        for _, row in df_trials.iterrows():
            head_config_tmp = copy.deepcopy(head_config)
            head_config_tmp['dropout'] = float(row['model_config.head_config__dropout'])
            if model_type == 'CategoryEmbeddingModel':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", CategoryEmbeddingModelConfig)
                model_config['layers'] = row['model_config__layers']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(CategoryEmbeddingModelConfig(**model_config))
            elif model_type == 'GANDALF':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", GANDALFConfig)
                model_config['gflu_stages'] = int(row['model_config__gflu_stages'])
                model_config['gflu_feature_init_sparsity'] = float(row['model_config__gflu_feature_init_sparsity'])
                model_config['gflu_dropout'] = float(row['model_config__gflu_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(GANDALFConfig(**model_config))
            elif model_type == 'TabNetModel':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", TabNetModelConfig)
                model_config['n_steps'] = row['model_config__n_steps']
                model_config['n_shared'] = row['model_config__n_shared']
                model_config['n_independent'] = row['model_config__n_independent']
                model_config['n_d'] = row['model_config__n_d']
                model_config['n_a'] = row['model_config__n_a']
                model_config['mask_type'] = row['model_config__mask_type']
                model_config['gamma'] = row['model_config__gamma']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(TabNetModelConfig(**model_config))
            elif model_type == 'FTTransformer':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", FTTransformerConfig)
                model_config['num_heads'] = int(row['model_config__num_heads'])
                model_config['num_attn_blocks'] = int(row['model_config__num_attn_blocks'])
                model_config['attn_dropout'] = float(row['model_config__attn_dropout'])
                model_config['add_norm_dropout'] = float(row['model_config__add_norm_dropout'])
                model_config['ff_dropout'] = float(row['model_config__ff_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(FTTransformerConfig(**model_config))
            elif model_type == 'DANet':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", DANetConfig)
                model_config['n_layers'] = int(row['model_config__n_layers'])
                model_config['abstlay_dim_1'] = int(row['model_config__abstlay_dim_1'])
                model_config['k'] = int(row['model_config__k'])
                model_config['dropout_rate'] = float(row['model_config__dropout_rate'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(DANetConfig(**model_config))
print(len(model_list))

## Perform model sweep

In [None]:
%%capture

for seed in [1337, 55763, 40279, 8751, 234461]:

    trainer_config['seed'] = seed
    trainer_config['checkpoints'] = 'valid_loss'
    trainer_config['load_best'] = True
    trainer_config['auto_lr_find'] = True
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sweep_df, best_model = model_sweep_custom(
            task="classification",
            # train=train_validation,
            train=train_only,
            # validation=None,
            validation=validation_only,
            test=test,
            data_config=data_config,
            optimizer_config=optimizer_config,
            trainer_config=trainer_config,
            model_list=model_list,
            common_model_args=common_params,
            metrics=[
                "accuracy",
                "f1_score",
                "precision",
                "recall",
                "specificity",
                "cohen_kappa",
                "auroc"
            ],
            metrics_params=[
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                {'task': 'multiclass', 'num_classes': 2},
                {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
            ],
            metrics_prob_input=[True, True, True, True, True, True, True],
            rank_metric=("accuracy", "higher_is_better"),
            return_best_model=True,
            seed=seed,
            progress_bar=False,
            verbose=False,
            suppress_lightning_logger=True,
            train_sampler=sampler_balanced
        )
    fn_suffix = f"{seed}_{best_model.config['lr_scheduler']}_{best_model.config['continuous_feature_transform']}"
    sweep_df.style.background_gradient(
        subset=[
            "train_loss",
            "validation_loss",
            "test_loss",
            "time_taken",
            "time_taken_per_epoch"
        ], cmap="RdYlGn_r"
    ).to_excel(f"{trainer_config['checkpoints_path']}/sweep_{fn_suffix}.xlsx")

## Save best models

In [None]:
%%capture

seed = 234461
models_ids = [196, 159]
explain_method = "GradientShap"
explain_baselines = "b|1000"

class_names = ['Central', 'Yakutia']

sweep_suffix = f"{seed}_{optimizer_config['lr_scheduler']}_{data_config['continuous_feature_transform']}"
path_models = f"{trainer_config['checkpoints_path']}/candidates/{sweep_suffix}"
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)

sweep_df = pd.read_excel(f"{trainer_config['checkpoints_path']}/sweep_{sweep_suffix}.xlsx", index_col=0)

for model_id in models_ids:

    tabular_model = TabularModel(
        data_config=data_config,
        model_config=ast.literal_eval(sweep_df.at[model_id, 'params']),
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=True,
        suppress_lightning_logger=False
    )
    datamodule = tabular_model.prepare_dataloader(
        train=train_only,
        validation=validation_only,
        seed=seed,
    )
    model = tabular_model.prepare_model(
        datamodule
    )
    tabular_model._prepare_for_training(
        model,
        datamodule
    )
    tabular_model.load_weights(sweep_df.at[model_id, 'checkpoint'])
    tabular_model.evaluate(test, verbose=False)
    tabular_model.save_model(f"{path_models}/{model_id}")
    
    loaded_model = TabularModel.load_model(f"{path_models}/{model_id}")
    
    df = data.loc[:, ['Age', 'SImAge', 'Sex', 'Region']]
    df.loc[train_only.index, 'Group'] = 'Train'
    df.loc[validation_only.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df = pd.concat(
        [
            df,
            loaded_model.predict(data),
            loaded_model.predict(data, ret_logits=True).loc[:, ['logits_0', 'logits_1']]
        ],
        axis=1
    )
    df.rename(columns={'prediction': 'Prediction', 'logits_0': 'Central_logits', 'logits_1': 'Yakutia_logits'},
              inplace=True)
    df['Region ID'] = df['Region']
    df['Region ID'].replace({'Central': 0, 'Yakutia': 1}, inplace=True)
    df['Prediction ID'] = df['Prediction']
    df['Prediction ID'].replace({'Central': 0, 'Yakutia': 1}, inplace=True)
    df.to_excel(f"{path_models}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    metrics_w_avg = [
        "accuracy",
        "f1_score",
        "precision",
        "recall",
        "specificity",
        "auroc"
    ]
    metrics_wo_avg = [
        "cohen_kappa"
    ]
    metrics_names = {
        "accuracy": "Accuracy",
        "f1_score": "F-1 Score",
        "precision": "Precision",
        "recall": "Recall",
        "specificity": "Specificity",
        "auroc": "AUROC",
        "cohen_kappa": "Cohen Kappa"
    }
    
    df_metrics = pd.DataFrame(
        index=[f"{m}_macro" for m in metrics_w_avg] +
              [f"{m}_weighted" for m in metrics_w_avg] +
              metrics_wo_avg,
        columns=list(colors_groups.keys()),
        data=np.zeros((len(metrics_w_avg) * 2 + 1, len(colors_groups))),
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction ID'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, 'Region ID'].values)
        probs = torch.from_numpy(df.loc[df['Group'] == group, ['Central_probability', 'Yakutia_probability']].values)
        for avg_type in ['macro', 'weighted']:
            df_metrics.at[f"accuracy_{avg_type}", group] = multiclass_accuracy(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"f1_score_{avg_type}", group] = multiclass_f1_score(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"precision_{avg_type}", group] = multiclass_precision(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"recall_{avg_type}", group] = multiclass_recall(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"specificity_{avg_type}", group] = multiclass_specificity(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"auroc_{avg_type}", group] = multiclass_auroc(preds=probs, target=real, num_classes=2, average=avg_type).numpy()
        df_metrics.at["cohen_kappa", group] = multiclass_cohen_kappa(preds=pred, target=real, num_classes=2).numpy()
        
        conf_mtx = confusion_matrix(real, pred)
        cm_sum = np.sum(conf_mtx, axis=1, keepdims=True)
        cm_perc = conf_mtx / cm_sum.astype(float) * 100
        annot = np.empty_like(conf_mtx).astype(str)
        nrows, ncols = conf_mtx.shape
        for i in range(nrows):
            for j in range(ncols):
                c = conf_mtx[i, j]
                p = cm_perc[i, j]
                if i == j:
                    s = cm_sum[i]
                    annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
                elif c == 0:
                    annot[i, j] = ''
                else:
                    annot[i, j] = '%.1f%%\n%d' % (p, c)
        conf_mtx = pd.DataFrame(conf_mtx, index=class_names, columns=class_names)
        conf_mtx.index.name = 'Actual'
        conf_mtx.columns.name = 'Predicted'
        fig, ax = plt.subplots(figsize=(1.5*len(class_names), 0.8*len(class_names)))
        heatmap = sns.heatmap(conf_mtx, annot=annot, fmt='', ax=ax)
        heatmap.set_aspect('equal', adjustable='box')
        fig.savefig(f"{path_models}/{model_id}/confusion_matrix_{group}.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_models}/{model_id}/confusion_matrix_{group}.pdf", bbox_inches='tight')
        plt.close(fig)
        
    df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    for avg_type in ['macro', 'weighted']:
        n_categories = len(metrics_w_avg) + len(metrics_wo_avg)
        theta = radar_factory(n_categories, frame='polygon')
        
        case_data = df_metrics.loc[[f"{m}_{avg_type}" for m in metrics_w_avg] + metrics_wo_avg, list(colors_groups.keys())].T.values
        
        fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
        ax.set_rgrids(list(np.linspace(0, 1, 21)))
        for d, group in zip(case_data, colors_groups):
            ax.plot(theta, d, color=colors_groups[group])
            ax.fill(theta, d, facecolor=colors_groups[group], alpha=0.25, label='_nolegend_')
        ax.set_varlabels([metrics_names[m_name] for m_name in metrics_w_avg + metrics_wo_avg])
        labels = (list(colors_groups.keys()))
        legend = ax.legend(labels, loc=(0.9, .95), labelspacing=0.1, fontsize='small')
        fig.savefig(f"{path_models}/{model_id}/metrics_{avg_type}.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_models}/{model_id}/metrics_{avg_type}.pdf", bbox_inches='tight')
        plt.close(fig)
    
    try:
        explanation = loaded_model.explain(train_only, method=explain_method, baselines=explain_baselines)
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=train_only.loc[:, feats].values,
            feature_names=feats,
            max_display=len(feats),
            plot_type="violin",
            show=False,
        )
        plt.savefig(f"{path_models}/{model_id}/explain_logits_beeswarm.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_models}/{model_id}/explain_logits_beeswarm.pdf", bbox_inches='tight')
        plt.close(fig)
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=train_only.loc[:, feats].values,
            feature_names=feats,
            max_display=len(feats),
            plot_type="bar",
            show=False,
        )
        plt.savefig(f"{path_models}/{model_id}/explain_logits_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_models}/{model_id}/explain_logits_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass


# Best model processing

## Load best model

In [None]:
data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)

seed = 1337
model_id = 197

class_names = ['Central', 'Yakutia']

sweep_suffix = f"{seed}_{optimizer_config['lr_scheduler']}_{data_config['continuous_feature_transform']}"
path_model = f"{trainer_config['checkpoints_path']}/candidates/{sweep_suffix}/{model_id}"

loaded_model = TabularModel.load_model(f"{path_model}")

## Best model probabilistic explanation

In [None]:
def predict_func(X):
    df_tmp = pd.DataFrame(columns=feats, data=X)
    return loaded_model.predict(df_tmp).loc[:, ['Central_probability', 'Yakutia_probability']].values

X = train_only.loc[:, feats].values
explainer = shap.SamplingExplainer(predict_func, X)
expected_value = explainer.expected_value

shap_values = explainer.shap_values(X)

In [None]:
with (pd.ExcelWriter(f"{path_model}/shap_values.xlsx", engine='xlsxwriter') as writer):
    for class_id, class_name in enumerate(class_names):
        df_shap_values = pd.DataFrame(data=shap_values[:, :, class_id], columns=feats, index=train_only.index.values)
        df_shap_values.to_excel(writer, sheet_name=class_name)

In [None]:
for class_id, class_name in enumerate(class_names):
    sns.set_theme(style='whitegrid')
    fig = shap.summary_plot(
        shap_values=shap_values[:, :, class_id],
        features=X,
        feature_names=feats,
        max_display=len(feats),
        plot_type="violin",
        show=False,
    )
    plt.gca().set_xlim([-0.4, 0.4])
    sns.despine(left=False, right=False, bottom=False, top=False)
    plt.gcf().axes[1].tick_params(labelsize=20)
    plt.gcf().axes[1].set_ylabel("Feature value", fontsize=22)
    plt.gca().tick_params(labelsize=20)
    plt.gca().set_xlabel("SHAP value (impact on model output)", fontsize=22)
    plt.savefig(f"{path_model}/explain_{class_name}_probs_beeswarm.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_model}/explain_{class_name}_probs_beeswarm.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig = shap.summary_plot(
        shap_values=shap_values[:, :, class_id],
        features=X,
        feature_names=feats,
        max_display=len(feats),
        plot_type="bar",
        show=False,
    )
    sns.despine(left=False, right=False, bottom=False, top=False)
    plt.savefig(f"{path_model}/explain_{class_name}_probs_bar.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_model}/explain_{class_name}_probs_bar.pdf", bbox_inches='tight')
    plt.close(fig)
    
    explanation = shap.Explanation(
        values=shap_values[:, :, class_id],
        base_values=np.array([expected_value] * X.shape[0]),
        data=X,
        feature_names=feats
    )
    shap.plots.heatmap(
        explanation,
        show=False,
        max_display=len(feats),
        instance_order=explanation.sum(1)
    )
    plt.savefig(f"{path_model}/explain_{class_name}_probs_heatmap.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_model}/explain_{class_name}_probs_heatmap.pdf", bbox_inches='tight')
    plt.close()

## Additional plots for best model

### Radar chart

In [None]:
colors_groups = {
    'Train': 'chartreuse',
    'Validation': 'dodgerblue',
    'Test': 'crimson',
}

metrics_w_avg = [
    "accuracy",
    "f1_score",
    "precision",
    "recall",
    "specificity",
    "auroc"
]
metrics_wo_avg = [
    "cohen_kappa"
]
metrics_names = {
    "accuracy": "Accuracy",
    "f1_score": "F-1 Score",
    "precision": "Precision",
    "recall": "Recall",
    "specificity": "Specificity",
    "auroc": "AUROC",
    "cohen_kappa": "Cohen Kappa"
}

df_metrics = pd.read_excel(f"{path_model}/metrics.xlsx", index_col=0)

for avg_type in ['macro', 'weighted']:
    n_categories = len(metrics_w_avg) + len(metrics_wo_avg)
    theta = radar_factory(n_categories, frame='polygon')
    
    case_data = df_metrics.loc[[f"{m}_{avg_type}" for m in metrics_w_avg] + metrics_wo_avg, ['Test']].values
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(3.5, 3.5), subplot_kw=dict(projection='radar'))
    ax.set_rgrids(list(np.linspace(0, 1, 101)))
    ax.plot(theta, case_data, color=colors_groups['Test'])
    ax.fill(theta, case_data, facecolor=colors_groups['Test'], alpha=0.25, label='_nolegend_')
    ax.set_varlabels([metrics_names[m_name] for m_name in metrics_w_avg + metrics_wo_avg])
    fig.savefig(f"{path_model}/metrics_test_only_{avg_type}.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_model}/metrics_test_only_{avg_type}.pdf", bbox_inches='tight')
    plt.close(fig)

# Simple TabularModel training

In [None]:
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = 'valid_loss'
trainer_config['load_best'] = True
trainer_config['auto_lr_find'] = True

tabular_model = TabularModel(
    data_config=f"{path_configs}/DataConfig.yaml",
    model_config=f"{path_configs}/models/CategoryEmbeddingModelConfig.yaml",
    optimizer_config=f"{path_configs}/OptimizerConfig.yaml",
    trainer_config=trainer_config,
    verbose=True,
    suppress_lightning_logger=False
)

tabular_model.fit(
    train=train_only,
    validation=validation_only,
    # target_transform=[np.log, np.exp],
    # callbacks=[DeviceStatsMonitor()],
)

## Play with trained model

In [None]:
tabular_model.predict(test, progress_bar='rich')

In [None]:
tabular_model.evaluate(test, verbose=True, ckpt_path="best")

In [None]:
tabular_model.config['checkpoints_path']

In [None]:
print(tabular_model.trainer.checkpoint_callback.best_model_path)

In [None]:
tabular_model.summary()

In [None]:
tabular_model.save_model(tabular_model.config['checkpoints_path'])

In [None]:
tabular_model.save_config(tabular_model.config['checkpoints_path'])

In [None]:
tabular_model = TabularModel.load_model(tabular_model.config['checkpoints_path'])