In [1]:
# Auto Reload
%load_ext autoreload
%autoreload 2

### Data Preprocess

In [2]:
import pandas as pd
import numpy as np

x = pd.read_csv("./data/KAMIR_data_preprocessing_no_normalization.csv", encoding = 'unicode_escape')
y = pd.read_csv("./data/KAMIR_Labels1.csv", encoding = 'unicode_escape')
x = x.iloc[:, 1:]
y = y.iloc[:, 1:]
y = y.iloc[:, 20]
y = y.fillna(0)

In [3]:
idx = []
for i in range(len(x.columns)):
    idx.append(i)

In [4]:
cat_col = ['Sex',
 'ECG Change',
 'Chest pain',
 'FMC',
 'Killip Class at admission',
 'ST change on ECG',
 'Heart rhythm on ECG',
 'Smoking History',
 'Initial diagnosis',
 'STEMI',
 'NSTEMI',
 'Coronary artery stenosis',
 'Puncture route',
 'Target vessel',
 'Lesion type of target vessel',
 'Pre TIMI flow of target vessel',
 'Treatment of target Vessel',
 'Post TIMI flow of target vessel',
 'Result of PCI',
 'Status of revascularization',
 'MR grade',
 'Final diagnosis',
 '12M_MR grade',
 '12M_Oral hypoglycemic agent (1)',
 '12M_Oral anticoagulant',
 '12M_Ezetimide',
 '12M_Fibrate',
 '12M_Statin',
 '12M_Aspirin',
 '12M_ACE inhibitor',
 '12M_Beta-blocker',
 '12M_Ca-channel blocker',
 '12M_Ticagrelor',
 '12M_Cilostazol',
 '12M_Prasugurel',
 '12M_Clopidogrel',
 'on Treatment',
 'on Treatment.1',
 'on Treatment.2',
 'on Treatment.3',
 'on Treatment.4',
 'on Treatment.5',
 'DiagnosisCerebrovascular Disease',
 'on Treatment.6',
 'Menopause',
 'Hysterectomy History',
 'Number of involved vessels',
 'Index procedure',
 'Staged PCI',
 '12M_ARB',
 '12M_Omega3 FA',
 '12M_Oral hypoglycemic agent (3)',
 '12M_Oral hypoglycemic agent (2)']

cat_idx = []
for i in cat_col:
    cat_idx.append(x.columns.get_loc(i))

cat_dim = []

for i in cat_idx:
    cat_dim.append(max(x.iloc[:,i])+1)

In [5]:
for i in cat_idx:
    idx.remove(i)

In [6]:
from sklearn import preprocessing

x.iloc[:,idx] = preprocessing.normalize(x.iloc[:,idx])

In [7]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=22, stratify=y)
x_train, x_valid, y_train, y_valid =  train_test_split(x_train, y_train, test_size=0.2, random_state=22, stratify=y_train)

In [8]:
print("train :", len(y_train[y_train==1]))
print("valid :",len(y_valid[y_valid==1]))
print("test :",len(y_test[y_test==1]))

train : 578
valid : 144
test : 180


In [9]:
import torch
import torch.nn as nn

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

In [10]:
x_train = np.array(x_train)
x_valid = np.array(x_valid)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_valid = np.array(y_valid)
y_test = np.array(y_test)


# x_train = torch.tensor(x_train, device=device)
# y_train = torch.tensor(y_train)
# x_valid = torch.tensor(x_valid, device=device)
# y_valid = torch.tensor(y_valid, device=device)
# x_test = torch.tensor(x_test, device=device)
# y_test = torch.tensor(y_test, device=device)

### Model : TabNet

In [11]:
from pytorch_tabnet.tab_model import TabNetClassifier

In [12]:
tabnet = TabNetClassifier(cat_idxs=cat_idx,
                       cat_dims=cat_dim,
                       cat_emb_dim=1,
                       optimizer_fn=torch.optim.Adam,
                       optimizer_params=dict(lr=1e-2),
                       scheduler_params={"step_size":50,
                                         "gamma":0.9},
                       scheduler_fn=torch.optim.lr_scheduler.StepLR,
                       mask_type='sparsemax' # "sparsemax", entmax
                      )

Device used : cuda


In [13]:
max_epochs = 200

tabnet.fit(
    X_train=x_train, y_train=y_train,
    eval_set=[(x_train, y_train), (x_valid, y_valid)],
    eval_name=['train', 'valid'],
    eval_metric=['auc'],
    max_epochs=max_epochs , patience=20,
    batch_size=1024, virtual_batch_size=128,
    num_workers=0,
    weights=0,
    drop_last=False,
)

