In [1]:
# This code is used for JHU CS 482/682: Deep Learning 2019 Spring Project
# Copyright: Zhaoshuo Li, Ding Hao, Mingyi Zheng
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
# from torchvision import transforms
import torchvision.transforms.functional as TF

import copy
import numpy as np
import random
from tensorboardX import SummaryWriter

import transforms
from dataset import *
from visualization import *
from label_conversion import *
from dice_loss import *
from model_trainning import *

from unet import *

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda


# Seed pytorch and numpy and random

In [3]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

## Hyperparameters

In [4]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 70
num_class = 12

In [5]:
weights = torch.ones((num_class,1))
weights = weights.to(device)
dice_loss = DICELoss(weights) 

## Visualization

In [6]:
# Initialize the visualization environment
writer = SummaryWriter()

## Unet

In [7]:
# initialize model
model = unet(useBN=True)
model.to(device)

unet(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=

## Optimizer and Scheduler and loss

In [8]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

## Baseline, without augmentation

In [9]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [10]:
# define transform
train_dataset=MICCAIDataset(data_type = "train", transform=None)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

In [11]:
# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

## Start training

In [12]:
print("Training Started!")

# initialize best_acc for comparison
best_acc = 0.0
train_iter = 0
val_iter = 0

for epoch in range(num_epochs):
    print("\nEPOCH " +str(epoch+1)+" of "+str(num_epochs)+"\n")
    
    # train
    train_loss, train_iter = train(model,device,scheduler,optimizer,dice_loss,train_generator,train_dataset,writer,train_iter)

    # validate
    with torch.no_grad():
        validation_loss, tp, fp, fn, val_iter = validate(model,device,dice_loss,num_class,validation_generator,validation_dataset,writer,val_iter)
        epoch_acc = (2*tp + 1e-7)/ (2*tp+fp+fn+1e-7)
        epoch_acc = epoch_acc.mean()
    
        # loss
        writer.add_scalar('data/Training Loss (per epoch)',train_loss,epoch)
        writer.add_scalar('data/Validation Loss (per epoch)',validation_loss,epoch)
        
        # randomly show one validation image 
        sample = validation_dataset.__getitem__(random.randint(0,len(validation_dataset)-1))
        img = sample[0]*0.5+0.5
        label = sample[1]
        tmp_img = sample[0].reshape(1,3,256,320)
        pred = functional.softmax(model(tmp_img.cuda()), dim=1)
        pred_label = torch.max(pred,dim=1)[1]
        pred_label = pred_label.type(label.type())
        # to plot
        tp_img = np.array(img)
        tp_label = train_dataset.label_converter.label2color(label.permute(1,2,0)).transpose(2,0,1)
        tp_pred = train_dataset.label_converter.label2color(pred_label.permute(1,2,0)).transpose(2,0,1)
        
        writer.add_image('Input', tp_img, epoch)
        writer.add_image('Label', tp_label, epoch)
        writer.add_image('Prediction', tp_pred, epoch)
        
        # deep copy the model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

Training Started!

EPOCH 1 of 70

Epoch Loss: 0.8637
----------
Vaildation Loss: 0.7967
0 Class, True Pos 7684378.0, False Pos 4168689.0, Flase Neg 5421315.0
1 Class, True Pos 2070202.0, False Pos 2615101.0, Flase Neg 468997.0
2 Class, True Pos 874423.0, False Pos 1120524.0, Flase Neg 461007.0
3 Class, True Pos 48003.0, False Pos 342140.0, Flase Neg 515767.0
4 Class, True Pos 77455.0, False Pos 206603.0, Flase Neg 4860064.0
5 Class, True Pos 2347620.0, False Pos 4954450.0, Flase Neg 1146179.0
6 Class, True Pos 0.0, False Pos 0.0, Flase Neg 94186.0
7 Class, True Pos 0.0, False Pos 0.0, Flase Neg 139522.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 0.0, False Pos 0.0, Flase Neg 114383.0
10 Class, True Pos 1101180.0, False Pos 1634672.0, Flase Neg 1644299.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 2 of 70

Epoch Loss: 0.7352
----------
Vaildation Loss: 0.7007
0 Class, True Pos 9023306.0, False Pos 4619559.0, Flase Neg 4082387.


EPOCH 11 of 70

Epoch Loss: 0.3922
----------
Vaildation Loss: 0.3885
0 Class, True Pos 12023302.0, False Pos 2711390.0, Flase Neg 1082391.0
1 Class, True Pos 2118503.0, False Pos 242109.0, Flase Neg 420696.0
2 Class, True Pos 874231.0, False Pos 220953.0, Flase Neg 461199.0
3 Class, True Pos 427825.0, False Pos 136948.0, Flase Neg 135945.0
4 Class, True Pos 3643480.0, False Pos 547556.0, Flase Neg 1294039.0
5 Class, True Pos 2771661.0, False Pos 593029.0, Flase Neg 722138.0
6 Class, True Pos 57557.0, False Pos 27372.0, Flase Neg 36629.0
7 Class, True Pos 113319.0, False Pos 20199.0, Flase Neg 26203.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 67223.0, False Pos 40113.0, Flase Neg 47160.0
10 Class, True Pos 2264242.0, False Pos 192227.0, Flase Neg 481237.0
11 Class, True Pos 73333.0, False Pos 78868.0, Flase Neg 102628.0
----------

EPOCH 12 of 70

Epoch Loss: 0.3759
----------
Vaildation Loss: 0.3768
0 Class, True Pos 11896876.0, False Pos 2326099.0, Flas


EPOCH 21 of 70

Epoch Loss: 0.3351
----------
Vaildation Loss: 0.3436
0 Class, True Pos 12015969.0, False Pos 2040843.0, Flase Neg 1089724.0
1 Class, True Pos 2274324.0, False Pos 357284.0, Flase Neg 264875.0
2 Class, True Pos 987785.0, False Pos 302733.0, Flase Neg 347645.0
3 Class, True Pos 460690.0, False Pos 160802.0, Flase Neg 103080.0
4 Class, True Pos 4038209.0, False Pos 505553.0, Flase Neg 899310.0
5 Class, True Pos 2860573.0, False Pos 356793.0, Flase Neg 633226.0
6 Class, True Pos 63027.0, False Pos 33497.0, Flase Neg 31159.0
7 Class, True Pos 116952.0, False Pos 18612.0, Flase Neg 22570.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 72586.0, False Pos 16964.0, Flase Neg 41797.0
10 Class, True Pos 2275044.0, False Pos 121475.0, Flase Neg 470435.0
11 Class, True Pos 103716.0, False Pos 62009.0, Flase Neg 72245.0
----------

EPOCH 22 of 70

Epoch Loss: 0.3343
----------
Vaildation Loss: 0.3398
0 Class, True Pos 12132931.0, False Pos 2101304.0, Flase


EPOCH 31 of 70

Epoch Loss: 0.3264
----------
Vaildation Loss: 0.3472
0 Class, True Pos 11861386.0, False Pos 1801571.0, Flase Neg 1244307.0
1 Class, True Pos 2264041.0, False Pos 301196.0, Flase Neg 275158.0
2 Class, True Pos 1004943.0, False Pos 314857.0, Flase Neg 330487.0
3 Class, True Pos 448997.0, False Pos 137315.0, Flase Neg 114773.0
4 Class, True Pos 4169389.0, False Pos 650505.0, Flase Neg 768130.0
5 Class, True Pos 2947533.0, False Pos 428843.0, Flase Neg 546266.0
6 Class, True Pos 60605.0, False Pos 26804.0, Flase Neg 33581.0
7 Class, True Pos 116271.0, False Pos 17325.0, Flase Neg 23251.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 71338.0, False Pos 14444.0, Flase Neg 43045.0
10 Class, True Pos 2303925.0, False Pos 143320.0, Flase Neg 441554.0
11 Class, True Pos 107075.0, False Pos 53757.0, Flase Neg 68886.0
----------

EPOCH 32 of 70

Epoch Loss: 0.3267
----------
Vaildation Loss: 0.3424
0 Class, True Pos 11919918.0, False Pos 1835226.0, Flas


EPOCH 41 of 70

Epoch Loss: 0.3282
----------
Vaildation Loss: 0.3533
0 Class, True Pos 12034553.0, False Pos 2032926.0, Flase Neg 1071140.0
1 Class, True Pos 2280905.0, False Pos 334539.0, Flase Neg 258294.0
2 Class, True Pos 975863.0, False Pos 265499.0, Flase Neg 359567.0
3 Class, True Pos 441639.0, False Pos 124935.0, Flase Neg 122131.0
4 Class, True Pos 3968811.0, False Pos 438174.0, Flase Neg 968708.0
5 Class, True Pos 2991041.0, False Pos 519830.0, Flase Neg 502758.0
6 Class, True Pos 59924.0, False Pos 25628.0, Flase Neg 34262.0
7 Class, True Pos 114277.0, False Pos 14915.0, Flase Neg 25245.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 70975.0, False Pos 13492.0, Flase Neg 43408.0
10 Class, True Pos 2262256.0, False Pos 117858.0, Flase Neg 483223.0
11 Class, True Pos 105845.0, False Pos 51555.0, Flase Neg 70116.0
----------

EPOCH 42 of 70

Epoch Loss: 0.3226
----------
Vaildation Loss: 0.3444
0 Class, True Pos 11852059.0, False Pos 1796600.0, Flase


EPOCH 51 of 70

Epoch Loss: 0.3352
----------
Vaildation Loss: 0.3484
0 Class, True Pos 12070818.0, False Pos 2050860.0, Flase Neg 1034875.0
1 Class, True Pos 2253026.0, False Pos 273061.0, Flase Neg 286173.0
2 Class, True Pos 984220.0, False Pos 273649.0, Flase Neg 351210.0
3 Class, True Pos 442921.0, False Pos 122775.0, Flase Neg 120849.0
4 Class, True Pos 4098483.0, False Pos 552590.0, Flase Neg 839036.0
5 Class, True Pos 2926987.0, False Pos 392805.0, Flase Neg 566812.0
6 Class, True Pos 61849.0, False Pos 28719.0, Flase Neg 32337.0
7 Class, True Pos 116177.0, False Pos 16859.0, Flase Neg 23345.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 71569.0, False Pos 14318.0, Flase Neg 42814.0
10 Class, True Pos 2243489.0, False Pos 104886.0, Flase Neg 501990.0
11 Class, True Pos 103976.0, False Pos 41403.0, Flase Neg 71985.0
----------

EPOCH 52 of 70

Epoch Loss: 0.3281
----------
Vaildation Loss: 0.3510
0 Class, True Pos 12056365.0, False Pos 1968532.0, Flase

Epoch Loss: 0.3291
----------
Vaildation Loss: 0.3458
0 Class, True Pos 12087098.0, False Pos 2039869.0, Flase Neg 1018595.0
1 Class, True Pos 2261029.0, False Pos 277250.0, Flase Neg 278170.0
2 Class, True Pos 972363.0, False Pos 254004.0, Flase Neg 363067.0
3 Class, True Pos 452840.0, False Pos 136273.0, Flase Neg 110930.0
4 Class, True Pos 4210133.0, False Pos 659254.0, Flase Neg 727386.0
5 Class, True Pos 2831376.0, False Pos 277486.0, Flase Neg 662423.0
6 Class, True Pos 61159.0, False Pos 26811.0, Flase Neg 33027.0
7 Class, True Pos 115801.0, False Pos 15955.0, Flase Neg 23721.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 73774.0, False Pos 16599.0, Flase Neg 40609.0
10 Class, True Pos 2246466.0, False Pos 105391.0, Flase Neg 499013.0
11 Class, True Pos 96628.0, False Pos 27881.0, Flase Neg 79333.0
----------

EPOCH 62 of 70

Epoch Loss: 0.3312
----------
Vaildation Loss: 0.3402
0 Class, True Pos 11998162.0, False Pos 1890063.0, Flase Neg 1107531.0
1 C

In [13]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'vanilla_trained_unet_new_dice.pt')

