In [1]:
import os
from random import randrange
from glob import glob

import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import CheXpertDataset, CheXpertImageDataset
from model import MultiLabelClassification
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import DenseNet121
import matplotlib.pyplot as plt
from torchvision import transforms

In [2]:
# Get Data
BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')
extra_valid_age_sex_df = pd.read_csv(os.path.join(DATA_DIR,'extra_valid_age_sex.csv'))
extra_valid_hidden_features = np.load(os.path.join(DATA_DIR,'extra_valid_hidden_features.npy'))
extra_valid_labels = np.load(os.path.join(DATA_DIR,'extra_valid_labels.npy'))
extra_valid_images = glob(os.path.join(DATA_DIR, 'extraValid', '*'))

In [3]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
# Create CSV's of extra_valid_labels
label_headers = ['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis',  'Pleural Effusion']
extra_valid_labels_df = pd.DataFrame(extra_valid_labels, columns=label_headers)

In [5]:
SEED = 123
set_all_seeds(SEED)

In [6]:
#Split into train and val sets
# Split into train and validation
train_images, val_images, train_labels, val_labels = train_test_split(extra_valid_images, extra_valid_labels_df,
                                                  test_size=0.2, random_state=SEED)

In [None]:
pretrain_set = CheXpertImageDataset(labels=train_labels, images=train_images, use_upsampling=False,
                                 image_size=320, mode='train', class_index=-1)
preval_set = CheXpertImageDataset(labels=val_labels, images=val_images, use_upsampling=False,
                                 image_size=320, mode='valid', class_index=-1)
pretrain_loader = DataLoader(pretrain_set,
                         batch_size=16,
                         shuffle=True,
                         num_workers=2)
preval_loader = DataLoader(preval_set,
                        batch_size=16,
                        shuffle=False,
                        num_workers=2)

In [None]:
# paramaters
#TODO: Find how many epochs to run
SEED = 123
BATCH_SIZE = 32
lr = 1e-4
weight_decay = 1e-5

# model
model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=5)
model = model.cuda()

# define loss & optimizer
CELoss = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# training
best_val_auc = 0 
for epoch in range(15):
    for idx, data in enumerate(pretrain_loader):
        train_data, train_labels = data
        train_data, train_labels  = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        loss = CELoss(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # validation  
        if idx % 400 == 0:
            model.eval()
            with torch.no_grad():    
                test_pred = []
                test_true = [] 
                for jdx, data in enumerate(preval_loader):
                    test_data, test_labels = data
                    test_data = test_data.cuda()
                    y_pred = model(test_data)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_labels.numpy())

            test_true = np.concatenate(test_true)
            test_pred = np.concatenate(test_pred)
            val_auc_mean =  roc_auc_score(test_true, test_pred) 
            model.train()

            if best_val_auc < val_auc_mean:
                best_val_auc = val_auc_mean
                torch.save(model.state_dict(), 'ce_pretrained_model.pth')

            print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc ))

In [None]:
def aug(p=0.5):
    return transforms.Compose([transforms.RandomHorizontalFlip()], p=p)

In [7]:
class_id = 1 # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion 
train_set = CheXpertImageDataset(labels=train_labels, images=train_images, use_upsampling=False,
                                 image_size=320, mode='train', class_index=-1)
val_set = CheXpertImageDataset(labels=val_labels, images=val_images, use_upsampling=False,
                                 image_size=320, mode='valid', class_index=-1)
train_loader = DataLoader(train_set,
                         batch_size=16,
                         shuffle=True,
                         num_workers=2)
val_loader = DataLoader(val_set,
                        batch_size=16,
                        shuffle=False,
                        num_workers=2)

Multi-label mode: True, Number of classes: [5]
------------------------------
Found 1608 images in total, 139 positive images, 1469 negative images
Cardiomegaly(C0): imbalance ratio is 0.0864

Found 1608 images in total, 1 positive images, 1607 negative images
Edema(C1): imbalance ratio is 0.0006

