In [20]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF
import torch.utils.data as data

from torchvision import transforms
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine

# Make Dataset pipline

In [22]:
class MRDataset(data.Dataset):
    def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None):
        super().__init__()
        self.task = task
        self.plane = plane
        self.root_dir = root_dir
        self.train = train
        if self.train:
            self.fold_path = self.root_dir + 'train/{}/'.format(plane)
            self.records = pd.read_csv(self.root_dir + 'train-{}.csv'.format(task), header=None, names=['id', 'label'])
        else:
            transform = None
            self.fold_path = self.root_dir + 'valid/{}/'.format(plane)
            self.records = pd.read_csv(self.root_dir + 'valid-{}.csv'.format(task), header=None, names=['id', 'label'])
            
        self.records['id'] = self.records['id'].map(lambda i: '0'*(4 - len(str(i))) + str(i))
        self.paths = [self.fold_path + filename + '.npy' for filename in self.records['id'].tolist()]
        self.labels = self.records['label'].tolist()
        
        self.transform = transform
        
        if weights is None:
            pos = np.sum(self.labels)
            neg = len(self.labels) - pos
            self.weights = torch.FloatTensor([1, neg/pos])
        else:
            self.weights = torch.FloatTensor(weights)
            
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        array = np.load(self.paths[index])
        label = self.labels[index]
        
        if label == 1:
            label = torch.FloatTensor([0, 1])
        elif label == 0:
            label = torch.FloatTensor([1, 0])
            
        if self.transform:
            array = self.transform(array)
        else:
            array = np.stack((array, )*3, axis=1)
            array = torch.FloatTensor(array)
            
#         if label.item() == 1:
#             weight = np.array([self.weights[1]])
#             weight = torch.FloatTensor(weight)
#         else:
#             weight = np.array([self.weights[0]])
#             weight = torch.FloatTensor(weight)
            
        return array, label, self.weights

# Build Model

In [28]:
class MRNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained_model = models.alexnet(pretrained=True)
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        self.classifer = nn.Linear(256, 2)
        
    def forward(self, x):
        x = torch.squeeze(x, dim=0)
        features = self.pretrained_model.features(x)
        pooled_features = self.pooling_layer(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
        output = self.classifer(flattened_features)
        return output

# Train Coronal plane

In [11]:
import shutil
import time
from datetime import datetime
import torch.optim as optim
from torch.autograd import Variable
from sklearn import metrics

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [34]:
def train_model(model, train_loader, epoch, num_epochs, optimizer, current_lr):
    model.train()
    
    if torch.cuda.is_available():
        model.cuda()
        
    y_preds = []
    y_trues = []
    losses = []
    
    for i, (image, label, weight) in enumerate(train_loader):
        optimizer.zero_grad()
        
        if torch.cuda.is_available():
            image, label, weight = image.cuda(), label.cuda(), weight.cuda()
            
#         label = label[0]
        weight = weight[0]
        
        prediction = model.forward(image.float())
        
        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
        loss.backward()
        optimizer.step()
        
        loss_value = loss.item()
        losses.append(loss_value)
        
        probas = torch.sigmoid(prediction)
        
        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())
        
        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5
            
        if (i % 100 == 0) & (i > 0):
            print('Epoch: {} / {} | Single batch number : {} / {} | avg train loss : {} | train auc : {} | lr : {}'.format(
            epoch+1, num_epochs, i, len(train_loader), np.round(np.mean(losses), 4), np.round(auc, 4), current_lr
            ))
            
    train_loss_epoch = np.round(np.mean(losses), 4)
    train_auc_epoch = np.round(auc, 4)
    return train_loss_epoch, train_auc_epoch

In [36]:
def evaluate_model(model, valid_loader, epoch, num_epochs, current_lr):
    model.eval()
    
    if torch.cuda.is_available():
        model.cuda()
        
    y_preds = []
    y_trues = []
    losses = []
    
    for i, (image, label, weight) in enumerate(valid_loader):
        
        if torch.cuda.is_available():
            image, label, weight = image.cuda(), label.cuda(), weight.cuda()
            
