In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm,trange 
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import sys
import time
import numpy as np
import imageio
import torchvision
import torchinfo

H = 160; W = 192
from models_utils import *

In [None]:
resnet = adapt_resnet()
print(torchinfo.summary(resnet,(2,1,192,160),depth=1))

imgs_all,labels_all,labels2_all = create_dataset()



In [None]:
#create a layout for the self-similiarities in MIND loss
layout = torch.randn(2,64,1,2).cuda()
grid = F.affine_grid(torch.eye(2,3).unsqueeze(0).cuda(),(1,1,H,W))

#run training for all five folds (here with MIND loss)
for fold in range(5):
    if(fold==0):
        train = torch.arange(2,10)
    if(fold==1):
        train = torch.cat((torch.arange(2),torch.arange(4,10)))
    if(fold==2):
        train = torch.cat((torch.arange(4),torch.arange(6,10)))
    if(fold==3):
        train = torch.cat((torch.arange(6),torch.arange(8,10))); test = torch.arange(6,8)
    if(fold==4):
        train = torch.arange(8)
    
    ramp = torch.sigmoid(torch.linspace(-5,15,24000))
    ramp2 = torch.sigmoid(torch.linspace(-15,15,24000))
    
    resnet = adapt_resnet()
    resnet.cuda()


    print('====== starting fold '+str(fold)+' =======')

    t0 = time.time()
    run_loss = torch.zeros(24000)
    optimizer = torch.optim.Adam(resnet.parameters(),lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,3000,1)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer,3000,.5)
    run_jstd = torch.zeros(24000)
    grid_sp = 4
    disp_hw = 4
    coeffs = 1*torch.tensor([0.003,0.03]); alpha = 50

    with tqdm(total=24000, file=sys.stdout) as pbar:
        for i in range(24000):

            ix = train[torch.randperm(8)][0]

            idx = torch.randperm(imgs_all[ix].shape[0])[:2]
            A = F.affine_grid(torch.eye(2,3).unsqueeze(0)*torch.tensor([.9,.9,1])+torch.randn(2,3)*.15,(1,1,H,W)).cuda().repeat(2,1,1,1)

            #A = F.affine_grid(torch.eye(2,3).unsqueeze(0)*torch.tensor([.9,.9,1]),(1,1,H,W)).cuda().repeat(2,1,1,1)


            input = F.grid_sample(torch.cat((imgs_all[ix][idx[0:1]],imgs_all[ix][idx[1:]]),0),A)
            with torch.no_grad():
                mind_mov = MIND2D_64(input[1:],layout,grid)
                mind_fix = MIND2D_64(input[:1],layout,grid)
            
            random_field = ramp[i]*6*F.interpolate(F.avg_pool2d(F.avg_pool2d(F.avg_pool2d(torch.randn(2,2,160//8,192//8).cuda(),5,stride=1,padding=2),5,stride=1,padding=2),5,stride=1,padding=2),scale_factor=8,mode='bilinear')
            input = (input*(1+random_field[:,:1])+random_field[:,1:])

            output = F.grid_sample(labels_all[ix][idx[0:2]],A)
            
            feat = resnet(input).reshape(2,64,H//grid_sp,W//grid_sp)
            
            ssd = correlate(feat[:1],feat[1:],disp_hw,grid_sp)
            disp_mesh_t = F.affine_grid(disp_hw*torch.eye(2,3).cuda().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,3,1,2).reshape(2,-1,1).flip(0)
            disp_soft = coupled_convex(ssd,disp_mesh_t,grid_sp,alpha=20,coeffs=torch.empty(0))
            disp = F.avg_pool2d(F.avg_pool2d(((F.interpolate(disp_soft,size=(H,W),mode='bilinear')*grid_sp)/torch.tensor([(H-1)/2,(W-1)/2]).cuda().view(1,2,1,1)).flip(1),7,stride=1,padding=3),7,stride=1,padding=3)

            
            mind_warp = F.grid_sample(mind_mov,disp.permute(0,2,3,1)+grid)
            mind_loss = nn.MSELoss()(mind_fix*10,mind_warp*10)

            label_fix = output[0:1]
            label_mov = output[1:2]

            label_warp = F.grid_sample(label_mov,disp.permute(0,2,3,1)+grid)
            loss = 1-soft_dice(label_fix,label_warp).mean()

            jdet = jacobian_det2d(disp.cuda())
            
            run_loss[i] = loss.item()
            run_jstd[i] = jdet.std()

            loss = mind_loss+2*loss+(ramp2[i]*.5)*jdet.std()#-.1).pow(2)
            loss.backward()
            str1 = f"iter: {i}, loss: {'%0.3f'%run_loss[i-90:i-1].mean()}, stddev: {'%0.3f'%run_jstd[i-90:i-1].mean()}, runtime: {'%0.3f'%(time.time()-t0)} sec"

            pbar.set_description(str1)
            pbar.update(1)
            
            if(i%8==7):
                optimizer.step()
                optimizer.zero_grad()
            if(i>8):
                scheduler.step()
    torch.save(resnet.cpu(),'clust2d_sensors_resnet_aug_fold'+str(fold)+'_hw4_mind_R1.pth')
    resnet.cuda()
    print('====== finished fold '+str(fold)+' =======')


    