In [1]:
# This code is used for JHU Medical Image Analysis Project
# Copyright: Zhaoshuo Li
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
# from torchvision import transforms
import torchvision.transforms.functional as TF

import copy
import numpy as np
import random
from tensorboardX import SummaryWriter

from dataset import *
from visualization import *
from dice_loss import *
from model_training import *
from albumentations import *

from model_from_ternaus import *
from unet import *

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="1,2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.device_count())
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

2
cuda


# Seed pytorch and numpy and random

In [3]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
pretrain_seed = 128

## Hyperparameters

In [4]:
train_batch_size = 1
validation_batch_size = 1
learning_rate = 0.001
num_epochs = 70
num_class = 4

In [5]:
weights = torch.tensor([10.0,2.0,1.0,2.0])
weights = 1/weights
weights = weights/torch.sum(weights)*4
print(weights)

tensor([0.1905, 0.9524, 1.9048, 0.9524])


In [6]:
dice_loss = DICELoss(weights.to(device)) 
dice_loss.to(device)

DICELoss()

## Visualization

In [7]:
# Initialize the visualization environment
writer = SummaryWriter()

# Baseline, without augmentation

### seed

In [8]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

### Unet with Resnet

In [9]:
# UNET
model = AlbuNet(num_classes=4, num_filters=32, pretrained=False, is_deconv=False)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): AlbuNet(
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (encoder): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_s

### optimizer and learing rate decay

In [10]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

### data generator

In [11]:
train_both_aug = Compose([
    PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=0,p=1),
    RandomCrop(height=256, width=256, p=1),
])
val_both_aug = Compose([
    PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=0,p=1),
    RandomCrop(height=256, width=256, p=1),
])

train_dataset=ACDCDataset(data_type="train",transform_both=train_both_aug,transform_image=None)
validation_dataset=ACDCDataset(data_type="validation",transform_both=val_both_aug,transform_image=None)

# # show one example
# img_ED,img_ES,label_ED,label_ES = train_dataset.__getitem__(0)
# imshow(img_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(img_ES.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ES.permute(1,2,0)[:,:,3],denormalize=False)

# # show one example
# img_ED,img_ES,label_ED,label_ES = validation_dataset.__getitem__(0)
# imshow(img_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(img_ES.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ES.permute(1,2,0)[:,:,3],denormalize=False)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

### Start training

In [12]:
best_model_wts, dice_score = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)
print(dice_score)

Training Started!

EPOCH 1 of 70



  "See the documentation of nn.Upsample for details.".format(mode))


RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 5.94 GiB total capacity; 597.26 MiB already allocated; 23.88 MiB free; 9.86 MiB cached)

### save model

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'vanilla_trained_unet_limited_data.pt')

### Save model as ONNX

In [None]:
dummy_input = torch.randn(1, 1, 256, 256, device='cpu')
torch.onnx.export(model.cpu(), dummy_input, "vanilla_trained_unet_limited_data.onnx", verbose=True)

In [None]:
# print learnt weights to verify MATLAB has correctly imported them
for name, param in model.named_parameters():
    print (name, param.data)

# With Data Augmentation

### seed

In [None]:
# IMPORTANT!
# must seed the same value each time when training a new network
seed = 256
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

### Unet

In [None]:
# initialize model
model = unet(useBN=True)
model.to(device)

### optimizer and learing rate decay

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
train_both_aug = Compose([
        PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=0,p=1),
        RandomCrop(height=256, width=256, p=1),
        Cutout(p=0.5),
        OneOf([
            ShiftScaleRotate(p=0.7),
            HorizontalFlip(p=0.8),
            VerticalFlip(p=0.8)
        ])
    ])
train_img_aug = Compose([
    OneOf([
            RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.5,0.5),p=0.9),
            RandomGamma(gamma_limit=(50,200),p=0.8)
        ]),
])


val_both_aug = Compose([
    PadIfNeeded(min_height=256, min_width=256, border_mode=0, value=0,p=1),
    RandomCrop(height=256, width=256, p=1)
])

train_dataset=ACDCDataset(data_type="train",transform_both=train_both_aug,transform_image=None)
validation_dataset=ACDCDataset(data_type="validation",transform_both=val_both_aug,transform_image=None)

# # show one example
# img_ED,img_ES,label_ED,label_ES = train_dataset.__getitem__(0)
# imshow(img_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(img_ES.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ES.permute(1,2,0)[:,:,3],denormalize=False)

# # show one example
# img_ED,img_ES,label_ED,label_ES = validation_dataset.__getitem__(0)
# imshow(img_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(img_ES.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ED.permute(1,2,0)[:,:,3],denormalize=False)
# imshow(label_ES.permute(1,2,0)[:,:,3],denormalize=False)

# intialize the dataloader
train_generator = DataLoader(train_dataset,shuffle=True,batch_size=train_batch_size,num_workers=8)
validation_generator = DataLoader(validation_dataset,shuffle=True,batch_size=validation_batch_size,num_workers=8)

### Start Training

In [None]:
best_model_wts, dice_score = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

In [None]:
## load best model weights
model.load_state_dict(best_model_wts)
## save model
torch.save(model.state_dict(), 'aug_trained_unet.pt')

### fine tune a little bit

In [None]:
# intialize optimizer and lr decay
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model.to(device)

In [None]:
best_model_wts, dice_score = run_training(model,device,num_class,scheduler,optimizer,dice_loss,num_epochs,train_generator,train_dataset,validation_generator,validation_dataset,writer)

## Save model as ONNX

In [None]:
dummy_input = torch.randn(1, 1, 256, 256, device='cpu')
torch.onnx.export(model.cpu(), dummy_input, "aug_trained_unet.onnx", verbose=True)

In [None]:
# print learnt weights to verify MATLAB has correctly imported them
for name, param in model.named_parameters():
    print (name, param.data)