epoch 0  | loss: 0.91703 | train_auc: 0.45774 | valid_auc: 0.41076 |  0:00:01s
epoch 1  | loss: 0.39581 | train_auc: 0.62238 | valid_auc: 0.6069  |  0:00:02s
epoch 2  | loss: 0.2403  | train_auc: 0.71333 | valid_auc: 0.69395 |  0:00:04s
epoch 3  | loss: 0.19891 | train_auc: 0.81391 | valid_auc: 0.81656 |  0:00:05s
epoch 4  | loss: 0.16726 | train_auc: 0.86335 | valid_auc: 0.86902 |  0:00:06s
epoch 5  | loss: 0.15092 | train_auc: 0.89878 | valid_auc: 0.89684 |  0:00:08s
epoch 6  | loss: 0.13225 | train_auc: 0.91024 | valid_auc: 0.90742 |  0:00:09s
epoch 7  | loss: 0.12465 | train_auc: 0.91562 | valid_auc: 0.91071 |  0:00:10s
epoch 8  | loss: 0.1184  | train_auc: 0.9232  | valid_auc: 0.91635 |  0:00:12s
epoch 9  | loss: 0.11671 | train_auc: 0.93223 | valid_auc: 0.9204  |  0:00:13s
epoch 10 | loss: 0.11345 | train_auc: 0.93689 | valid_auc: 0.92564 |  0:00:14s
epoch 11 | loss: 0.11086 | train_auc: 0.93468 | valid_auc: 0.92624 |  0:00:15s
epoch 12 | loss: 0.11045 | train_auc: 0.93114 | vali

### Prediction and Confusion Matrix

In [14]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score

pred = tabnet.predict_proba(x_test)
test_auc = roc_auc_score(y_score=pred[:,1], y_true=y_test)

pred = tabnet.predict(x_test)
test_pre = precision_score(y_test, pred, average='macro')
test_rec = recall_score(y_test, pred, average='macro')
test_f1s = f1_score(y_test, pred, average='macro')

In [15]:
print("AUC :", test_auc)
print("Precision :", test_pre)
print("Recall :", test_rec)
print("F1 score :", test_f1s)

AUC : 0.958037263332579
Precision : 0.7220616570327554
Recall : 0.5156483367277664
F1 score : 0.5163985148514851


In [16]:
from sklearn.metrics import confusion_matrix
con_mat = confusion_matrix(y_test,pred, labels=[1,0])
print(con_mat)

[[   6  174]
 [   6 2940]]


In [17]:
print("Actual Death :", len(y_test[y_test==1]))
print("Actual Survive :", len(y_test[y_test==0]))
print("Predicted Death :", len(pred[pred==1]))
print("Predicted Survive :", len(pred[pred==0]))
print()
print("Death Recall :", con_mat[0][0]/len(y_test[y_test==1])*100)
print("Survival Recall :", con_mat[1][1]/len(y_test[y_test==0])*100)
print()
print("Death Precision :", con_mat[0][0]/len(pred[pred==1])*100)
print("Survival Precision :", con_mat[1][1]/len(pred[pred==0])*100)

Actual Death : 180
Actual Survive : 2946
Predicted Death : 12
Predicted Survive : 3114

Death Recall : 3.3333333333333335
Survival Recall : 99.79633401221996

Death Precision : 50.0
Survival Precision : 94.41233140655106


### Explainable

In [18]:
import matplotlib.pyplot as plt

explain_matrix, masks = tabnet.explain(x_test)

# for i in range(3):
#     plt.figure(figsize=(20, 400))
#     plt.imshow(masks[i][:50])

In [19]:
x_test.shape

(3126, 230)

In [20]:
important_feature = []

for i in range(len(explain_matrix)):
    imp_idx = np.where(explain_matrix[i]==max(explain_matrix[i]))[0][0] 
#     print("Test Case", i, ":", imp_idx)
    important_feature.append(imp_idx)

In [21]:
uniq_arr = np.unique(important_feature, return_counts=True)
print(uniq_arr)

(array([  9,  16,  18,  24,  43,  66,  67,  91,  92,  96,  98, 109, 112,
       118, 121, 136, 142, 149, 157, 159, 165, 167, 179, 188, 197, 200,
       203, 206, 209, 210, 225, 229]), array([   5,    4,    3,    5,    1,    4,   16,    4,    1,   80,   17,
          3,   99,    1,   15,   48,    2,    1,   45,    4,    1,  112,
         98,    8,  977,    2,    3,   17, 1538,    6,    3,    3]))


In [22]:
most_idx = np.where(uniq_arr[1]==max(uniq_arr[1]))

In [23]:
first_imp_feature = uniq_arr[0][most_idx][0]

In [24]:
x.columns[first_imp_feature]

'12M_ACE inhibitor'

### Save model

In [25]:
# saving_path_name = "saveModel/tabnet_2"
# saved_filepath = tabnet.save_model(saving_path_name)

### Load model

In [26]:
# loaded_tabnet = TabNetClassifier()
# loaded_tabnet.load_model(saved_filepath)

# max_epochs = 200

# loaded_tabnet.fit(
#     X_train=x_train, y_train=y_train,
#     eval_set=[(x_train, y_train), (x_valid, y_valid)],
#     eval_name=['train', 'valid'],
#     eval_metric=['auc'],
#     max_epochs=max_epochs , patience=20,
#     batch_size=1024, virtual_batch_size=128,
#     num_workers=0,
#     weights=1,
#     drop_last=False,
# )

### Reference
##### https://wsshin.tistory.com/5