#         label = label[0]
        weight = weight[0]
        
        prediction = model(image.float())
        
        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
        
        loss_value = loss.item()
        losses.append(loss_value)
        
        probas = torch.sigmoid(prediction)
        
        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())
        
        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5
            
        if (i % 20 == 0) & (i > 0):
            print('Epoch: {} / {} | Single batch number : {} / {} | avg val loss : {} | val auc : {} | lr : {}'.format(
            epoch+1, num_epochs, i, len(valid_loader), np.round(np.mean(losses), 4), np.round(auc, 4), current_lr
            ))
            
    val_loss_epoch = np.round(np.mean(losses), 4)
    val_auc_epoch = np.round(auc, 4)
    return val_loss_epoch, val_auc_epoch

In [26]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

## Train Coronal

In [23]:
 augmentor = Compose([
        transforms.Lambda(lambda x: torch.Tensor(x)),
        RandomRotate(25),
        RandomTranslate([0.11, 0.11]),
        RandomFlip(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
    ])

coronal_train_dataset = MRDataset(root_dir='', task='acl', plane='coronal', transform=augmentor, train=True)
coronal_train_loader = torch.utils.data.DataLoader(coronal_train_dataset, batch_size=1, shuffle=True)

coronal_val_dataset = MRDataset(root_dir='', task='acl', plane='coronal', transform=augmentor, train=False)
coronal_val_loader = torch.utils.data.DataLoader(coronal_val_dataset, batch_size=1, shuffle=False)

In [40]:
for i, (image, label, weight) in enumerate(coronal_val_loader):
    if i == 0:
        print(image.size(), label.size(), weight.size())
    else:
        break

torch.Size([1, 25, 3, 256, 256]) torch.Size([1, 2]) torch.Size([1, 2])


In [None]:
coronal_mrnet = MRNet()
if torch.cuda.is_available():
    coronal_mrnet.cuda()

optimizer = torch.optim.Adam(coronal_mrnet.parameters(), lr=1e-5, weight_decay=0.1)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)

best_val_loss = float('inf')
best_val_acc = float(0)


In [43]:
# coronal_mrnet = MRNet()
# if torch.cuda.is_available():
#     coronal_mrnet.cuda()

# optimizer = torch.optim.Adam(coronal_mrnet.parameters(), lr=1e-5, weight_decay=0.1)

# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)

# best_val_loss = float('inf')
# best_val_acc = float(0)

NUM_EPOCHS = 20
iteration_change_loss = 0
t_start_training = time.time()

for epoch in range(NUM_EPOCHS):
    current_lr = get_lr(optimizer)
    
    t_start = time.time()
    
    train_loss, train_acc = train_model(model=coronal_mrnet, train_loader=coronal_train_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, optimizer=optimizer, current_lr=current_lr)
    val_loss, val_acc = evaluate_model(model=coronal_mrnet, valid_loader=coronal_val_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, current_lr=current_lr)
    
    scheduler.step(val_loss)
    
    t_end = time.time()
    delta = t_end - t_start
    print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
            train_loss, train_acc, val_loss, val_acc, delta))

    iteration_change_loss += 1
    print('-' * 30)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        file_name = 'coronal_best.pth'
        torch.save(coronal_mrnet, file_name)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        iteration_change_loss = 0
        
    if iteration_change_loss == 5:
        print('Early stopping after {0} iterations without the decrease of the val loss'.
              format(iteration_change_loss))
        break
        
t_end_training = time.time()
print('Training took {} s'.format(t_end_training - t_start_training))

