In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.models as models

import os
import numpy as np
from sklearn import metrics
from tqdm import trange, tqdm

import matplotlib.pyplot as plt

import utilities as UT
from ranksvm import get_dynamic_image

In [2]:
def prep_data(LABEL_PATH ,TEST_NUM):
    # This function is used to prepare train/test labels for 5-fold cross-validation
    TEST_LABEL = LABEL_PATH + '/fold_' + str(TEST_NUM) +'.csv'

    # combine train labels
    filenames = [LABEL_PATH + '/fold_0.csv', 
                LABEL_PATH + '/fold_1.csv', 
                LABEL_PATH + '/fold_2.csv', 
                LABEL_PATH + '/fold_3.csv', 
                LABEL_PATH + '/fold_4.csv', ]

    filenames.remove(TEST_LABEL)

    with open(LABEL_PATH + '/combined_train_list.csv', 'w') as combined_train_list:
        for fold in filenames:
            for line in open(fold, 'r'):                
                combined_train_list.write(line)
    TRAIN_LABEL = LABEL_PATH + '/combined_train_list.csv'
    
    return TRAIN_LABEL, TEST_LABEL

In [3]:
class Dataset_Early_Fusion(Dataset):
    def __init__(self, 
                 label_file='/data/scratch/xxing/adni_dl/Preprocessed/ADNI2_MRItrain_list.csv'):         
        self.files = UT.read_csv(label_file)
    def __len__(self):
        return len(self.files)
    def __getitem__(self,idx):
        temp = self.files[idx]        
        full_path = temp[0]        
        
        label = full_path.split('/')[-2]
        if(label=='CN'):
            label=0
        elif(label=='AD'):
            label=1
        else:
            print('Label Error')
        
        im = np.load(full_path) 
        im = get_dynamic_image(im)
        im = np.expand_dims(im,0)
        im = np.concatenate([im,im,im], 0)
        
        return im, int(label), full_path # output image shape [T,C,W,H]

In [4]:
class att(nn.Module):
    def __init__(self, input_channel):  
        "the soft attention module"
        super(att,self).__init__()
        self.channel_in = input_channel
    
        self.conv1 = nn.Sequential(
            nn.Conv2d(
            in_channels=input_channel,      
            out_channels=512,    
            kernel_size=1), 
            nn.ReLU()
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
            in_channels=512,      
            out_channels=256,    
            kernel_size=1), 
            nn.ReLU()
            )
        self.conv3 =nn.Sequential(
            nn.Conv2d(
            in_channels=256,      
            out_channels=64,    
            kernel_size=1), 
            nn.ReLU()
            )  
        self.conv4 =nn.Sequential(
            nn.Conv2d(
            in_channels=64,      
            out_channels=1,    
            kernel_size=1), 
            nn.Softmax(dim=2)
            )
    def forward(self, x):
        mask = x
        mask = self.conv1(mask)
        mask = self.conv2(mask)
        mask = self.conv3(mask)
        att = self.conv4(mask)
        #print(att.size())
        output = torch.mul(x, att)
        return output
    
class CNN(nn.Module):
    def __init__(self, 
                 num_classes=2, 
                 feature='Vgg11', 
                 feature_shape=(512,7,7),
                 pretrained=True, 
                 requires_grad=False):         
        
        super(CNN, self).__init__()

        # Feature Extraction
        if(feature=='Alex'):
            self.ft_ext = models.alexnet(pretrained=pretrained) 
            self.ft_ext_modules = list(list(self.ft_ext.children())[:-2][0][:9])            
            
        elif(feature=='Res34'):
            self.ft_ext = models.resnet34(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0:3]+list(self.ft_ext.children())[4:-2] # remove the Maxpooling layer
            
        elif(feature=='Res18'):
            self.ft_ext = models.resnet18(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0:3]+list(self.ft_ext.children())[4:-2] # remove the Maxpooling layer
            
        elif(feature=='Vgg16'):
            self.ft_ext = models.vgg16(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0][:30] # remove the Maxpooling layer
            
        elif(feature=='Vgg11'):
            self.ft_ext = models.vgg11(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0][:19] # remove the Maxpooling layer
            
        elif(feature=='Mobile'):
            self.ft_ext = models.mobilenet_v2(pretrained=pretrained) 
            self.ft_ext_modules=list(self.ft_ext.children())[0] # remove the Maxpooling layer
            
        self.ft_ext=nn.Sequential(*self.ft_ext_modules)                
        for p in self.ft_ext.parameters():
            p.requires_grad = requires_grad
            
        # Classifier
        if(feature=='Alex'):
            feature_shape=(256,5,5)
        elif(feature=='Res34'):
            feature_shape=(512,7,7)
        elif(feature=='Res18'):
            feature_shape=(512,7,7)
        elif(feature=='Vgg16'):
            feature_shape=(512,6,6)
        elif(feature=='Vgg11'):
            feature_shape=(512,6,6)
        elif(feature=='Mobile'):
            feature_shape=(1280,4,4)
            
        conv1_output_features = int(feature_shape[0])
        
        fc1_input_features = int(conv1_output_features*feature_shape[1]*feature_shape[2])
        fc1_output_features = int(conv1_output_features*2)
        fc2_output_features = int(fc1_output_features/4)
        
        self.attn=att(conv1_output_features)
                
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=feature_shape[0],      
                out_channels=conv1_output_features,    
                kernel_size=1,       
            ),
            nn.BatchNorm2d(conv1_output_features),
            nn.ReLU()
        )                    
        self.fc1 = nn.Sequential(
             nn.Linear(fc1_input_features, fc1_output_features),
             nn.BatchNorm1d(fc1_output_features),            
             nn.ReLU()
         )

        self.fc2 = nn.Sequential(
             nn.Linear(fc1_output_features, fc2_output_features),
             nn.BatchNorm1d(fc2_output_features),
             nn.ReLU()
         )
        
        self.out = nn.Linear(fc2_output_features, num_classes)
        
    def forward(self, x, drop_prob=0.5):
        x = self.ft_ext(x)
        #print(x.size())
        x= self.attn(x)
        #x = self.conv1(x)
        x = x.view(x.size(0), -1) 
        x = self.fc1(x)
        x = nn.Dropout(drop_prob)(x)
        x = self.fc2(x)
        x = nn.Dropout(drop_prob)(x)        
        prob = self.out(x) 
        
        return prob