Found 1608 images in total, 13 positive images, 1595 negative images
Consolidation(C2): imbalance ratio is 0.0081

Found 1608 images in total, 106 positive images, 1502 negative images
Atelectasis(C3): imbalance ratio is 0.0659

Found 1608 images in total, 64 positive images, 1544 negative images
Pleural Effusion(C4): imbalance ratio is 0.0398

Multi-label mode: True, Number of classes: [5]
------------------------------
Found 403 images in total, 40 positive images, 363 negative images
Cardiomegaly(C0): imbalance ratio is 0.0993

Found 403 images in total, 1 positive images, 402 negative images
Edema(C1): imbalance ratio is 0.0025

Found 403 images in total, 4 positive images, 399 negative i

In [8]:
# paramaters
BATCH_SIZE = 16
imratio = train_set.imratio
lr = 0.05 # using smaller learning rate is better
gamma = 500
weight_decay = 1e-5
margin = 1.0

# model
set_all_seeds(SEED)
model = DenseNet121(pretrained=False, last_activation='sigmoid', activations='relu', num_classes=5)
model = model.cuda()

# load pretrained model
if True:
    PATH = 'ce_pretrained_model.pth' 
    state_dict = torch.load(PATH)
    state_dict.pop('classifier.weight', None)
    state_dict.pop('classifier.bias', None) 
    model.load_state_dict(state_dict, strict=False)


# define loss & optimizer
Loss = AUCMLoss(imratio=imratio)
optimizer = PESG(model, 
                 a=Loss.a, 
                 b=Loss.b, 
                 alpha=Loss.alpha, 
                 imratio=imratio, 
                 lr=lr, 
                 gamma=gamma, 
                 margin=margin, 
                 weight_decay=weight_decay)

best_val_auc = 0
for epoch in range(30):
    if epoch > 0 and epoch % 5 == 0:
        optimizer.update_regularizer(decay_factor=2)
    for idx, data in enumerate(train_loader):
        train_data, train_labels = data
        train_data, train_labels = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        loss = Loss(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # validation
        if idx % 400 == 0:
            model.eval()
            with torch.no_grad():    
                test_pred = []
                test_true = [] 
                for jdx, data in enumerate(val_loader):
                    test_data, test_label = data
                    test_data = test_data.cuda()
                    y_pred = model(test_data)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_label.numpy())

                test_true = np.concatenate(test_true)
                test_pred = np.concatenate(test_pred)
                val_auc =  roc_auc_score(test_true, test_pred) 
                model.train()

                if best_val_auc < val_auc:
                    best_val_auc = val_auc

            print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, lr=%.4f'%(epoch, idx, val_auc,  optimizer.lr))

print ('Best Val_AUC is %.4f'%best_val_auc)

Epoch=0, BatchID=0, Val_AUC=0.5639, lr=0.0500
Epoch=1, BatchID=0, Val_AUC=0.6253, lr=0.0500
Epoch=2, BatchID=0, Val_AUC=0.5871, lr=0.0500
Epoch=3, BatchID=0, Val_AUC=0.5788, lr=0.0500
Epoch=4, BatchID=0, Val_AUC=0.5619, lr=0.0500
Reducing learning rate to 0.02500 @ T=505!
Updating regularizer @ T=505!
Epoch=5, BatchID=0, Val_AUC=0.5607, lr=0.0250
Epoch=6, BatchID=0, Val_AUC=0.5275, lr=0.0250
Epoch=7, BatchID=0, Val_AUC=0.5382, lr=0.0250
Epoch=8, BatchID=0, Val_AUC=0.5243, lr=0.0250
Epoch=9, BatchID=0, Val_AUC=0.5357, lr=0.0250
Reducing learning rate to 0.01250 @ T=505!
Updating regularizer @ T=505!
Epoch=10, BatchID=0, Val_AUC=0.5490, lr=0.0125


KeyboardInterrupt: 