Epoch: 1 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.1595 | train auc : 0.7798 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 200 / 1130 | avg train loss : 0.9981 | train auc : 0.8283 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0506 | train auc : 0.8239 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.0607 | train auc : 0.8244 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.0355 | train auc : 0.8206 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 600 / 1130 | avg train loss : 0.9806 | train auc : 0.8304 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.005 | train auc : 0.8213 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.001 | train auc : 0.8266 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.0289 | train auc : 0.8123 | lr : 1e-05
Epoch: 1 / 50 | Singl

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.2368 | train auc : 0.7411 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.1178 | train auc : 0.8178 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0236 | train auc : 0.8191 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 400 / 1130 | avg train loss : 0.9812 | train auc : 0.8323 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.9787 | train auc : 0.8308 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.0102 | train auc : 0.8082 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 700 / 1130 | avg train loss : 0.9953 | train auc : 0.8113 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.0224 | train auc : 0.8106 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.0127 | train auc : 0.8176 | lr : 1e-05
Epoch: 2 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 3 / 50 | Single batch number : 100 / 1130 | avg train loss : 0.9928 | train auc : 0.83 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 200 / 1130 | avg train loss : 0.9844 | train auc : 0.8273 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0077 | train auc : 0.8008 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 400 / 1130 | avg train loss : 0.9787 | train auc : 0.8086 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.9814 | train auc : 0.8101 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 600 / 1130 | avg train loss : 0.9858 | train auc : 0.8035 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 700 / 1130 | avg train loss : 0.966 | train auc : 0.8181 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 800 / 1130 | avg train loss : 0.9549 | train auc : 0.8263 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 900 / 1130 | avg train loss : 0.9639 | train auc : 0.8331 | lr : 1e-05
Epoch: 3 / 50 | Single

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 4 / 50 | Single batch number : 100 / 1130 | avg train loss : 0.6892 | train auc : 0.9403 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.0316 | train auc : 0.8289 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 300 / 1130 | avg train loss : 0.9988 | train auc : 0.8338 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.0021 | train auc : 0.8321 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.9517 | train auc : 0.8393 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 600 / 1130 | avg train loss : 0.9891 | train auc : 0.8494 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 700 / 1130 | avg train loss : 0.9795 | train auc : 0.8554 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 800 / 1130 | avg train loss : 0.9802 | train auc : 0.8591 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 900 / 1130 | avg train loss : 0.9625 | train auc : 0.8608 | lr : 1e-05
Epoch: 4 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 6 / 50 | Single batch number : 100 / 1130 | avg train loss : 0.6581 | train auc : 0.9157 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 200 / 1130 | avg train loss : 0.7785 | train auc : 0.8805 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 300 / 1130 | avg train loss : 0.8205 | train auc : 0.8716 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 400 / 1130 | avg train loss : 0.8609 | train auc : 0.8664 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.8438 | train auc : 0.8829 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 600 / 1130 | avg train loss : 0.8715 | train auc : 0.8767 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 700 / 1130 | avg train loss : 0.8759 | train auc : 0.8771 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 800 / 1130 | avg train loss : 0.8778 | train auc : 0.8788 | lr : 3e-06
Epoch: 6 / 50 | Single batch number : 900 / 1130 | avg train loss : 0.9033 | train auc : 0.8749 | lr : 3e-06
Epoch: 6 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Early stopping after 5 iterations without the decrease of the val loss
Training took 4142.865465164185 s


## Train Axial

In [46]:
augmentor = Compose([
        transforms.Lambda(lambda x: torch.Tensor(x)),
        RandomRotate(25),
        RandomTranslate([0.11, 0.11]),
        RandomFlip(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
    ])

axial_train_dataset = MRDataset(root_dir='', task='acl', plane='axial', transform=augmentor, train=True)
axial_train_loader = torch.utils.data.DataLoader(axial_train_dataset, batch_size=1, shuffle=True)

axial_val_dataset = MRDataset(root_dir='', task='acl', plane='axial', transform=augmentor, train=False)
axial_val_loader = torch.utils.data.DataLoader(axial_val_dataset, batch_size=1, shuffle=False)

axial_mrnet = MRNet()
if torch.cuda.is_available():
    axial_mrnet.cuda()

optimizer = torch.optim.Adam(axial_mrnet.parameters(), lr=1e-5, weight_decay=0.1)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)

