In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time
import matplotlib
import imageio
import scipy.io as sio
import warnings
warnings.filterwarnings('ignore')
from src.utils import *
from src.nets import *

In [None]:
#### load dataset ####
train_imgs_color=sio.loadmat('./data/training_images.mat')['ims']
train_segs=torch.from_numpy(sio.loadmat('./data/training_segmentations.mat')['segs']).long()

# make images gray
train_imgs = torch.from_numpy(np.dot(train_imgs_color,[0.299, 0.587, 0.114]))

# setting dataset for process
img_train = train_imgs.float().unsqueeze(1)/255.0
seg_train = train_segs

numtrain = img_train.shape[0]
img_train = img_train[:numtrain]
seg_train = seg_train[:numtrain]

print(train_imgs.shape, train_segs.shape)
print(img_train.shape, seg_train.shape)

# label weight
a = torch.bincount(seg_train.contiguous().view(-1))
label_weights = torch.sqrt((1/a.float()))/torch.sqrt(1/a.float()).mean()
print('label weights:',label_weights)

In [None]:
### hyper parameters####
numepochs = 300; 
learning_rate = 0.001 
lambda_weight = 0.001
lambda_semantic = 1.0
batchsize = 20

##### network initialization #####
unet = UNet2D(L)
unet.apply(init_weights); unet.cuda()
regnet1 = RegNet(inch=L*2).cuda()
regnet1.apply(init_weights)
regnet2 = RegNet(inch=L*2).cuda()
regnet2.apply(init_weights)

metric = nn.L1Loss()
metric_seg = nn.CrossEntropyLoss(weight=label_weights.cuda())
optimizer = optim.Adam(list(unet.parameters())+list(regnet1.parameters())+list(regnet2.parameters()),lr=learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer,0.97)

# start training and validation
for epoch in range(numepochs):
    idx_epoch = torch.cat((torch.randperm(numtrain).view(-1,1),torch.randperm(numtrain).view(-1,1)),1).view(batchsize,-1,2)

    unet.train()
    regnet1.train()
    regnet2.train()
    scheduler.step()

    for iter in range(idx_epoch.shape[1]):
        idx = idx_epoch[:,iter,:]
        optimizer.zero_grad()

        # image and segmentation
        img_aug, label_aug = augmentAffine(img_train[idx].squeeze().cuda(),seg_train[idx].float().cuda())

        y_label1 = label_aug[:,0]; y_label2 = label_aug[:,1]  
        label_y1 = labelMatrixOneHot(y_label1.cpu(),L)
        label_y2 = labelMatrixOneHot(y_label2.cpu(),L)

        # train UNet
        seg_predict1 = unet(img_aug[:,0:1].cuda())
        seg_predict2 = unet(img_aug[:,1:2].cuda()) 

        # train registration networks
        predflow1 = regnet1(torch.cat((seg_predict1,seg_predict2),1))
        warped_mov1, estflow1 = BsplineTrafo(seg_predict2,predflow1,19)
        predflow2 = regnet2(torch.cat((seg_predict1,warped_mov1),1))

        # segmentation loss (semantic guidance)
        semanticloss = 0.5*metric_seg(seg_predict1,y_label1[:,::2,::2].long().cuda())+0.5*metric_seg(seg_predict2,y_label2[:,::2,::2].long().cuda())
        warped_mov2, estflow2 = BsplineTrafo(warped_mov1.float().cuda(),predflow2,11)

        # deformation field is sum of estimated fields from two networks            
        def_x2 = F.interpolate(estflow1+estflow2,scale_factor=2,mode='bilinear')
        warped_seg = warpImage(label_y2.float().cuda(),def_x2)

        # deformation loss
        deformloss = torch.mean((warped_seg-label_y1.float().cuda()).abs()*label_weights.view(1,L,1,1).cuda())

        # regularization loss
        dx = def_x2[:,0:1,:,:];         dy = def_x2[:,1:2,:,:]
        dx_smooth = F.avg_pool2d(F.avg_pool2d(dx,5,padding=2,stride=1),5,padding=2,stride=1)
        dy_smooth = F.avg_pool2d(F.avg_pool2d(dy,5,padding=2,stride=1),5,padding=2,stride=1)
        regloss = torch.norm(dx-dx_smooth)+torch.norm(dy-dy_smooth)

        # total loss
        loss = lambda_semantic*semanticloss + deformloss + lambda_weight*regloss

        # backpropagation
        loss.backward()
        optimizer.step()       

        J = torch.std(jacobian_det(def_x2.data.cpu()))
        Jnegativ = (jacobian_det(def_x2.data.cpu())<0).float().mean()