In [5]:
def train(train_dataloader, val_dataloader, feature='Vgg11'):
    net = CNN(feature=feature).to(device)
    
    opt = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=0.001)
#     opt = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma= 0.985)
#     scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 
#                                                   base_lr=LR, 
#                                                   max_lr=0.001, 
#                                                   step_size_up=100,
#                                                   cycle_momentum=False)
    loss_fcn = torch.nn.CrossEntropyLoss(weight=LOSS_WEIGHTS.to(device))
        
    t = trange(EPOCHS, desc=' ', leave=True)

    train_hist = []
    val_hist = []
    pred_result = []
    old_acc = 0
    old_auc = 0
    test_acc = 0
    best_epoch = 0
    test_performance = []
    for e in t:    
        y_true = []
        y_pred = []
        
        val_y_true = []
        val_y_pred = []                
        
        train_loss = 0
        val_loss = 0

        # training
        net.train()
        for step, (img, label, _) in enumerate(train_dataloader):
            img = img.float().to(device)
            label = label.long().to(device)
            opt.zero_grad()
            out = net(img)
            loss = loss_fcn(out, label)

            loss.backward()
            opt.step()
            
            label = label.cpu().detach()
            out = out.cpu().detach()
            y_true, y_pred = UT.assemble_labels(step, y_true, y_pred, label, out)        

            train_loss += loss.item()

        train_loss = train_loss/(step+1)
        acc = float(torch.sum(torch.max(y_pred, 1)[1]==y_true))/ float(len(y_pred))
        auc = metrics.roc_auc_score(y_true, y_pred[:,1])
        f1 = metrics.f1_score(y_true, torch.max(y_pred, 1)[1])
        precision = metrics.precision_score(y_true, torch.max(y_pred, 1)[1])
        recall = metrics.recall_score(y_true, torch.max(y_pred, 1)[1])
        ap = metrics.average_precision_score(y_true, torch.max(y_pred, 1)[1]) #average_precision

        scheduler.step()

        # val
        net.eval()
        full_path = []
        with torch.no_grad():
            for step, (img, label, _) in enumerate(val_dataloader):
                img = img.float().to(device)
                label = label.long().to(device)
                out = net(img)
                loss = loss_fcn(out, label)
                val_loss += loss.item()

                label = label.cpu().detach()
                out = out.cpu().detach()
                val_y_true, val_y_pred = UT.assemble_labels(step, val_y_true, val_y_pred, label, out)
                
                for item in _:
                    full_path.append(item)
                
        val_loss = val_loss/(step+1)
        val_acc = float(torch.sum(torch.max(val_y_pred, 1)[1]==val_y_true))/ float(len(val_y_pred))
        val_auc = metrics.roc_auc_score(val_y_true, val_y_pred[:,1])
        val_f1 = metrics.f1_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_precision = metrics.precision_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_recall = metrics.recall_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_ap = metrics.average_precision_score(val_y_true, torch.max(val_y_pred, 1)[1]) #average_precision


        train_hist.append([train_loss, acc, auc, f1, precision, recall, ap])
        val_hist.append([val_loss, val_acc, val_auc, val_f1, val_precision, val_recall, val_ap])             

        t.set_description("Epoch: %i, train loss: %.4f, train acc: %.4f, val loss: %.4f, val acc: %.4f, test acc: %.4f" 
                          %(e, train_loss, acc, val_loss, val_acc, test_acc))


        if(old_acc<val_acc):
            old_acc = val_acc
            old_auc = val_auc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision

            test_performance = [best_epoch, test_loss, test_acc, test_auc, test_f1, test_precision, test_recall, test_ap]
        
        if(old_acc==val_acc) and (old_auc<val_auc):
            old_acc = val_acc
            old_auc = val_auc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision

            test_performance = [best_epoch, test_loss, test_acc, test_auc, test_f1, test_precision, test_recall, test_ap]
    return train_hist, val_hist, test_performance, test_y_true, test_y_pred, full_path