best_val_loss = float('inf')
best_val_acc = float(0)


In [47]:
NUM_EPOCHS = 50
iteration_change_loss = 0
t_start_training = time.time()

for epoch in range(NUM_EPOCHS):
    current_lr = get_lr(optimizer)
    
    t_start = time.time()
    
    train_loss, train_acc = train_model(model=axial_mrnet, train_loader=axial_train_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, optimizer=optimizer, current_lr=current_lr)
    val_loss, val_acc = evaluate_model(model=axial_mrnet, valid_loader=axial_val_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, current_lr=current_lr)
    
    scheduler.step(val_loss)
    
    t_end = time.time()
    delta = t_end - t_start
    print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
            train_loss, train_acc, val_loss, val_acc, delta))

    iteration_change_loss += 1
    print('-' * 30)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        file_name = 'axial_best.pth'
        torch.save(axial_mrnet, file_name)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        iteration_change_loss = 0
        
    if iteration_change_loss == 5:
        print('Early stopping after {0} iterations without the decrease of the val loss'.
              format(iteration_change_loss))
        break
        
t_end_training = time.time()
print('Training took {} s'.format(t_end_training - t_start_training))

Epoch: 1 / 50 | Single batch number : 100 / 1130 | avg train loss : 4.3027 | train auc : 0.3882 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 200 / 1130 | avg train loss : 3.6715 | train auc : 0.4519 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 300 / 1130 | avg train loss : 3.5081 | train auc : 0.5104 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 400 / 1130 | avg train loss : 3.3185 | train auc : 0.5299 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 500 / 1130 | avg train loss : 2.9846 | train auc : 0.5463 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 600 / 1130 | avg train loss : 2.8303 | train auc : 0.5429 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 700 / 1130 | avg train loss : 2.659 | train auc : 0.5539 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 800 / 1130 | avg train loss : 2.4993 | train auc : 0.571 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 900 / 1130 | avg train loss : 2.3824 | train auc : 0.5719 | lr : 1e-05
Epoch: 1 / 50 | Singl

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.4547 | train auc : 0.6767 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.6555 | train auc : 0.635 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.556 | train auc : 0.6339 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.4707 | train auc : 0.6664 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.3719 | train auc : 0.6788 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.3701 | train auc : 0.6829 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.3557 | train auc : 0.6848 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.3215 | train auc : 0.6926 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.3182 | train auc : 0.6943 | lr : 1e-05
Epoch: 2 / 50 | Singl

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 6 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.254 | train auc : 0.7795 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.1802 | train auc : 0.8055 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.1185 | train auc : 0.806 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.0866 | train auc : 0.7993 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.0565 | train auc : 0.8006 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.0421 | train auc : 0.8035 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.0556 | train auc : 0.793 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.0633 | train auc : 0.7845 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.09 | train auc : 0.7719 | lr : 1e-05
Epoch: 6 / 50 | Single b

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 7 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.0449 | train auc : 0.7697 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 200 / 1130 | avg train loss : 0.9898 | train auc : 0.7664 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0778 | train auc : 0.7653 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.0546 | train auc : 0.7866 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.0529 | train auc : 0.7914 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.0709 | train auc : 0.7896 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.0761 | train auc : 0.7887 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.0962 | train auc : 0.7871 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.0702 | train auc : 0.7861 | lr : 1e-05
Epoch: 7 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 8 / 50 | Single batch number : 100 / 1130 | avg train loss : 0.8775 | train auc : 0.8424 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 200 / 1130 | avg train loss : 0.9218 | train auc : 0.854 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 300 / 1130 | avg train loss : 0.9806 | train auc : 0.8342 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 400 / 1130 | avg train loss : 0.9935 | train auc : 0.8246 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.9858 | train auc : 0.8234 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.025 | train auc : 0.8167 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 700 / 1130 | avg train loss : 0.9928 | train auc : 0.8305 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.04 | train auc : 0.8253 | lr : 3e-06
Epoch: 8 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.024 | train auc : 0.8256 | lr : 3e-06
Epoch: 8 / 50 | Single b

  "type " + obj.__name__ + ". It won't be checked "


