In [1]:
# Auto Reload
%load_ext autoreload
%autoreload 2

### Data Preprocess

In [2]:
import pandas as pd
import numpy as np

x = pd.read_csv("./data/KAMIR_data_preprocessing_no_normalization.csv", encoding = 'unicode_escape')
y = pd.read_csv("./data/KAMIR_Labels1.csv", encoding = 'unicode_escape')
x = x.iloc[:, 1:]
y = y.iloc[:, 1:]
y = y.iloc[:, 20]
y = y.fillna(0)

In [3]:
idx = []
for i in range(len(x.columns)):
    idx.append(i)

In [4]:
cat_col = ['Sex',
 'ECG Change',
 'Chest pain',
 'FMC',
 'Killip Class at admission',
 'ST change on ECG',
 'Heart rhythm on ECG',
 'Smoking History',
 'Initial diagnosis',
 'STEMI',
 'NSTEMI',
 'Coronary artery stenosis',
 'Puncture route',
 'Target vessel',
 'Lesion type of target vessel',
 'Pre TIMI flow of target vessel',
 'Treatment of target Vessel',
 'Post TIMI flow of target vessel',
 'Result of PCI',
 'Status of revascularization',
 'MR grade',
 'Final diagnosis',
 '12M_MR grade',
 '12M_Oral hypoglycemic agent (1)',
 '12M_Oral anticoagulant',
 '12M_Ezetimide',
 '12M_Fibrate',
 '12M_Statin',
 '12M_Aspirin',
 '12M_ACE inhibitor',
 '12M_Beta-blocker',
 '12M_Ca-channel blocker',
 '12M_Ticagrelor',
 '12M_Cilostazol',
 '12M_Prasugurel',
 '12M_Clopidogrel',
 'on Treatment',
 'on Treatment.1',
 'on Treatment.2',
 'on Treatment.3',
 'on Treatment.4',
 'on Treatment.5',
 'DiagnosisCerebrovascular Disease',
 'on Treatment.6',
 'Menopause',
 'Hysterectomy History',
 'Number of involved vessels',
 'Index procedure',
 'Staged PCI',
 '12M_ARB',
 '12M_Omega3 FA',
 '12M_Oral hypoglycemic agent (3)',
 '12M_Oral hypoglycemic agent (2)']

cat_idx = []
for i in cat_col:
    cat_idx.append(x.columns.get_loc(i))

cat_dim = []

for i in cat_idx:
    cat_dim.append(max(x.iloc[:,i])+1)

In [5]:
for i in cat_idx:
    idx.remove(i)

In [6]:
from sklearn import preprocessing

x.iloc[:,idx] = preprocessing.normalize(x.iloc[:,idx])

In [7]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=22, stratify=y)

In [8]:
import torch
import torch.nn as nn

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

In [9]:
x = np.array(x_train)
y = np.array(y_train)

x_test = np.array(x_test)
y_test = np.array(y_test)

In [10]:
alive = len(y[y==0])
mortality = len(y[y==1])
total = len(y)
weight_0 = (1/alive)*(total)/2.0
weight_1 = (1/mortality)*(total)/2.0
print(weight_0)
print(weight_1)

0.5306451612903226
8.657894736842104


### Model : TabNet with Optuna

In [11]:
from pytorch_tabnet.tab_model import TabNetClassifier
import optuna
from optuna import Trial, visualization
from sklearn.model_selection import StratifiedKFold

