In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.models as models

import os
import numpy as np
from sklearn import metrics
from tqdm import trange, tqdm

import matplotlib.pyplot as plt

import utilities as UT

In [2]:
def prep_data(LABEL_PATH ,TEST_NUM):
    # This function is used to prepare train/test labels for 5-fold cross-validation
    TEST_LABEL = LABEL_PATH + '/fold_' + str(TEST_NUM) +'.csv'

    # combine train labels
    filenames = [LABEL_PATH + '/fold_0.csv', 
                LABEL_PATH + '/fold_1.csv', 
                LABEL_PATH + '/fold_2.csv', 
                LABEL_PATH + '/fold_3.csv', 
                LABEL_PATH + '/fold_4.csv', ]

    filenames.remove(TEST_LABEL)

    with open(LABEL_PATH + '/combined_train_list.csv', 'w') as combined_train_list:
        for fold in filenames:
            for line in open(fold, 'r'):                
                combined_train_list.write(line)
    TRAIN_LABEL = LABEL_PATH + '/combined_train_list.csv'
    
    return TRAIN_LABEL, TEST_LABEL

In [3]:
class Dataset_Early_Fusion(Dataset):
    def __init__(self, 
                 label_file='/data/scratch/gliang/data/adni/ADNI2_MRI/train_list.csv'):         
        self.files = UT.read_csv(label_file)
    def __len__(self):
        return len(self.files)
    def __getitem__(self,idx):
        temp = self.files[idx]        
        full_path = temp[0]        
        
        label = full_path.split('/')[-2]
        if(label=='CN'):
            label=0
        elif(label=='AD'):
            label=1
        else:
            print('Label Error')
        
        im = np.load(full_path) 
        im = np.reshape(im, (1,110,110,110))
        return im, int(label), full_path # output image shape [T,C,W,H]

In [None]:
class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm3d(inplanes)
        self.relu = nn.ReLU()
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1)
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(x)
        out = self.relu(out)
        out = self.conv2(out)
        out += residual
        return out

class ResNet3D(nn.Module):
    def __init__(self, num_classes=2, input_shape=(1,110,110,110)): # input: input_shape:	[num_of_filters, kernel_size] (e.g. [256, 25])
        super(ResNet3D, self).__init__()
        #stage 1
        self.conv1 = nn.Sequential(
            nn.Conv3d(
            in_channels=input_shape[0],        
            out_channels=32,       
            kernel_size=(3,3,3),         
            padding=1
            ),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(
            in_channels=32,       
            out_channels=32,      
            kernel_size=(3,3,3),          
            padding=1              
            ),
            nn.BatchNorm3d(32),
            nn.ReLU(),                  
            nn.Conv3d(
            in_channels=32,       
            out_channels=64,       
            kernel_size=(3,3,3), 
            stride=2,
            padding=1              
            )
        )
        #stage 2
        self.bot2=Bottleneck(64,64,1)
        #stage 3
        self.bot3=Bottleneck(64,64,1)
        #stage 4
        self.conv4=nn.Sequential(
            nn.BatchNorm3d(64),
            nn.Conv3d(
            in_channels=64,        # input height
            out_channels=64,       # n_filters
            kernel_size=(3,3,3),          # filter size
            padding=1,
            stride=2
            )
        )
        #stage 5
        self.bot5=Bottleneck(64,64,1)
        #stage 6
        self.bot6=Bottleneck(64,64,1)
        #stage 7
        self.conv7=nn.Sequential(
            nn.BatchNorm3d(64),
            nn.Conv3d(
            in_channels=64,        # input height
            out_channels=128,       # n_filters
            kernel_size=(3,3,3),          # filter size
            padding=1,
            stride=2
            )
        )
        #stage 8
        self.bot8=Bottleneck(128,128,1)
        
        #stage 9
        self.bot9=Bottleneck(128,128,1)
        
        #stage 10
        self.conv10=nn.Sequential(
            nn.MaxPool3d(kernel_size=(7,7,7)))
        
        fc1_output_features=128     
        self.fc1 = nn.Sequential(
             nn.Linear(1024, 128),
             nn.ReLU()
        )

        fc2_output_features=2           
        self.fc2 = nn.Sequential(
        nn.Linear(fc1_output_features, fc2_output_features),
        nn.Sigmoid()
        )

    def forward(self, x, drop_prob=0.8):
        x = self.conv1(x)
        #print(x.shape)  
        x = self.bot2(x)
        #print(x.shape)
        x = self.bot3(x)
        #print(x.shape)
        x = self.conv4(x)
        #print(x.shape) 
        x = self.bot5(x)
        #print(x.shape)
        x = self.bot6(x)
        #print(x.shape)
        x = self.conv7(x)
        #print(x.shape)        
        x = self.bot8(x)
        #print(x.shape) 
        x = self.bot9(x)
        #print(x.shape)
        x = self.conv10(x)
        #print(x.shape) 
        x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, num_filter * w * h)
        #print(x.shape)        
        x = self.fc1(x)
        x = self.fc2(x)
        #prob = self.out(x) # probability
        return x

In [None]:
LABEL_PATH = '/data/scratch/xxing/adni_dl/Preprocessed/ADNI2_MRI'

GPU = 5
BATCH_SIZE = 8
EPOCHS = 150

LR = 0.0001
LOSS_WEIGHTS = torch.tensor([1., 1.]) 

device = torch.device('cuda:'+str(GPU) if torch.cuda.is_available() else 'cpu')