## Test

In [9]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [10]:
# load model
model.load_state_dict(torch.load('vanilla_trained_unet_new_dice.pt'))
model.to(device)
print("Model loaded")

Model loaded


In [11]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

Dice Score: 0.7406
0 Class, True Pos 15401127.0, False Pos 2537299.0, Flase Neg 1300958.0
1 Class, True Pos 2773163.0, False Pos 322557.0, Flase Neg 383260.0
2 Class, True Pos 1277054.0, False Pos 315812.0, Flase Neg 563094.0
3 Class, True Pos 679472.0, False Pos 177035.0, Flase Neg 188786.0
4 Class, True Pos 5218920.0, False Pos 714988.0, Flase Neg 1293891.0
5 Class, True Pos 3484841.0, False Pos 527246.0, Flase Neg 725667.0
6 Class, True Pos 90449.0, False Pos 38994.0, Flase Neg 41281.0
7 Class, True Pos 129632.0, False Pos 22884.0, Flase Neg 33159.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 313.0
9 Class, True Pos 110730.0, False Pos 22739.0, Flase Neg 41244.0
10 Class, True Pos 2378386.0, False Pos 201263.0, Flase Neg 285521.0
11 Class, True Pos 142682.0, False Pos 50967.0, Flase Neg 74610.0
----------


# Data Augmentation

# Pretraining