In [12]:
def Objective(trial):
    mask_type     = trial.suggest_categorical("mask_type", ["entmax", "sparsemax"])
    n_da          = trial.suggest_int("n_da", 20, 64, step=4)
    n_steps       = trial.suggest_int("n_steps", 4, 10, step=3)
    gamma         = trial.suggest_float("gamma", 1.0, 2.0, step=0.1)
    n_shared      = trial.suggest_int("n_shared", 1, 5, step=1)
    lambda_sparse = trial.suggest_float("lambda_sparse", 1e-6, 1e-3, log=True)
    sc_gamma      = trial.suggest_float("sc_gamma", 0.5, 2.0, step=0.1)
    sc_steps      = trial.suggest_int("sc_steps", 10, 50, step=10)
    tabnet_params = dict(cat_idxs         = cat_idx,
                         cat_dims         = cat_dim,
                         n_d              = n_da,
                         n_a              = n_da,
                         n_steps          = n_steps,
                         gamma            = gamma,
                         lambda_sparse    = lambda_sparse, optimizer_fn=torch.optim.Adam,
                         optimizer_params = dict(lr=2e-2, weight_decay=1e-5),
                         mask_type        = mask_type,
                         n_shared         = n_shared,
                         scheduler_params = dict(gamma=sc_gamma, step_size=sc_steps),
                         scheduler_fn     = torch.optim.lr_scheduler.StepLR,
                         verbose=0,
                         ) #early stopping
    kf = StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
    CV_score_array = []
    for train_index, test_index in kf.split(x):
        x_train, x_valid       = x[train_index], x[test_index]
        y_train, y_valid       = y[train_index], y[test_index]
        tabnet                 = TabNetClassifier(**tabnet_params)
        tabnet.fit(X_train     = x_train,
                   y_train     = y_train,
                   eval_set    = [(x_valid, y_valid)],
                   patience    = trial.suggest_int("patience",low=15,high=30),
                   max_epochs  = trial.suggest_int('epochs', 15, 100),
                   eval_metric = ['accuracy'],
                   weights     = {0:weight_0, 1:weight_1})
        CV_score_array.append(tabnet.best_cost)
    avg = np.mean(CV_score_array)
    return avg

In [13]:
study = optuna.create_study(direction="maximize", study_name='TabNet optimization')
study.optimize(Objective, timeout=6*60)

