In [1]:
import warnings

warnings.filterwarnings('ignore')

import os
import random
import numpy as np
import pandas as pd
from tab_transformer_pytorch import TabTransformer
import scipy.stats as st
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder, StandardScaler, OneHotEncoder
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

Couldn't import dot_parser, loading of dot files will not be possible.


In [2]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False

seed_everything()    

In [3]:
df = pd.read_csv('./impute_set/imp3.csv')

In [6]:
class CustomDataset(Dataset):
    def __init__(self, x, y, cat_col, numeric_col):
        self.X_cat = np.array(x[cat_col]).astype(np.int32)
        self.X_num = np.array(x[numeric_col]).astype(np.float32)
        self.y = np.array(y).astype(np.float32)
        
    def __len__(self):
        return len(self.X_cat)
    
    def __getitem__(self, idx):
        X_cat = torch.from_numpy(self.X_cat[idx])
        X_num = torch.from_numpy(self.X_num[idx])
        y = torch.from_numpy(self.y[idx])
        return X_cat, X_num, y
    

def preprocessing(df, numeric='minmax', category='label'):
    X = df.drop('BS3_1', axis=1)
    y = df[['BS3_1']]
    numeric_col = [
        'FEV1', 'FEV1FVC', 'age', 'BS6_3', 'BS6_2_1', 'BD1',
        '건강문해력', 'Total_slp_wk', 'EQ_5D', 'BE3_31', 'BE5_1', '질환유병기간'
    ]
    cat_col = []
    for col in X.columns:
        if col not in numeric_col:
            cat_col.append(col)

    df_num, df_cat = X[numeric_col], X[cat_col]
    if numeric == 'minmax':
        n_pre = MinMaxScaler()
    else:
        n_pre = StandardScaler()
    df_num = pd.DataFrame(n_pre.fit_transform(df_num), columns=df_num.columns)

    if category == 'label':
        c_pre = OrdinalEncoder()
        df_cat = pd.DataFrame(c_pre.fit_transform(df_cat), columns=df_cat.columns)
    else:
        c_pre = OneHotEncoder(sparse_output=False)
        df_cat = pd.DataFrame(c_pre.fit_transform(df_cat))

    X = pd.concat([df_num, df_cat], axis=1)
    uniques = []
    for col in cat_col:
        uniques.append(len(X[col].unique()))

    return X, y, uniques, numeric_col, cat_col    


def test_with_imputations(train_loader, test_loader, train_X, test_y, numeric_col, uniques):
    combined_array = np.column_stack((train_X[numeric_col].mean().values, train_X[numeric_col].std().values))

    # class_counts = torch.tensor([test_y.value_counts()[0], test_y.value_counts()[1]])
    # class_weights = 1.0 / class_counts
    # class_weights /= class_weights.sum()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_negatives = (test_y == 0).sum()
    num_positives = (test_y == 1).sum()
    class_weights = torch.tensor(num_negatives / num_positives, dtype=torch.float32).to(device)


    print(f'Weights: {class_weights}')

    device = torch.device('cuda')
    model = TabTransformer(
        categories=tuple(uniques),
        num_continuous=len(numeric_col),
        dim=64,dim_out=1,depth=6,heads=8,attn_dropout=.1,ff_dropout=.1,mlp_hidden_mults=(4,2),
        mlp_act=nn.ReLU(inplace=True),
        continuous_mean_std=torch.tensor(combined_array, dtype=torch.float32)
    )
    model = model.to(device)
    optim = Adam(model.parameters(), lr=.0001)
    # criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    best_f1 = 0.0
    best_epoch = 0
    epochs = 50
    for epoch in range(epochs):
        running_loss = 0.0
        model.train()
        for x_cat, x_num, yy in train_loader:
            optim.zero_grad()
            x_cat, x_num, yy = x_cat.to(device), x_num.to(device), yy.to(device)
            preds = model(x_cat, x_num)
            loss = criterion(preds.squeeze(), yy.squeeze())
            loss.backward()
            optim.step()
            running_loss += loss.item()
        # print(f'{epoch+1} Epoch | Loss: {running_loss/len(train_loader):.4f}')

        model.eval()
        val_loss = 0.0
        correct = 0
        val_preds = []
        val_targets = []
        with torch.no_grad():
            for x_cat, x_num, yy in test_loader:
                x_cat, x_num, yy = x_cat.to(device), x_num.to(device), yy.to(device)
                preds = model(x_cat, x_num)
                val_loss += criterion(preds.squeeze(), yy.squeeze()).item()
                yy = yy.detach().cpu().numpy().squeeze()
                preds = torch.sigmoid(preds).detach().cpu().numpy().squeeze()
                pred_labels = np.where(preds>=.5, 1, 0)
                correct += (pred_labels == yy).sum().item()
                
                val_preds.extend(pred_labels.tolist())
                val_targets.extend(yy.tolist())
        val_loss /= len(test_loader)
        val_f1 = f1_score(val_targets, val_preds, average='macro')
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_epoch = epoch+1
            torch.save(model.state_dict(), 'bestTabTransformer.pth')
            
    print(f'Best Epoch: {best_epoch} | Best F1 : {best_f1:.4f}')  
    return best_f1  