Early stopping after 5 iterations without the decrease of the val loss
Training took 7006.245965003967 s


## Train Sagittal

In [48]:
augmentor = Compose([
        transforms.Lambda(lambda x: torch.Tensor(x)),
        RandomRotate(25),
        RandomTranslate([0.11, 0.11]),
        RandomFlip(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
    ])

sagittal_train_dataset = MRDataset(root_dir='', task='acl', plane='sagittal', transform=augmentor, train=True)
sagittal_train_loader = torch.utils.data.DataLoader(sagittal_train_dataset, batch_size=1, shuffle=True)

sagittal_val_dataset = MRDataset(root_dir='', task='acl', plane='sagittal', transform=augmentor, train=False)
sagittal_val_loader = torch.utils.data.DataLoader(sagittal_val_dataset, batch_size=1, shuffle=False)

sagittal_mrnet = MRNet()
if torch.cuda.is_available():
    sagittal_mrnet.cuda()

optimizer = torch.optim.Adam(sagittal_mrnet.parameters(), lr=1e-5, weight_decay=0.1)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)

best_val_loss = float('inf')
best_val_acc = float(0)


In [49]:
NUM_EPOCHS = 50
iteration_change_loss = 0
t_start_training = time.time()

for epoch in range(NUM_EPOCHS):
    current_lr = get_lr(optimizer)
    
    t_start = time.time()
    
    train_loss, train_acc = train_model(model=sagittal_mrnet, train_loader=sagittal_train_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, optimizer=optimizer, current_lr=current_lr)
    val_loss, val_acc = evaluate_model(model=sagittal_mrnet, valid_loader=sagittal_val_loader,
                                       epoch=epoch, num_epochs=NUM_EPOCHS, current_lr=current_lr)
    
    scheduler.step(val_loss)
    
    t_end = time.time()
    delta = t_end - t_start
    print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
            train_loss, train_acc, val_loss, val_acc, delta))

    iteration_change_loss += 1
    print('-' * 30)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        file_name = 'sagittal_best.pth'
        torch.save(sagittal_mrnet, file_name)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        iteration_change_loss = 0
        
    if iteration_change_loss == 5:
        print('Early stopping after {0} iterations without the decrease of the val loss'.
              format(iteration_change_loss))
        break
        
t_end_training = time.time()
print('Training took {} s'.format(t_end_training - t_start_training))

