In [1]:
# This code is used for JHU CS 482/682: Deep Learning 2019 Spring Project
# Copyright: Zhaoshuo Li, Ding Hao, Mingyi Zheng
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
# from torchvision import transforms
import torchvision.transforms.functional as TF

import numpy as np
import random
from tensorboardX import SummaryWriter

import transforms
from dataset import *
from visualization import *
from label_conversion import *
from dice_loss import *
from model_trainning import *
from model_pretrainning import *

from unet import *

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda


# Seed pytorch and numpy and random

In [3]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

## Hyperparameters

In [4]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 70
num_class = 12

In [5]:
weights = torch.ones((num_class,1))
weights = weights.to(device)
dice_loss = DICELoss(weights) 

## Visualization

In [6]:
# Initialize the visualization environment
writer = SummaryWriter()

## Unet

In [None]:
# initialize model
model = unet(useBN=True)
model.to(device)

## Optimizer and Scheduler and loss

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

## Baseline, without augmentation

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=None)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

## Start training

In [None]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'vanilla_trained_unet_new_dice.pt')

## Test

In [None]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [None]:
# load model
model.load_state_dict(torch.load('vanilla_trained_unet_new_dice.pt'))
model.to(device)
print("Model loaded")

In [None]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

# Data Augmentation

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

In [16]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=transforms)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

In [None]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'aug_trained_unet.pt')

In [None]:
# load model
model.load_state_dict(torch.load('aug_trained_unet.pt'))
model.to(device)
print("Model loaded")

In [None]:
# load test dataset
test_dataset=MICCAIDataset(data_type = "test", transform=None)
test_generator=DataLoader(test_dataset,shuffle=False,batch_size=4,num_workers=8)

In [None]:
final_dice = test(model,device,dice_loss,num_class,test_generator,test_dataset,writer)

# Transformation Pretraining 

In [7]:
# IMPORTANT!
# must seed the same value each time when training a new network
pretrain_seed = 128
random.seed(pretrain_seed)
torch.manual_seed(pretrain_seed)
np.random.seed(pretrain_seed)

In [8]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 30
num_class = 12

In [9]:
pretrain_model = unet_pretrain(useBN=True)
pretrain_model.to(device)
print("pretrain model generated")

pretrain model generated


In [10]:
# intialize optimizer and lr decay and loss
optimizer = torch.optim.Adam(pretrain_model.parameters(),lr=learning_rate,weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = torch.nn.MSELoss(reduction='mean')

In [11]:
# define dataset
pretrain_dataset=Transformation_PretrainDataset(data_type = "train", transform=None)
prevalidation_dataset=Transformation_PretrainDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = pretrain_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# imshow(label.permute(1,2,0),denormalize=True)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
pretrain_generator = DataLoader(pretrain_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
prevalidation_generator = DataLoader(prevalidation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

In [None]:
best_model_wts = run_pretraining(pretrain_model,device,scheduler,optimizer,criterion,num_epochs,pretrain_generator,pretrain_dataset,prevalidation_generator,prevalidation_dataset,writer)

Pre-Training Started!

EPOCH 1 of 30

Epoch Loss: 0.1408
----------
Vaildation Loss: 0.1515

EPOCH 2 of 30

Epoch Loss: 0.1216
----------
Vaildation Loss: 0.1350

EPOCH 3 of 30

Epoch Loss: 0.1199
----------
Vaildation Loss: 0.1247

EPOCH 4 of 30

Epoch Loss: 0.1206
----------
Vaildation Loss: 0.1262

EPOCH 5 of 30

Epoch Loss: 0.1185
----------
Vaildation Loss: 0.1418

EPOCH 6 of 30

Epoch Loss: 0.1189
----------
Vaildation Loss: 0.1134

EPOCH 7 of 30

Epoch Loss: 0.1156
----------
Vaildation Loss: 0.1151

EPOCH 8 of 30

Epoch Loss: 0.1162
----------
Vaildation Loss: 0.1146

EPOCH 9 of 30

Epoch Loss: 0.1158
----------
Vaildation Loss: 0.1286

EPOCH 10 of 30



In [None]:
## load best model weights
pretrain_model.load_state_dict(best_model_wts)
## save model
torch.save(pretrain_model.state_dict(), 'pre_trained_unet.pt')

## fine tune network

In [None]:
model = unet(useBN=True)
model.load_state_dict(pretrain_model.state_dict())
model.to(device)

In [None]:
train_batch_size = 10
validation_batch_size=10
learning_rate = 0.001
num_epochs = 70
num_class = 12

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
best_model_wts = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)