In [2]:
# This code is used for JHU CS 482/682: Deep Learning 2019 Spring Project
# Copyright: Zhaoshuo Li, Ding Hao, Mingyi Zheng
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
# from torchvision import transforms
import torchvision.transforms.functional as TF

import numpy as np
import random
from tensorboardX import SummaryWriter

import transforms
from dataset import *
from visualization import *
from label_conversion import *
from dice_loss import *
from model_trainning import *
from model_pretrainning import *

from unet import *

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda


# Seed pytorch and numpy and random

In [4]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

## Hyperparameters

In [5]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 70
num_class = 12

In [6]:
weights = torch.ones((num_class,1))
weights = weights.to(device)
dice_loss = DICELoss(weights) 

## Visualization

In [7]:
# Initialize the visualization environment
writer = SummaryWriter()

## Unet

In [None]:
# initialize model
model = unet(useBN=True)
model.to(device)

## Optimizer and Scheduler and loss

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

## Baseline, without augmentation

In [20]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [21]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=None)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

## Start training

In [None]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'vanilla_trained_unet_new_dice.pt')

## Test

In [None]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [None]:
# load model
model.load_state_dict(torch.load('vanilla_trained_unet_new_dice.pt'))
model.to(device)
print("Model loaded")

In [None]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

# Data Augmentation

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

In [None]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=transforms)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

In [None]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'aug_trained_unet.pt')

In [None]:
# load model
model.load_state_dict(torch.load('aug_trained_unet.pt'))
model.to(device)
print("Model loaded")

In [None]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [None]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

# Transformation Pretraining 

In [8]:
# IMPORTANT!
# must seed the same value each time when training a new network
pretrain_seed = 128
random.seed(pretrain_seed)
torch.manual_seed(pretrain_seed)
np.random.seed(pretrain_seed)

In [9]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 30
num_class = 12

In [10]:
pretrain_model = unet_pretrain(useBN=True)
pretrain_model.to(device)
print("pretrain model generated")

pretrain model generated