In [None]:
def train(train_dataloader, val_dataloader):
    net = ResNet3D().to(device)
    
    #opt = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=0.001)
    opt = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9,nesterov=True)
#   scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma= 0.985)
#   scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 
#                                                   base_lr=LR, 
#                                                   max_lr=0.001, 
#                                                   step_size_up=100,
#                                                   cycle_momentum=False)
    loss_fcn = torch.nn.CrossEntropyLoss(weight=LOSS_WEIGHTS.to(device))
        
    t = trange(EPOCHS, desc=' ', leave=True)

    train_hist = []
    val_hist = []
    pred_result = []
    old_acc = 0
    test_acc = 0
    best_epoch = 0
    for e in t:    
        y_true = []
        y_pred = []
        
        val_y_true = []
        val_y_pred = []                
        
        train_loss = 0
        val_loss = 0

        # training
        net.train()
        for step, (img, label, _) in enumerate(train_dataloader):
            img = img.float().to(device)
            label = label.long().to(device)
            opt.zero_grad()
            out = net(img)
            loss = loss_fcn(out, label)

            loss.backward()
            opt.step()
            
            label = label.cpu().detach()
            out = out.cpu().detach()
            y_true, y_pred = UT.assemble_labels(step, y_true, y_pred, label, out)        

            train_loss += loss.item()

        train_loss = train_loss/(step+1)
        acc = float(torch.sum(torch.max(y_pred, 1)[1]==y_true))/ float(len(y_pred))
        auc = metrics.roc_auc_score(y_true, y_pred[:,1])
        f1 = metrics.f1_score(y_true, torch.max(y_pred, 1)[1])
        precision = metrics.precision_score(y_true, torch.max(y_pred, 1)[1])
        recall = metrics.recall_score(y_true, torch.max(y_pred, 1)[1])
        ap = metrics.average_precision_score(y_true, torch.max(y_pred, 1)[1]) #average_precision

        #scheduler.step()

        # val
        net.eval()
        full_path = []
        with torch.no_grad():
            for step, (img, label, _) in enumerate(val_dataloader):
                img = img.float().to(device)
                label = label.long().to(device)
                out = net(img)
                loss = loss_fcn(out, label)
                val_loss += loss.item()

                label = label.cpu().detach()
                out = out.cpu().detach()
                val_y_true, val_y_pred = UT.assemble_labels(step, val_y_true, val_y_pred, label, out)
                
                for item in _:
                    full_path.append(item)
                
        val_loss = val_loss/(step+1)
        val_acc = float(torch.sum(torch.max(val_y_pred, 1)[1]==val_y_true))/ float(len(val_y_pred))
        val_auc = metrics.roc_auc_score(val_y_true, val_y_pred[:,1])
        val_f1 = metrics.f1_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_precision = metrics.precision_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_recall = metrics.recall_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_ap = metrics.average_precision_score(val_y_true, torch.max(val_y_pred, 1)[1]) #average_precision


        train_hist.append([train_loss, acc, auc, f1, precision, recall, ap])
        val_hist.append([val_loss, val_acc, val_auc, val_f1, val_precision, val_recall, val_ap])             

        t.set_description("Epoch: %i, train loss: %.4f, train acc: %.4f, val loss: %.4f, val acc: %.4f, test acc: %.4f" 
                          %(e, train_loss, acc, val_loss, val_acc, test_acc))


        if(old_acc<val_acc):
            old_acc = val_acc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision
            
        if(old_acc==val_acc) and (old_auc<val_auc):
            old_acc = val_acc
            old_auc = val_auc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision

            
            test_performance = [best_epoch, test_loss, test_acc, test_auc, test_f1, test_precision, test_recall, test_ap]

    return train_hist, val_hist, test_performance, test_y_true, test_y_pred, full_path

In [None]:
#DATA_PATH = '/data/scratch/gliang/data/adni/ADNI2_MRI_Feature/Alex_Layer-9_DynamicImage'
#FEATURE_SHAPE=(256,5,5)
#print('DATA_PATH:',DATA_PATH)

train_hist = []
val_hist = []
test_performance = []
test_y_true = np.asarray([])
test_y_pred = np.asarray([])
full_path = np.asarray([])
for i in range(0, 5):
    print('Train Fold', i)
    
    TEST_NUM = i
    TRAIN_LABEL, TEST_LABEL = prep_data(LABEL_PATH, TEST_NUM)
    
    train_dataset = Dataset_Early_Fusion(label_file=TRAIN_LABEL)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

    val_dataset = Dataset_Early_Fusion(label_file=TEST_LABEL)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        
    cur_result = train(train_dataloader, val_dataloader)
    
    train_hist.append(cur_result[0])
    val_hist.append(cur_result[1]) 
    test_performance.append(cur_result[2]) 
    test_y_true = np.concatenate((test_y_true, cur_result[3].numpy()))
    if(len(test_y_pred) == 0):
        test_y_pred = cur_result[4].numpy()
    else:
        test_y_pred = np.vstack((test_y_pred, cur_result[4].numpy()))
    full_path = np.concatenate((full_path, np.asarray(cur_result[5])))
    print('finish')

print(test_performance)

test_y_true = torch.tensor(test_y_true)
test_y_pred = torch.tensor(test_y_pred)

test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true.long()))/ float(len(test_y_pred))
test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1])

print('ACC %.4f, AUC %.4f, F1 %.4f, Prec %.4f, Recall %.4f, AP %.4f' 
      %(test_acc, test_auc, test_f1, test_precision, test_recall, test_ap))