Epoch: 1 / 50 | Single batch number : 100 / 1130 | avg train loss : 6.3007 | train auc : 0.4275 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 200 / 1130 | avg train loss : 4.4658 | train auc : 0.4677 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 300 / 1130 | avg train loss : 4.0244 | train auc : 0.4909 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 400 / 1130 | avg train loss : 3.5293 | train auc : 0.5045 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 500 / 1130 | avg train loss : 3.2177 | train auc : 0.5172 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 600 / 1130 | avg train loss : 2.9455 | train auc : 0.5403 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 700 / 1130 | avg train loss : 2.7955 | train auc : 0.558 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 800 / 1130 | avg train loss : 2.6942 | train auc : 0.5765 | lr : 1e-05
Epoch: 1 / 50 | Single batch number : 900 / 1130 | avg train loss : 2.555 | train auc : 0.5778 | lr : 1e-05
Epoch: 1 / 50 | Singl

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.1558 | train auc : 0.6723 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.2041 | train auc : 0.6816 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.388 | train auc : 0.674 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.48 | train auc : 0.6627 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.422 | train auc : 0.6623 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.416 | train auc : 0.649 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.4022 | train auc : 0.6452 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.4221 | train auc : 0.6335 | lr : 1e-05
Epoch: 2 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.3881 | train auc : 0.6323 | lr : 1e-05
Epoch: 2 / 50 | Single bat

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 3 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.2606 | train auc : 0.5801 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.2594 | train auc : 0.7222 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.2974 | train auc : 0.6871 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.2997 | train auc : 0.6941 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.3125 | train auc : 0.6901 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.3024 | train auc : 0.6834 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.3488 | train auc : 0.6692 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.3401 | train auc : 0.6691 | lr : 1e-05
Epoch: 3 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.3421 | train auc : 0.6585 | lr : 1e-05
Epoch: 3 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 4 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.2513 | train auc : 0.7798 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.4916 | train auc : 0.7157 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.4026 | train auc : 0.7156 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.4043 | train auc : 0.7061 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.401 | train auc : 0.6905 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.3573 | train auc : 0.6943 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.3424 | train auc : 0.69 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.3098 | train auc : 0.6912 | lr : 1e-05
Epoch: 4 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.3 | train auc : 0.6904 | lr : 1e-05
Epoch: 4 / 50 | Single ba

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 5 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.0397 | train auc : 0.6364 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.1864 | train auc : 0.7404 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.1945 | train auc : 0.7563 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.2673 | train auc : 0.7484 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.2412 | train auc : 0.7499 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.2245 | train auc : 0.7393 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.1737 | train auc : 0.7289 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.1778 | train auc : 0.7202 | lr : 1e-05
Epoch: 5 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.1494 | train auc : 0.7248 | lr : 1e-05
Epoch: 5 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 6 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.3529 | train auc : 0.697 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.2471 | train auc : 0.7368 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.1167 | train auc : 0.7686 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.1413 | train auc : 0.7506 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.1814 | train auc : 0.7374 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.1947 | train auc : 0.7392 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.1452 | train auc : 0.7477 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.1762 | train auc : 0.7456 | lr : 1e-05
Epoch: 6 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.2253 | train auc : 0.7295 | lr : 1e-05
Epoch: 6 / 50 | Sing

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 7 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.0925 | train auc : 0.8604 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.1609 | train auc : 0.8107 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.2283 | train auc : 0.7843 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.1718 | train auc : 0.779 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.2066 | train auc : 0.7781 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.2218 | train auc : 0.7631 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.1674 | train auc : 0.7737 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.1715 | train auc : 0.763 | lr : 1e-05
Epoch: 7 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.1287 | train auc : 0.7737 | lr : 1e-05
Epoch: 7 / 50 | Singl

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 8 / 50 | Single batch number : 100 / 1130 | avg train loss : 0.9526 | train auc : 0.7308 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.0167 | train auc : 0.7749 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0199 | train auc : 0.7758 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 400 / 1130 | avg train loss : 1.0107 | train auc : 0.7869 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 500 / 1130 | avg train loss : 1.0459 | train auc : 0.7729 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.0314 | train auc : 0.7689 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.0904 | train auc : 0.7629 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.0691 | train auc : 0.7709 | lr : 1e-05
Epoch: 8 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.0853 | train auc : 0.7617 | lr : 1e-05
Epoch: 8 / 50 | Sin

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 9 / 50 | Single batch number : 100 / 1130 | avg train loss : 1.1174 | train auc : 0.829 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 200 / 1130 | avg train loss : 1.048 | train auc : 0.7782 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 300 / 1130 | avg train loss : 1.0328 | train auc : 0.7769 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 400 / 1130 | avg train loss : 0.9881 | train auc : 0.7902 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 500 / 1130 | avg train loss : 0.9882 | train auc : 0.7783 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 600 / 1130 | avg train loss : 1.0592 | train auc : 0.7622 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 700 / 1130 | avg train loss : 1.0891 | train auc : 0.7478 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 800 / 1130 | avg train loss : 1.0791 | train auc : 0.7491 | lr : 1e-05
Epoch: 9 / 50 | Single batch number : 900 / 1130 | avg train loss : 1.104 | train auc : 0.7451 | lr : 1e-05
Epoch: 9 / 50 | Single