In [6]:
LABEL_PATH = '/data/scratch/xxing/adni_dl/Preprocessed/ADNI2_MRI'

GPU = 3
BATCH_SIZE = 16
EPOCHS = 150

LR = 0.0001
LOSS_WEIGHTS = torch.tensor([1., 1.]) 

device = torch.device('cuda:'+str(GPU) if torch.cuda.is_available() else 'cpu')

In [7]:
#DATA_PATH = '/data/scratch/gliang/data/adni/ADNI2_MRI_Feature/Alex_Layer-9_DynamicImage'
#FEATURE_SHAPE=(256,5,5)
#print('DATA_PATH:',DATA_PATH)

train_hist = []
val_hist = []
test_performance = []
test_y_true = np.asarray([])
test_y_pred = np.asarray([])
full_path = np.asarray([])
for i in range(0, 5):
    print('Train Fold', i)
    
    TEST_NUM = i
    TRAIN_LABEL, TEST_LABEL = prep_data(LABEL_PATH, TEST_NUM)
    
    train_dataset = Dataset_Early_Fusion(label_file=TRAIN_LABEL)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

    val_dataset = Dataset_Early_Fusion(label_file=TEST_LABEL)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        
    cur_result = train(train_dataloader, val_dataloader)
    
    train_hist.append(cur_result[0])
    val_hist.append(cur_result[1]) 
    test_performance.append(cur_result[2]) 
    test_y_true = np.concatenate((test_y_true, cur_result[3].numpy()))
    if(len(test_y_pred) == 0):
        test_y_pred = cur_result[4].numpy()
    else:
        test_y_pred = np.vstack((test_y_pred, cur_result[4].numpy()))
    full_path = np.concatenate((full_path, np.asarray(cur_result[5])))

print(test_performance)

test_y_true = torch.tensor(test_y_true)
test_y_pred = torch.tensor(test_y_pred)
test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true.long()))/ float(len(test_y_pred))
test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1])

print('ACC %.4f, AUC %.4f, F1 %.4f, Prec %.4f, Recall %.4f, AP %.4f' 
      %(test_acc, test_auc, test_f1, test_precision, test_recall, test_ap))

Train Fold 0


Epoch: 149, train loss: 0.0119, train acc: 1.0000, val loss: 1.0106, val acc: 0.6087, test acc: 0.7391: 100%|██████████| 150/150 [06:53<00:00,  2.75s/it]


Train Fold 1


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
Epoch: 149, train loss: 0.1654, train acc: 0.9639, val loss: 0.2534, val acc: 0.7647, test acc: 1.0000: 100%|██████████| 150/150 [06:44<00:00,  2.68s/it]


Train Fold 2


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
Epoch: 149, train loss: 0.4540, train acc: 0.8659, val loss: 0.4339, val acc: 0.6111, test acc: 0.8333: 100%|██████████| 150/150 [06:46<00:00,  2.64s/it]


Train Fold 3


Epoch: 149, train loss: 0.0336, train acc: 1.0000, val loss: 0.5678, val acc: 0.7143, test acc: 0.8929: 100%|██████████| 150/150 [06:57<00:00,  2.69s/it]


Train Fold 4


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
Epoch: 149, train loss: 0.0319, train acc: 1.0000, val loss: 0.9059, val acc: 0.6429, test acc: 0.9286: 100%|██████████| 150/150 [06:54<00:00,  2.72s/it]

[[115, 0.7638657465577126, 0.7391304347826086, 0.6428571428571428, 0.7999999999999999, 0.75, 0.8571428571428571, 0.7298136645962733], [114, 0.07645551860332489, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [63, 0.3206184506416321, 0.8333333333333334, 0.9125000000000001, 0.7999999999999999, 0.8571428571428571, 0.75, 0.753968253968254], [76, 0.5696732699871063, 0.8928571428571429, 0.8333333333333333, 0.8421052631578948, 0.8888888888888888, 0.8, 0.7825396825396825], [29, 0.47631120681762695, 0.9285714285714286, 0.7777777777777778, 0.9473684210526316, 0.9, 1.0, 0.9]]
ACC 0.8700, AUC 0.8495, F1 0.8687, Prec 0.8600, Recall 0.8776, AP 0.8147



