In [1]:
import torch
from torch import nn
import torchio as tio
import time
from IPython import display
from math import ceil

#import pytorch_ssim

from Unet import UNet
from pretrain_datasets import load_pretrain_datasets,load_kidney_seg
from Utils import DiceLoss, dice_ratio, LovaszSoftmax

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits



In [2]:
#Set parameters

num_workers = 6               #number of parallel workers for dataloaders
batch_size = 1                #dataloader batch size
epochs = 10                    #number of epochs per dataset
data_shape = (128,128,128)    #resize all data to this shape
num_filters = 32              #number of channels throughout the network

rate = .0001

num_classes = 2


#choose dataset order and included sets. 0-braintumor,1-heart,2-liver,3-hippocampus,4-prostate,
#5-lung,6-pancreas,7-hepaticvessels,8-spleen,9-colontumor,10-kidney
dataset_order = [2,7,8,6,4,10]

In [3]:
#define data augmentations/transformations
#define transformations
flip = tio.RandomFlip(p=0.3)
spatial = tio.OneOf(
    {tio.RandomAffine(): 0.4, 
     tio.RandomElasticDeformation(): 0.6},p=0.4)
noise = tio.RandomNoise(std=(0.05),p=0.3)
transform = tio.Compose([flip,spatial])


In [4]:
#load data
dataloaders = load_pretrain_datasets(data_shape,batch=batch_size,workers=num_workers,transform=transform)
dataloaders.append(load_kidney_seg(data_shape,batch=batch_size,workers=num_workers,transform=transform))

#make network
net = UNet(1,num_classes,num_filters).cuda()

for m in net.modules():
    if isinstance(m, nn.Conv3d) or isinstance(m,nn.ConvTranspose3d):
        torch.nn.init.kaiming_normal_(m.weight)
    elif isinstance(m, nn.InstanceNorm3d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

#define loss criteria
criterion_bce = torch.nn.BCELoss().cuda()
criterion_lova = LovaszSoftmax().cuda()
criterion_dice = DiceLoss().cuda()

#ssim = pytorch_ssim.SSIM3D().cuda()

#learning params
optimizer = torch.optim.Adam(net.parameters(), lr=rate)


In [5]:
for loadnum,dataset_num in enumerate(dataset_order):
    dataloader = dataloaders[dataset_num]
    
    for e in range(epochs):
        epoch_start_time = time.time()
        start_time = time.time() #gets reset each print
        epoch_loss = 0.
        net.train()
        for batch_idx, subject in enumerate(dataloader):
            image = subject['img'][tio.DATA].cuda()
            label = subject['label'][tio.DATA].cuda()
            
            label1 = torch.cat(((label==1).float(),(label==2).float()),dim=1).cuda()
            
            
            optimizer.zero_grad()
            
            output = net(image)
            
            output = torch.sigmoid(output)
            
            
            loss_bce = criterion_bce(output,label1)
            loss_lova = 1-criterion_lova(output,label)
            loss_dice = criterion_dice(output,label1)
            
            loss = loss_bce + loss_lova + loss_dice
            #loss = (1-ssim(output,label)) #need to minimize dissimilarity
            
            epoch_loss += loss.item()
            
            if batch_idx % 10 == 0:
                display.clear_output(wait=True)
                print_line = 'Dataset {} of {}\n' \
                             'Epoch: {} | Batch: {} -----> Train loss: {:4f} Cost Time: {}\n' \
                             'Batch BCE Loss: {:4f} || ' \
                             'Batch Lovasz Loss: {:4f} || ' \
                             'Batch DICE Loss: {:4f} || ' \
                             .format(loadnum+1,len(dataset_order),e, batch_idx, epoch_loss / (batch_idx + 1),
                                     time.time()-start_time, loss_bce.item(),loss_lova.item(),loss_dice.item())
                             
                print(print_line)
                start_time = time.time()
            
            loss.backward()
            optimizer.step()
        print('Epoch {} Finished ! Loss is {:4f}'.format(e, epoch_loss / (batch_idx + 1)))

        print("Epoch time: ", time.time() - epoch_start_time)
    #optimizer = torch.optim.Adam(net.parameters(), lr=rate)

Dataset 6 of 6
Epoch: 9 | Batch: 200 -----> Train loss: 1.191116 Cost Time: 8.624786138534546
Batch BCE Loss: 0.265333 || Batch Lovasz Loss: 0.100302 || Batch DICE Loss: 0.834664 || 
Epoch 9 Finished ! Loss is 1.192281
Epoch time:  274.9685034751892


In [6]:
timestr = time.strftime("%Y%m%d-%H%M%S")
folder_name = '/home/mitch/fewshotlocal/models/pretrain'
torch.save(net.state_dict(), folder_name + '/model_{}.pth'.format(timestr))