In [None]:
# data preprocessing and network architectures is modified from https://github.com/ahmedbesbes/mrnet

# This is a notebook to train a model on the MRNet dataset

In [None]:
import shutil
import os
import time
from datetime import datetime
import argparse
import numpy as np
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F

from dataloader_mrnet import MRDataset
import model_mrnet

from sklearn import metrics

In [None]:
task = 'acl' # choices are 'abnormal', 'acl', 'meniscus'
plane = 'sagittal' 
epochs = 50
lr = 1e-4
log_every = 100
threshold = 20 # instead of s=5, as the batch size is 1, 
# we set this to be much larger than before so that the fluctuations are reasonable


In [None]:
def train_model(model, train_loader, epoch, num_epochs, optimizer, current_lr, log_every=100):
    model.train()

    if torch.cuda.is_available():
        model.cuda()

    y_preds = []
    y_trues = []
    losses = []

    for i, (image, label, weight) in enumerate(train_loader):
        optimizer.zero_grad()

        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()
            weight = weight.cuda()
        
        label = label[0]
        weight = weight[0]

        prediction = model.forward(image.float())

        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        losses.append(loss_value)

        probas = torch.sigmoid(prediction)

        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())

        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5

        if (i % log_every == 0) & (i > 0):
            print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ]| avg train loss {4} | train auc : {5} | lr : {6}'''.
                  format(
                      epoch + 1,
                      num_epochs,
                      i,
                      len(train_loader),
                      np.round(np.mean(losses), 4),
                      np.round(auc, 4),
                      current_lr
                  )
                  )

    train_loss_epoch = np.round(np.mean(losses), 4)
    train_auc_epoch = np.round(auc, 4)
    return train_loss_epoch, train_auc_epoch

In [None]:
def evaluate_model(model, val_loader, epoch, num_epochs, current_lr, log_every=20):
    model.eval()

    if torch.cuda.is_available():
        model.cuda()

    y_trues = []
    y_preds = []
    losses = []

    for i, (image, label, weight) in enumerate(val_loader):

        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()
            weight = weight.cuda()

        label = label[0]
        weight = weight[0]

        prediction = model.forward(image.float())

        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)

        loss_value = loss.item()
        losses.append(loss_value)

        probas = torch.sigmoid(prediction)

        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())

        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5

        if (i % log_every == 0) & (i > 0):
            print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ] | avg val loss {4} | val auc : {5} | lr : {6}'''.
                  format(
                      epoch + 1,
                      num_epochs,
                      i,
                      len(val_loader),
                      np.round(np.mean(losses), 4),
                      np.round(auc, 4),
                      current_lr
                  )
                  )

    val_loss_epoch = np.round(np.mean(losses), 4)
    val_auc_epoch = np.round(auc, 4)
    return val_loss_epoch, val_auc_epoch

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
def find_GD(model, train_loader):
    """the function for computing the gradient disparity
    this functions gives the following output:
    avg_grad_dis: the avg gradient disparity between pairs of samples of the dataset
    """
    lr = 0.01
    opt = optim.SGD(model.parameters(), lr=lr)
    it = iter(train_loader)
    Ls = []
    # because the batch size is 1, we manually find loss std
    for i in range(40):
        image, label, weight = next(it)
        image = Variable(image.cuda(), requires_grad=True)
        label = Variable(label).cuda()
        weight = Variable(weight).cuda()

        label = label[0]
        weight = weight[0]
        opt.zero_grad()

        criterion = torch.nn.BCEWithLogitsLoss(weight=weight)

        prediction = model.forward(image.float())

        loss = criterion(prediction, label).item()
        Ls.append(loss)
        
    
    loss_std = np.std(np.array(Ls), axis=0)
    # set model in training mode (need this because of dropout)
    model.train() 
    cnt = 0
    avg_grad_dis = 0
    Grads = []
    it = iter(train_loader)
    for i in range(threshold):
        image, label, weight = next(it)
        image = Variable(image.cuda(), requires_grad=True)
        label = Variable(label).cuda()
        weight = Variable(weight).cuda()

        label = label[0]
        weight = weight[0]
        opt.zero_grad()
            
        criterion = torch.nn.BCEWithLogitsLoss(weight=weight)

        prediction = model.forward(image.float())

        loss = criterion(prediction, label)           
        loss1_s = loss/loss_std
        
        loss1_s.backward(retain_graph=True)
        grads1_s = []
        for name, param in model.named_parameters():
            if param.grad != None:
                grads1_s.append(param.grad.view(-1))
        grads1_s = torch.cat(grads1_s)
        Grads.append(grads1_s.data.cpu().numpy())
        cnt += 1
    
    Grads = np.array(Grads)
    cnt2 = 0
    for i in range(cnt):
        for j in range(cnt):
            if i < j:
                grads1 = Grads[i]
                grads2 = Grads[j]
                avg_grad_dis += np.linalg.norm(grads1-grads2)
                cnt2 += 1
    
        
    avg_grad_dis /= cnt2
    return avg_grad_dis

In [None]:
num_avgs = 5
num_epochs = 100

# these lists store the results over multiple runs
tLosses_avg = []
tAUCs_avg = []
vLosses_avg = []
vAUCs_avg = []
GD_avg = []

for i in range(num_avgs):
    train_dataset = MRDataset('./data/', task,
                              plane, train=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True, drop_last=False)
    validation_dataset = MRDataset(
        './data/', task, plane, train=False)
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset, batch_size=1, shuffle=-True, drop_last=False)

    mrnet = model_mrnet.MRNet()

    if torch.cuda.is_available():
        mrnet = mrnet.cuda()
    lr = 1e-4
    optimizer = optim.SGD(mrnet.parameters(), lr=lr)

    tLosses = []
    tAUCs = []
    vLosses = []
    vAUCs = []
    GD = []
    t_start_training = time.time()
    for epoch in range(num_epochs):
        current_lr = lr
        t_start = time.time()
        
        gd = find_GD(mrnet, train_loader)
        train_loss, train_auc = train_model(
            mrnet, train_loader, epoch, num_epochs, optimizer, current_lr, log_every)
        val_loss, val_auc = evaluate_model(
            mrnet, validation_loader, epoch, num_epochs,  current_lr)

        tLosses.append(train_loss)
        tAUCs.append(train_auc)
        vLosses.append(val_loss)
        vAUCs.append(val_auc)
        GD.append(gd)

        t_end = time.time()
        delta = t_end - t_start

        print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | pac {4} | elapsed time {5} s".format(
            train_loss, train_auc, val_loss, val_auc, gd, delta))
        print('-' * 30)

    t_end_training = time.time()
    print('training took {%s - %s} s'% (t_end_training, t_start_training))
    tLosses_avg.append(tLosses)
    tAUCs_avg.append(tAUCs)
    vLosses_avg.append(vLosses)
    vAUCs_avg.append(vAUCs)
    GD_avg.append(GD)


In [None]:
# to save the results in a file
List = [vLosses_avg, tLosses_avg, vAUCs_avg, tAUCs_avg, GD_avg]

# rename the file to have the results saved on
with open('temp.data', 'wb') as filehandle:
    # store the data as binary data stream
    for ls in List:
        pickle.dump(ls, filehandle) 
