In [1]:
import os
from random import randrange
from glob import glob

import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from dataset import ChexpertDataset
from model import MultiLabelClassification
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from libauc.losses import AUCMLoss
from libauc.optimizers import PESG

In [2]:
# Get Data
BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')
extra_valid_age_sex_df = pd.read_csv(os.path.join(DATA_DIR,'extra_valid_age_sex.csv'))
extra_valid_hidden_features = np.load(os.path.join(DATA_DIR,'extra_valid_hidden_features.npy'))
extra_valid_labels = np.load(os.path.join(DATA_DIR,'extra_valid_labels.npy'))
extra_valid_images = glob(os.path.join(DATA_DIR, 'extraValid', '*'))

In [3]:
# Add age and sex to hidden feature vectors
extra_valid_age_sex_df = extra_valid_age_sex_df.replace(['F', 'M', 'O'], [0, 1, 2])
extra_valid_age_sex_df = extra_valid_age_sex_df.apply(pd.to_numeric)
extra_valid_age_sex_np = extra_valid_age_sex_df.to_numpy()
hidden_with_age_sex = np.concatenate((extra_valid_hidden_features, extra_valid_age_sex_np), axis=1)
#hidden_with_age_sex = hidden_with_age_sex.astype('float32')
hidden_with_age_sex = extra_valid_hidden_features.astype('float32')

In [4]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed = 144
set_all_seeds(seed)

In [5]:
# Split into train and validation
X_train, X_val, y_train, y_val = train_test_split(hidden_with_age_sex, extra_valid_labels, test_size=0.2, random_state=seed)

In [6]:
# Label smoothing
y_train[y_train == -1] = 1
y_val[y_val == -1] = 1

In [7]:
# Load data into DataLoader
train_set = ChexpertDataset(X_train, y=y_train, scale_X=True)
val_set = ChexpertDataset(X_val, y=y_val, scale_X=True)
train_loader = DataLoader(train_set,
                         batch_size=16,
                         shuffle=True,
                         num_workers=2)
val_loader = DataLoader(val_set,
                        batch_size=16,
                        shuffle=False,
                        num_workers=2)

In [8]:
# paramaters
BATCH_SIZE = 16
imratio = 0.3424
lr = 0.05 # using smaller learning rate is better
gamma = 500
weight_decay = 1e-5
margin = 1.0
features = X_train.shape[1]
classes = y_train.shape[1]

# model
model = MultiLabelClassification(num_feature=features, num_class=classes)
model = model.cuda()

# define loss & optimizer
Loss = AUCMLoss(imratio=imratio)
optimizer = PESG(model, 
                 a=Loss.a, 
                 b=Loss.b, 
                 alpha=Loss.alpha, 
                 imratio=imratio, 
                 lr=lr, 
                 gamma=gamma, 
                 margin=margin, 
                 weight_decay=weight_decay)

best_val_auc = 0
for epoch in range(24):
    if epoch % 4 == 0 and epoch > 0:
         optimizer.update_regularizer(decay_factor=5)
    for idx, data in enumerate(train_loader):
        train_data, train_labels = data
        train_data, train_labels = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        loss = Loss(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # validation
        if idx % 400 == 0:
            model.eval()
            with torch.no_grad():    
                test_pred = []
                test_true = [] 
                for jdx, data in enumerate(val_loader):
                    test_data, test_label = data
                    test_data = test_data.cuda()
                    y_pred = model(test_data)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_label.numpy())

                test_true = np.concatenate(test_true)
                test_pred = np.concatenate(test_pred)
                val_auc =  roc_auc_score(test_true, test_pred) 
                model.train()

                if best_val_auc < val_auc:
                    best_val_auc = val_auc

            print('Epoch=%s, BatchID=%s, Val_AUC=%.4f, lr=%.4f'%(epoch, idx, val_auc,  optimizer.lr))

print ('Best Val_AUC is %.4f'%best_val_auc)

Epoch=0, BatchID=0, Val_AUC=0.5222, lr=0.0500
Epoch=1, BatchID=0, Val_AUC=0.6650, lr=0.0500
Epoch=2, BatchID=0, Val_AUC=0.7218, lr=0.0500
Epoch=3, BatchID=0, Val_AUC=0.7598, lr=0.0500
Reducing learning rate to 0.01000 @ T=404!
Updating regularizer @ T=404!
Epoch=4, BatchID=0, Val_AUC=0.7308, lr=0.0100
Epoch=5, BatchID=0, Val_AUC=0.7372, lr=0.0100
Epoch=6, BatchID=0, Val_AUC=0.7417, lr=0.0100
Epoch=7, BatchID=0, Val_AUC=0.7418, lr=0.0100
Reducing learning rate to 0.00200 @ T=404!
Updating regularizer @ T=404!
Epoch=8, BatchID=0, Val_AUC=0.7435, lr=0.0020
Epoch=9, BatchID=0, Val_AUC=0.7447, lr=0.0020
Epoch=10, BatchID=0, Val_AUC=0.7443, lr=0.0020
Epoch=11, BatchID=0, Val_AUC=0.7446, lr=0.0020
Reducing learning rate to 0.00040 @ T=404!
Updating regularizer @ T=404!
Epoch=12, BatchID=0, Val_AUC=0.7453, lr=0.0004
Epoch=13, BatchID=0, Val_AUC=0.7453, lr=0.0004
Epoch=14, BatchID=0, Val_AUC=0.7454, lr=0.0004
Epoch=15, BatchID=0, Val_AUC=0.7454, lr=0.0004
Reducing learning rate to 0.00008 @ T=4