[32m[I 2022-05-12 07:31:09,626][0m A new study created in memory with name: TabNet optimization[0m



Early stopping occurred at epoch 62 with best_epoch = 37 and best_val_0_accuracy = 0.93483
Best weights from best epoch are automatically used!

Early stopping occurred at epoch 28 with best_epoch = 3 and best_val_0_accuracy = 0.94442
Best weights from best epoch are automatically used!

Early stopping occurred at epoch 30 with best_epoch = 5 and best_val_0_accuracy = 0.9416
Best weights from best epoch are automatically used!
Stop training because you reached max_epochs = 99 with best_epoch = 83 and best_val_0_accuracy = 0.9556
Best weights from best epoch are automatically used!
Stop training because you reached max_epochs = 99 with best_epoch = 97 and best_val_0_accuracy = 0.9592
Best weights from best epoch are automatically used!


[32m[I 2022-05-12 07:38:24,890][0m Trial 0 finished with value: 0.9471296601359456 and parameters: {'mask_type': 'entmax', 'n_da': 36, 'n_steps': 7, 'gamma': 1.1, 'n_shared': 1, 'lambda_sparse': 6.725627863421793e-05, 'sc_gamma': 2.0, 'sc_steps': 50, 'patience': 25, 'epochs': 99}. Best is trial 0 with value: 0.9471296601359456.[0m


In [14]:
TabNet_params = study.best_params
print(TabNet_params)

{'mask_type': 'entmax', 'n_da': 36, 'n_steps': 7, 'gamma': 1.1, 'n_shared': 1, 'lambda_sparse': 6.725627863421793e-05, 'sc_gamma': 2.0, 'sc_steps': 50, 'patience': 25, 'epochs': 99}


In [15]:
final_params = dict(cat_idxs         = cat_idx,
                    cat_dims         = cat_dim,
                    n_d              = TabNet_params['n_da'],
                    n_a              = TabNet_params['n_da'],
                    n_steps          = TabNet_params['n_steps'],
                    gamma            = TabNet_params['gamma'],
                    lambda_sparse    = TabNet_params['lambda_sparse'],
                    optimizer_fn     = torch.optim.Adam,
                    optimizer_params = dict(lr=2e-2, weight_decay=1e-5),
                    mask_type        = TabNet_params['mask_type'],
                    n_shared         = TabNet_params['n_shared'],
                    scheduler_params = dict(gamma=TabNet_params["sc_gamma"], step_size=TabNet_params["sc_steps"]),
                    scheduler_fn     = torch.optim.lr_scheduler.StepLR,
                    verbose=0,
                    )

epochs = TabNet_params['epochs']

In [None]:
tabnet = TabNetClassifier(**final_params)
tabnet.fit(X_train     = x,
           y_train     = y,
           patience    = TabNet_params['patience'],
           max_epochs  = epochs,
           eval_metric = ['accuracy'],
           weights     = {0:weight_0, 1:weight_1})

No early stopping will be performed, last training weights will be used.


### Prediction and Confusion Matrix

In [None]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score

pred = tabnet.predict_proba(x_test)
test_auc = roc_auc_score(y_score=pred[:,1], y_true=y_test)

pred = tabnet.predict(x_test)
test_pre = precision_score(y_test, pred, average='macro')
test_rec = recall_score(y_test, pred, average='macro')
test_f1s = f1_score(y_test, pred, average='macro')

In [None]:
print("AUC \t\t:", test_auc)
print("Precision \t:", test_pre)
print("Recall \t\t:", test_rec)
print("F1 score \t:", test_f1s)

In [None]:
from sklearn.metrics import confusion_matrix
con_mat = confusion_matrix(y_test,pred, labels=[1,0])
print(con_mat)

In [None]:
print("Actual Death :", len(y_test[y_test==1]))
print("Actual Survive :", len(y_test[y_test==0]))
print("Predicted Death :", len(pred[pred==1]))
print("Predicted Survive :", len(pred[pred==0]))
print()
print("Death Recall :", con_mat[0][0]/len(y_test[y_test==1])*100)
print("Survival Recall :", con_mat[1][1]/len(y_test[y_test==0])*100)
print()
print("Death Precision :", con_mat[0][0]/len(pred[pred==1])*100)
print("Survival Precision :", con_mat[1][1]/len(pred[pred==0])*100)

### Explainable

In [None]:
import matplotlib.pyplot as plt

explain_matrix, masks = tabnet.explain(x_test)

# for i in range(3):
#     plt.figure(figsize=(20, 400))
#     plt.imshow(masks[i][:50])

In [None]:
x_test.shape

In [None]:
important_feature = []

for i in range(len(explain_matrix)):
    imp_idx = np.where(explain_matrix[i]==max(explain_matrix[i]))[0][0] 
#     print("Test Case", i, ":", imp_idx)
    important_feature.append(imp_idx)

In [None]:
uniq_arr = np.unique(important_feature, return_counts=True)
print(uniq_arr)

In [None]:
most_idx = np.where(uniq_arr[1]==max(uniq_arr[1]))

In [None]:
first_imp_feature = uniq_arr[0][most_idx][0]

### Save model

In [None]:
# saving_path_name = "saveModel/tabnet_optuna1"
# saved_filepath = tabnet.save_model(saving_path_name)

### Load model

In [None]:
# loaded_tabnet = TabNetClassifier()
# loaded_tabnet.load_model(saved_filepath)

# max_epochs = 200

# loaded_tabnet.fit(
#     X_train=x_train, y_train=y_train,
#     eval_set=[(x_train, y_train), (x_valid, y_valid)],
#     eval_name=['train', 'valid'],
#     eval_metric=['auc'],
#     max_epochs=max_epochs , patience=20,
#     batch_size=1024, virtual_batch_size=128,
#     num_workers=0,
#     weights=1,
#     drop_last=False,
# )

### Reference
##### https://wsshin.tistory.com/5