def test_with_5fold(df, numeric, category, shuffle=True):
    f1s = []
    X, y, uniques, numeric_col, cat_col = preprocessing(df, numeric, category)
    if shuffle:
        skf = StratifiedKFold(n_splits=5, shuffle=shuffle, random_state=42)
    else:
        skf = StratifiedKFold(n_splits=5, shuffle=shuffle)
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
        train_X, train_y = X.iloc[train_idx], y.iloc[train_idx]
        test_X, test_y = X.iloc[test_idx], y.iloc[test_idx]

        train_set = CustomDataset(train_X, train_y, cat_col, numeric_col)
        test_set = CustomDataset(test_X, test_y, cat_col, numeric_col)
        train_loader = DataLoader(train_set, batch_size=64, shuffle=True, pin_memory=True)
        test_loader = DataLoader(test_set, batch_size=64, shuffle=True, pin_memory=True)

        f1_value = test_with_imputations(
            train_loader, test_loader, train_X, test_y, numeric_col, uniques
        )
        f1s.append(f1_value)

    return f1s


def get_cv_results(f1s:list):
    f1s = np.array(f1s)
    mean_f1 = np.mean(f1s)
    std_f1 = np.std(f1s)
    ci95 = st.t.interval(.95, df=len(f1s)-1, loc=mean_f1, scale=std_f1/np.sqrt(len(f1s)))
    return mean_f1, std_f1, ci95

In [7]:
# minmax | label
f1s = test_with_5fold(df, numeric='minmax', category='label')
mean_f1, std_f1, ci95 = get_cv_results(f1s)
print(f1s)
print(f'CV Results: Mean {mean_f1:.2f} | Std {std_f1:.2f} | CI95% {ci95[0]:.2f}~{ci95[1]:.2f}')

Weights: tensor([6.], device='cuda:0')
Best Epoch: 2 | Best F1 : 0.8250
Weights: tensor([6.], device='cuda:0')
Best Epoch: 7 | Best F1 : 0.7883
Weights: tensor([6.], device='cuda:0')
Best Epoch: 1 | Best F1 : 0.7321
Weights: tensor([6.], device='cuda:0')
Best Epoch: 3 | Best F1 : 0.6764
Weights: tensor([5.3636], device='cuda:0')
Best Epoch: 3 | Best F1 : 0.7809
[0.825, 0.7883064516129032, 0.7321428571428572, 0.6764252696456087, 0.7808695652173914]
CV Results: Mean 0.76 | Std 0.05 | CI95% 0.70~0.82


In [8]:
# standard | label
f1s = test_with_5fold(df, numeric='standard', category='label')
mean_f1, std_f1, ci95 = get_cv_results(f1s)
print(f1s)
print(f'CV Results: Mean {mean_f1:.2f} | Std {std_f1:.2f} | CI95% {ci95[0]:.2f}~{ci95[1]:.2f}')

Weights: tensor([6.], device='cuda:0')
Best Epoch: 3 | Best F1 : 0.8179
Weights: tensor([6.], device='cuda:0')
Best Epoch: 9 | Best F1 : 0.7659
Weights: tensor([6.], device='cuda:0')
Best Epoch: 7 | Best F1 : 0.7486
Weights: tensor([6.], device='cuda:0')
Best Epoch: 2 | Best F1 : 0.6812
Weights: tensor([5.3636], device='cuda:0')
Best Epoch: 3 | Best F1 : 0.7989
[0.8179115570419918, 0.765886287625418, 0.7485632183908046, 0.6812386156648452, 0.7988505747126436]
CV Results: Mean 0.76 | Std 0.05 | CI95% 0.70~0.82