In [11]:
# intialize optimizer and lr decay and loss
optimizer = torch.optim.Adam(pretrain_model.parameters(),lr=learning_rate,weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = torch.nn.MSELoss(reduction='mean')

In [12]:
# define dataset
pretrain_dataset=Transformation_PretrainDataset(data_type = "train", transform=None)
prevalidation_dataset=Transformation_PretrainDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = pretrain_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# imshow(label.permute(1,2,0),denormalize=True)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
pretrain_generator = DataLoader(pretrain_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
prevalidation_generator = DataLoader(prevalidation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

In [13]:
best_model_wts = run_pretraining(pretrain_model,device,scheduler,optimizer,criterion,num_epochs,pretrain_generator,pretrain_dataset,prevalidation_generator,prevalidation_dataset,writer)

Pre-Training Started!

EPOCH 1 of 30

Epoch Loss: 0.1407
----------
Vaildation Loss: 0.1641

EPOCH 2 of 30

Epoch Loss: 0.1216
----------
Vaildation Loss: 0.1452

EPOCH 3 of 30

Epoch Loss: 0.1201
----------
Vaildation Loss: 0.1343

EPOCH 4 of 30

Epoch Loss: 0.1201
----------
Vaildation Loss: 0.1268

EPOCH 5 of 30

Epoch Loss: 0.1183
----------
Vaildation Loss: 0.1305

EPOCH 6 of 30

Epoch Loss: 0.1189
----------
Vaildation Loss: 0.1183

EPOCH 7 of 30

Epoch Loss: 0.1158
----------
Vaildation Loss: 0.1180

EPOCH 8 of 30

Epoch Loss: 0.1162
----------
Vaildation Loss: 0.1166

EPOCH 9 of 30

Epoch Loss: 0.1157
----------
Vaildation Loss: 0.1228

EPOCH 10 of 30

Epoch Loss: 0.1161
----------
Vaildation Loss: 0.1155

EPOCH 11 of 30

Epoch Loss: 0.1117
----------
Vaildation Loss: 0.1095

EPOCH 12 of 30

Epoch Loss: 0.1111
----------
Vaildation Loss: 0.1097

EPOCH 13 of 30

Epoch Loss: 0.1111
----------
Vaildation Loss: 0.1092

EPOCH 14 of 30

Epoch Loss: 0.1113
----------
Vaildation Loss: 

In [14]:
## load best model weights
pretrain_model.load_state_dict(best_model_wts)
## save model
torch.save(pretrain_model.state_dict(), 'pre_trained_unet.pt')

## fine tune network

In [15]:
model = unet(useBN=True)
model.load_state_dict(pretrain_model.state_dict())
model.to(device)

unet(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=

In [22]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [23]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 70
num_class = 12

In [24]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [25]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=None)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

In [26]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

Training Started!

EPOCH 1 of 70

Epoch Loss: 0.8775
----------
Vaildation Loss: 0.8360
0 Class, True Pos 6505601.0, False Pos 4086435.0, Flase Neg 6600092.0
1 Class, True Pos 1160322.0, False Pos 913126.0, Flase Neg 1378877.0
2 Class, True Pos 1187133.0, False Pos 5768067.0, Flase Neg 148297.0
3 Class, True Pos 0.0, False Pos 0.0, Flase Neg 563770.0
4 Class, True Pos 627356.0, False Pos 1919970.0, Flase Neg 4310163.0
5 Class, True Pos 1710068.0, False Pos 3081792.0, Flase Neg 1783731.0
6 Class, True Pos 0.0, False Pos 0.0, Flase Neg 94186.0
7 Class, True Pos 0.0, False Pos 0.0, Flase Neg 139522.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 0.0, False Pos 0.0, Flase Neg 114383.0
10 Class, True Pos 1488044.0, False Pos 797526.0, Flase Neg 1257435.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 2 of 70

Epoch Loss: 0.7493
----------
Vaildation Loss: 0.7420
0 Class, True Pos 10090202.0, False Pos 7633683.0, Flase Neg 3015491.0
1 Cl

Epoch Loss: 0.4588
----------
Vaildation Loss: 0.4521
0 Class, True Pos 11463823.0, False Pos 2813130.0, Flase Neg 1641870.0
1 Class, True Pos 2194820.0, False Pos 367239.0, Flase Neg 344379.0
2 Class, True Pos 890175.0, False Pos 241077.0, Flase Neg 445255.0
3 Class, True Pos 415312.0, False Pos 143173.0, Flase Neg 148458.0
4 Class, True Pos 3419971.0, False Pos 915280.0, Flase Neg 1517548.0
5 Class, True Pos 2625886.0, False Pos 955157.0, Flase Neg 867913.0
6 Class, True Pos 56434.0, False Pos 28714.0, Flase Neg 37752.0
7 Class, True Pos 112441.0, False Pos 23730.0, Flase Neg 27081.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 2713.0, False Pos 18390.0, Flase Neg 111670.0
10 Class, True Pos 2256484.0, False Pos 301491.0, Flase Neg 488995.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 12 of 70

Epoch Loss: 0.4490
----------
Vaildation Loss: 0.4469
0 Class, True Pos 11560533.0, False Pos 2841164.0, Flase Neg 1545160.0
1 Class, 


EPOCH 21 of 70

Epoch Loss: 0.3953
----------
Vaildation Loss: 0.3983
0 Class, True Pos 11665615.0, False Pos 2340863.0, Flase Neg 1440078.0
1 Class, True Pos 2244923.0, False Pos 444530.0, Flase Neg 294276.0
2 Class, True Pos 992838.0, False Pos 351858.0, Flase Neg 342592.0
3 Class, True Pos 449059.0, False Pos 174809.0, Flase Neg 114711.0
4 Class, True Pos 3915939.0, False Pos 765258.0, Flase Neg 1021580.0
5 Class, True Pos 2696384.0, False Pos 491333.0, Flase Neg 797415.0
6 Class, True Pos 60211.0, False Pos 32325.0, Flase Neg 33975.0
7 Class, True Pos 115788.0, False Pos 23193.0, Flase Neg 23734.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 60943.0, False Pos 31654.0, Flase Neg 53440.0
10 Class, True Pos 2220061.0, False Pos 167856.0, Flase Neg 525418.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 22 of 70

Epoch Loss: 0.3964
----------
Vaildation Loss: 0.3978
0 Class, True Pos 11883627.0, False Pos 2511847.0, Flase Neg 12


EPOCH 31 of 70

Epoch Loss: 0.3844
----------
Vaildation Loss: 0.4009
0 Class, True Pos 11519668.0, False Pos 2103775.0, Flase Neg 1586025.0
1 Class, True Pos 2237108.0, False Pos 383696.0, Flase Neg 302091.0
2 Class, True Pos 1002274.0, False Pos 341618.0, Flase Neg 333156.0
3 Class, True Pos 440830.0, False Pos 150053.0, Flase Neg 122940.0
4 Class, True Pos 4008575.0, False Pos 863772.0, Flase Neg 928944.0
5 Class, True Pos 2811174.0, False Pos 633211.0, Flase Neg 682625.0
6 Class, True Pos 59130.0, False Pos 29947.0, Flase Neg 35056.0
7 Class, True Pos 114669.0, False Pos 20190.0, Flase Neg 24853.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 63861.0, False Pos 30364.0, Flase Neg 50522.0
10 Class, True Pos 2246010.0, False Pos 185515.0, Flase Neg 499469.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 32 of 70

Epoch Loss: 0.3817
----------
Vaildation Loss: 0.3933
0 Class, True Pos 11562162.0, False Pos 2102815.0, Flase Neg 15

Epoch Loss: 0.3881
----------
Vaildation Loss: 0.4032
0 Class, True Pos 11740397.0, False Pos 2374910.0, Flase Neg 1365296.0
1 Class, True Pos 2250692.0, False Pos 415909.0, Flase Neg 288507.0
2 Class, True Pos 967865.0, False Pos 286395.0, Flase Neg 367565.0
3 Class, True Pos 435345.0, False Pos 140289.0, Flase Neg 128425.0
4 Class, True Pos 3788312.0, False Pos 604908.0, Flase Neg 1149207.0
5 Class, True Pos 2854579.0, False Pos 756525.0, Flase Neg 639220.0
6 Class, True Pos 58195.0, False Pos 27683.0, Flase Neg 35991.0
7 Class, True Pos 111481.0, False Pos 16003.0, Flase Neg 28041.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 60703.0, False Pos 22972.0, Flase Neg 53680.0
10 Class, True Pos 2188250.0, False Pos 144027.0, Flase Neg 557229.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 42 of 70

Epoch Loss: 0.3824
----------
Vaildation Loss: 0.3992
0 Class, True Pos 11509574.0, False Pos 2101130.0, Flase Neg 1596119.0
1 Class, 


EPOCH 51 of 70

Epoch Loss: 0.3927
----------
Vaildation Loss: 0.4032
0 Class, True Pos 11708125.0, False Pos 2337974.0, Flase Neg 1397568.0
1 Class, True Pos 2230738.0, False Pos 367704.0, Flase Neg 308461.0
2 Class, True Pos 980399.0, False Pos 292064.0, Flase Neg 355031.0
3 Class, True Pos 429367.0, False Pos 129274.0, Flase Neg 134403.0
4 Class, True Pos 3969827.0, False Pos 807516.0, Flase Neg 967692.0
5 Class, True Pos 2795566.0, False Pos 598258.0, Flase Neg 698233.0
6 Class, True Pos 60759.0, False Pos 31837.0, Flase Neg 33427.0
7 Class, True Pos 114281.0, False Pos 19475.0, Flase Neg 25241.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 63350.0, False Pos 26397.0, Flase Neg 51033.0
10 Class, True Pos 2160225.0, False Pos 122304.0, Flase Neg 585254.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 52 of 70

Epoch Loss: 0.3872
----------
Vaildation Loss: 0.4046
0 Class, True Pos 11661762.0, False Pos 2220670.0, Flase Neg 144


EPOCH 61 of 70

Epoch Loss: 0.3867
----------
Vaildation Loss: 0.3880
0 Class, True Pos 11740396.0, False Pos 2303301.0, Flase Neg 1365297.0
1 Class, True Pos 2232041.0, False Pos 347992.0, Flase Neg 307158.0
2 Class, True Pos 952989.0, False Pos 254065.0, Flase Neg 382441.0
3 Class, True Pos 447925.0, False Pos 159336.0, Flase Neg 115845.0
4 Class, True Pos 4093354.0, False Pos 934550.0, Flase Neg 844165.0
5 Class, True Pos 2699475.0, False Pos 436721.0, Flase Neg 794324.0
6 Class, True Pos 60053.0, False Pos 29814.0, Flase Neg 34133.0
7 Class, True Pos 114502.0, False Pos 19191.0, Flase Neg 25020.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 499.0
9 Class, True Pos 66373.0, False Pos 33027.0, Flase Neg 48010.0
10 Class, True Pos 2184176.0, False Pos 136159.0, Flase Neg 561303.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 175961.0
----------

EPOCH 62 of 70

Epoch Loss: 0.3844
----------
Vaildation Loss: 0.3932
0 Class, True Pos 11672809.0, False Pos 2221484.0, Flase Neg 143

In [27]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'transform_pretrained_unet.pt')

In [28]:
# load model
model.load_state_dict(torch.load('transform_pretrained_unet.pt'))
model.to(device)
print("Model loaded")

Model loaded


In [29]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [30]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

Dice Score: 0.6589
0 Class, True Pos 15175938.0, False Pos 3145996.0, Flase Neg 1526147.0
1 Class, True Pos 2712240.0, False Pos 370608.0, Flase Neg 444183.0
2 Class, True Pos 1269266.0, False Pos 362182.0, Flase Neg 570882.0
3 Class, True Pos 671018.0, False Pos 226524.0, Flase Neg 197240.0
4 Class, True Pos 5073795.0, False Pos 985669.0, Flase Neg 1439016.0
5 Class, True Pos 3208612.0, False Pos 501420.0, Flase Neg 1001896.0
6 Class, True Pos 87317.0, False Pos 39723.0, Flase Neg 44413.0
7 Class, True Pos 126181.0, False Pos 25057.0, Flase Neg 36610.0
8 Class, True Pos 0.0, False Pos 0.0, Flase Neg 313.0
9 Class, True Pos 100686.0, False Pos 33329.0, Flase Neg 51288.0
10 Class, True Pos 2271560.0, False Pos 231119.0, Flase Neg 392347.0
11 Class, True Pos 0.0, False Pos 0.0, Flase Neg 217292.0
----------
