In [1]:
# This code is used for JHU CS 482/682: Deep Learning 2019 Spring Project
# Copyright: Zhaoshuo Li, Ding Hao, Mingyi Zheng
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
# from torchvision import transforms
import torchvision.transforms.functional as TF

import numpy as np
import random
from tensorboardX import SummaryWriter

import transforms
from dataset import *
from visualization import *
from label_conversion import *
from dice_loss import *
from model_trainning import *
from model_pretrainning import *
from model_from_ternaus import *

from unet import *

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda


# Seed pytorch and numpy and random

In [3]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

## Hyperparameters

In [4]:
train_batch_size = 8
validation_batch_size= 8
learning_rate = 0.001
num_epochs = 150
num_class = 12

In [None]:
weights = torch.ones((num_class,1))
weights = weights.to(device)
dice_loss = DICELoss(weights) 

## Visualization

In [None]:
# Initialize the visualization environment
writer = SummaryWriter()

# Baseline, without augmentation

### Unet

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
# initialize model
model = unet(useBN=True)
model.to(device)
print("model loaded")

model loaded


### Optimizer and Scheduler and loss

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

In [None]:
# define dataset
train_dataset=MICCAIDataset(data_type = "train", transform=None)
validation_dataset=MICCAIDataset(data_type = "validation", transform=None)
label_converter = LabelConverter()

# # show one example
# img,label = train_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# # show one example
# img,label = validation_dataset.__getitem__(0)
# imshow(img.permute(1,2,0),denormalize=True)
# colorlabel = train_dataset.label_converter.label2color(label.permute(1,2,0))
# imshow(colorlabel)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

### Start training

In [None]:
best_model_wts, best_accuracy = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)
print(best_accuracy)

Training Started!

EPOCH 1 of 150

Training Loss: 0.8568
0 Class, True Pos 29171220.0, False Pos 21813304.0, False Neg 23953550.0, Num Pixel 53124768.0, Dice score 0.56
1 Class, True Pos 6685966.0, False Pos 5457427.0, False Neg 4237535.0, Num Pixel 10923501.0, Dice score 0.58
2 Class, True Pos 3256451.0, False Pos 5489278.0, False Neg 2101208.0, Num Pixel 5357659.0, Dice score 0.46
3 Class, True Pos 296683.0, False Pos 1765565.0, False Neg 2180660.0, Num Pixel 2477343.0, Dice score 0.13
4 Class, True Pos 846285.0, False Pos 1868685.0, False Neg 19203332.0, Num Pixel 20049616.0, Dice score 0.07
5 Class, True Pos 7814052.0, False Pos 18079506.0, False Neg 5882046.0, Num Pixel 13696098.0, Dice score 0.39
6 Class, True Pos 15521.0, False Pos 3497506.0, False Neg 403097.0, Num Pixel 418618.0, Dice score 0.01
7 Class, True Pos 265.0, False Pos 5073.0, False Neg 612899.0, Num Pixel 613164.0, Dice score 0.00
8 Class, True Pos 0.0, False Pos 1615176.0, False Neg 3960.0, Num Pixel 3960.0, Dice 

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'vanilla_trained_unet_2.pt')