In [None]:
import numpy as np
import pandas as pd
import os
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import adabound
from pycox.evaluation.concordance import concordance_td
from sklearn.metrics import accuracy_score, roc_auc_score
torch.cuda.is_available()

In [None]:
class RAdam(optim.Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = group['buffer'][int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    elif self.degenerated_to_sgd:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    else:
                        step_size = -1
                    buffered[2] = step_size

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
                    p.data.copy_(p_data_fp32)
                elif step_size > 0:
                    if group['weight_decay'] != 0:
                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
                    p.data.copy_(p_data_fp32)

        return loss

In [None]:
mydata = pd.read_csv('D:/Cho Lab Dropbox/연구과제별정리/02 목적과제_SPPEC_암종별(임상)_상희회준/위암/03 머신러닝_회준종혁다혜/00 ML_Data/DL_data_220629.csv')

In [None]:
#mydata = mydata[mydata['Age']>74]

Redefine event indicator (Complication yes or no)

In [None]:
dat = mydata.copy()

In [None]:
dat['complication']=np.where(mydata['Clavien_Dindo']> 1, 1, 0)


In [None]:
event_data = dat['complication']
#event_data = mydata['Clavien_Dindo']
#pd.crosstab(mydata['OP_year'], mydata['Clavien_Dindo'])

In [None]:
pre_op_dat = dat[['Sex', 
        'Age',
        'ASA_Score',
        'Smoking',
        'Drinking',
        'BMI',
        'Hypertension',
        'Diabetes',
        'Ass_condition_grp',
        'Ass_lesion',
        'Clinical_Stage_grp ',
        'Histology',
        'Lauren',
        'Reconstruction',
        'Combined_Resection',
        'Platelets',
        'Albumin',
        'Cell_Count',
        'Hemoglobin_status',
        'Neutrophil_count_status'
]]

post_op_dat = mydata[['fStage_grp',
                'Location',
                'LN_Dissection',
                'Operation',
                'Intraop_cc',
                'OP_time',
                'Z_EBL'
               ]]

In [None]:
nsubject = dat.shape[0]

In [None]:
contvar = ['Age', "Platelets", "Albumin", "Cell_Count"]
catevar= pre_op_dat.columns[[x not in contvar for x in pre_op_dat.columns]]

contvar_post = ['OP_time', 'Z_EBL']
catevar_post = post_op_dat.columns[[x not in contvar_post for x in post_op_dat.columns]]

In [None]:
pre_op_dat2 = pre_op_dat.copy()
pre_op_dat2[catevar] = pre_op_dat2[catevar].astype("category")

post_op_dat2 = post_op_dat.copy()
post_op_dat2[catevar_post] = post_op_dat2[catevar_post].astype("category")

print(pre_op_dat2.dtypes, post_op_dat2.dtypes)

In [None]:
pre_op_dat3 = pd.get_dummies(pre_op_dat2, columns=catevar)
pre_dat = pre_op_dat3.to_numpy()

post_op_dat3 = pd.get_dummies(post_op_dat2, columns=catevar_post)

In [None]:
pre_op_dat3.isnull().sum()

In [None]:
post_op_dat3.isnull().sum()

In [None]:
missing_col = ['Z_EBL']
#Technique 1: Using mean to impute the missing values
for i in missing_col:
    post_op_dat3.loc[post_op_dat3.loc[:,i].isnull(),i]=post_op_dat3.loc[:,i].mean()

In [None]:
post_dat = post_op_dat3.to_numpy()

In [None]:
total_dat = np.concatenate([pre_dat, post_dat], 1)
print(pre_dat.shape, post_dat.shape, total_dat.shape)

In [None]:
OP_year = mydata['OP_year']
testindex = np.isin(OP_year, [2015, 2017, 2019, 2021])

In [None]:
total_train = total_dat[np.invert(testindex)].copy()
event_train = event_data.values[np.invert(testindex)].copy()

total_test = total_dat[testindex].copy()
event_test = event_data.values[testindex].copy()

ntrain = total_train.shape[0]
nsubject, ntrain, nsubject-ntrain

In [None]:
event_train

In [None]:
num_event = 1
batch_size = 256
hidden_size = 256

train_data = []
for i in range(ntrain):
    train_data.append([total_train[i], event_train[i]])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)
train_loader

In [None]:
cuda = torch.cuda.is_available() # False
if cuda:
    device = 'cuda:0'
else:
    device = 'cpu'

In [None]:
class MLP(nn.Module):
    
    def __init__(self, input_size, hidden_size=128, num_layer=1, num_event=1):
        super(MLP, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.num_event = num_event
        
        self.hidden_in = nn.Linear(self.input_size, self.hidden_size)
        hiddens = [
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.LeakyReLU(inplace=True),
        ]
        self.hiddens = nn.Sequential(*((num_layer-1)*hiddens))
        self.hidden_out = nn.Linear(self.hidden_size, 1)
        self.activation = nn.LeakyReLU(inplace=True)
                
    def forward(self, x):
        ## x: (batch, input_size)

        batch_size = x.shape[0]
        x = self.hidden_in(x)
        x = self.hiddens(self.activation(x))
        x = self.hidden_out(x)
        
        return x

In [None]:
model = MLP(input_size=41, hidden_size=128, num_layer=2, num_event=num_event).to(device)
model

In [None]:
criterion = nn.BCEWithLogitsLoss()

for weight_decay in [1e-3]:
    for hidden_size in [16, 32, 64, 128, 256]:
        for num_layer in [1, 2, 3]:
            
            path = 'D:/models/prepost_binary/MultiMLP_{}hiddensize_{}layers_{:.0e}'.format(hidden_size, num_layer, weight_decay)
            #if os.path.isfile(path):
            #    continue
            print(path[9:])

            model = MLP(input_size=total_train.shape[-1], hidden_size=hidden_size, num_layer=num_layer, num_event=num_event).to(device)
            #if os.path.isfile(path):
            #    model.load_state_dict(torch.load(path, map_location = device))

            lr = 1e-3
            optimizer = adabound.AdaBound(model.parameters(), lr=lr, weight_decay=0)

            loss_array = []
            patience = 0
            min_loss = np.inf
            for e in range(int(1e6)):

                loss_array_tmp = []

                for total_batch, event_batch in train_loader:

                    total_batch = total_batch.float()
                    event_batch = event_batch.reshape(-1,1).float()

                    y_pred = model(total_batch.to(device))

                    norm = 0.
                    for parameter in model.parameters():
                        norm += torch.norm(parameter, p=1)

                    loss1 = criterion(y_pred, event_batch.to(device))

                    loss = loss1 + weight_decay*norm
                    loss_array_tmp.append(loss1.item())

                    model.zero_grad()

                    loss.backward()

                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                    optimizer.step()

                loss_array.append(np.mean(loss_array_tmp))
                if e % 100 == 0:
                    print('Epoch: ' + str(e) + 
                          ', Loss: '+ f'{loss_array[-1]:.4e}')
                if min_loss > loss_array[-1]:
                    patience = 0
                    min_loss = loss_array[-1]
                    torch.save(model.state_dict(), path)
                else:
                    patience += 1

                torch.cuda.empty_cache()

                if patience > 1000:
                    break

            plt.plot(loss_array, label='Loss')
            plt.ylabel('loss')
            plt.xlabel('epoch')
            plt.yscale('log')
            plt.title(path[2:])
            plt.legend()
            plt.show()

            total_train_sort = torch.FloatTensor(total_train)
            total_test_sort = torch.FloatTensor(total_test)

            model.load_state_dict(torch.load(path, map_location = device))

            y_train = torch.sigmoid(model(total_train_sort.to(device))).detach().cpu().numpy()
            y_test = torch.sigmoid(model(total_test_sort.to(device))).detach().cpu().numpy()

            out_pred = np.where(y_train >= 0.5, 1, 0)
            acc_train = accuracy_score(event_train, out_pred.flatten())
            auc_train = roc_auc_score(event_train, y_train.flatten())
            
            out_pred = np.where(y_test >= 0.5, 1, 0)
            acc_test = accuracy_score(event_test, out_pred.flatten())
            auc_test = roc_auc_score(event_test, y_test.flatten())
            print('-------------------------------------------------------')
            print(path[9:])
            print('Train accuracy = {:.4f}, Test accuracy = {:.4f}'.format(acc_train, acc_test))
            print('Train AUC = {:.4f}, Test AUC = {:.4f}'.format(auc_train, auc_test))
            print('=======================================================')