In [2]:
import reg_mri
import os
from glob import glob
from utils import compute_mean_dice
import nibabel as nib
from scipy.spatial.distance import dice
import numpy as np
import itk
import SimpleITK as sitk
import scipy.ndimage
import scipy
import matplotlib.pyplot as plt
from transforms_dict import getRegistrationEvalInverseTransformForMRI, SaveTransformForMRI
from tqdm import tqdm
import monai
import subprocess
from monai.transforms import AsDiscrete, MaskIntensity, RandAffine, Affine
import torch

In [3]:
import argparse
import logging
import sys

import numpy as np
import torch
from monai.transforms import AsDiscrete, MaskIntensity, RandAffine, Affine
from monai.utils import set_determinism
from monai.losses import LocalNormalizedCrossCorrelationLoss
from torch.utils.tensorboard import SummaryWriter
import random

import utils_parser
from reg_data import getRegistrationDataset
from reg_model import getRegistrationModel
from utils import compute_mean_dice, getAdamOptimizer, getReducePlateauScheduler, loadExistingModel, getDevice
from utils import print_model_output, print_weights, add_weights_to_name, compute_landmarks_distance_local
from loss import compute_affine_loss, get_jacobian, antifolding_loss, JacobianDet
from loss import get_deformable_registration_loss_from_weights, get_affine_registration_loss_from_weights, jacobian_loss
from models import TrilinearLocalNet
from torchinfo import summary
from miseval import evaluate
import time

In [4]:
def compute_affine_loss_test(ddf1, ddf2, affine_matrix):
    ## A(x+u'(x)) = x+u(x)
    ## => Ax + Au'(x) = x+u(x)
    ## => Ax-x + Au'(x)-u(x) = 0
    ## => Affineloss = 1/128³*Sum(x€128³) |Ax+Au'(x) - (x+u(x))|²
    # input = Ax+Au'(x) ||| target = x+u(x)
    som = 0
    device = getDevice()
    #print(affine_matrix)
    for i in range(128):
        for j in range(128):
            for k in range(128):
                x = torch.tensor([[i, j, k]]).to(device)
                
                ux = ddf1[0,:,i,j,k]
                ux = torch.tensor([[ux[0],ux[1],ux[2]]]).to(device)
                ux = torch.transpose(ux,0,1)
                ux = torch.mm(torch.eye(3).to(device), ux)
                ux = torch.transpose(ux,0,1)
    
                x_aff = torch.tensor([[i,j,k,1]]).type(torch.cuda.FloatTensor).to(device)
                x_aff = torch.transpose(x_aff,0,1)
                Ax = torch.mm(affine_matrix,x_aff) 
                Ax = torch.transpose(Ax,0,1)
                Ax = Ax[0,0:3]
                
                Aux = ddf2[0,:,i,j,k]
                Aux_aff = torch.tensor([[Aux[0],Aux[1],Aux[2],1]]).type(torch.cuda.FloatTensor).to(device)
                Aux_aff = torch.transpose(Aux_aff,0,1)
                Aux = torch.mm(affine_matrix,Aux_aff)
                Aux = torch.transpose(Aux,0,1)
                Aux = Aux[0,0:3]                   
                
                input = Ax+Aux
                target = x+ux
                terme = torch.abs((input-target).squeeze())
                term = torch.sum(terme)
                som += term * term
                break
            break
        break
    som /= 128*128*128
    return som                

In [5]:
def compute_affine_loss_vector(u1, u2, A):
    ## => Affineloss = 1/128³*Sum(x€128³) |Ax+Au'(x) - (x+u(x))|²
    
    # input = AX+AU2(X) ||| target = X+U1(x)
    
    device = getDevice()    
    #A = torch.linalg.inv(A)   
    u1 = u1.to(device)
    #print(u1.shape)
    u2 = u2.to(device)    
    
    a = torch.arange(128,dtype=torch.float64, device=device)
    one = torch.ones([128, 128, 128], dtype=torch.float64, device=device)
    img = torch.stack([torch.meshgrid(a,a,a)[0], torch.meshgrid(a,a,a)[1], torch.meshgrid(a,a,a)[2]])
    
    img_aff = torch.stack([torch.meshgrid(a,a,a)[0], torch.meshgrid(a,a,a)[1], torch.meshgrid(a,a,a)[2], one]) 
    #print(img.shape)
    #print(img[:,78,92,101])
    
    A_stack = A.type(torch.float64).repeat(128*128*128,1,1).to(device)

    
    X = img.view((3, 128*128*128)).permute(1,0).unsqueeze(2)    
    #print(X.shape)
    U1X = torch.as_tensor(u1.as_tensor().view((3, 128*128*128, 1)).permute(1,0,2), dtype=torch.float64, device=device)
    #print(U1X.shape)
    
    X_aff = img_aff.view((4, 128*128*128)).permute(1,0).unsqueeze(2)    
    U2X = torch.as_tensor(u2.as_tensor().view((3, 128*128*128, 1)).permute(1,0,2), dtype=torch.float64, device=device)   
    one_aff = torch.ones([128*128*128,1], dtype=torch.float64, device=device)
    U2X_aff = torch.stack([U2X[:,0], U2X[:,1], U2X[:,2], one_aff], dim=1)
    AXplusAU2X = torch.bmm(A_stack, X_aff+U2X_aff)[:,0:3]#.squeeze().permute(1,0).view((4,128,128,128))
    #print(AXplusAU2X.shape)
    
    input = AXplusAU2X
    target = X+U1X
    loss = torch.nn.MSELoss()
    affine_loss = loss(input, target)
    
    return affine_loss

In [6]:
def compute_affine_loss_brouillon(u1, u2, A):
    ## => Affineloss = 1/128³*Sum(x€128³) |Ax+Au'(x) - (x+u(x))|²
    
    # input = AX+AU2(X) ||| target = X+U1(x)
    
    device = getDevice()    
    A = torch.linalg.inv(A).type(torch.FloatTensor)
    print(A)
    u1 = u1.to(device).squeeze().reshape(3,128*128*128).type(torch.FloatTensor)
    u2 = u2.to(device).squeeze().reshape(3,128*128*128).type(torch.FloatTensor)
    
    a = torch.arange(128,dtype=torch.float64, device=device)    
    x = torch.stack(torch.meshgrid(a,a,a)).view(3,128*128*128)
    
    h1 = x + u1
    
    h2 = x + u2
    h2 = h2.view((3,128*128*128)).type(torch.FloatTensor)
    
    h2 = torch.mm(A[:3,:3],h2) + A[:3,3].unsqueeze(1)
    
    loss = torch.nn.MSELoss()
    affine_loss = loss(h1, h2)
    return affine_loss

In [7]:
A = torch.tensor([[1., 0., 0., 2.],
                    [0., 1., 0., 3.],
                    [0., 0., 1., 4.],
                    [0., 0., 0., 1.]])


In [8]:
T = A[:3,3].reshape(1,3,1,1,1)
print(T.shape)

torch.Size([1, 3, 1, 1, 1])


In [9]:
R = torch.tensor([[1., 0., 0., 0.],
                    [0., 0., -1., 0.],
                    [0., 1., 0., 0.],
                    [0., 0., 0., 1.]])

In [10]:
print(ddf.shape)

NameError: name 'ddf' is not defined

In [11]:
ddf2 = ddf.permute(0,1,2,4,3)

NameError: name 'ddf' is not defined

In [215]:
ddf[0,:,78,101,123].array

array([ 0.06156899,  0.71789205, -0.0716481 ], dtype=float32)

In [217]:
ddf2[0,:,78,123,101].array

array([ 0.06156899,  0.71789205, -0.0716481 ], dtype=float32)

In [1]:
compute_affine_loss_brouillon(ddf, ddf2, T)

NameError: name 'compute_affine_loss_brouillon' is not defined

In [179]:
compute_affine_loss_vector(ddf, ddf, B)

tensor(5402.0856, dtype=torch.float64, grad_fn=<MseLossBackward0>)

In [131]:
np.min(ddf)

-4.6532774

In [15]:
outfolders = [
        "paper-old-0.1",
        "paper-affine-0.1",   
        "paper-old-8.0",
        "paper-affine-8.0",   
]
models = [
    "paper/local_overfit_feminad_old_1.0-0.0-0.1.pth",
    "paper/local_overfit_feminad_old_affine_1.0-0.0-0.1.pth",
    "paper/local_overfit_feminad_old_1.0-0.0-8.0.pth",
    "paper/local_overfit_feminad_old_affine_1.0-0.0-8.0.pth",
]

outdataset = 'Feminad'
mris = sorted(glob(os.path.join('dataset2', outdataset, 'MRI_N4_Resample_Norm_Identity_Affine', "*.nii.gz")))
atlas_name = "dataset2/Atlas/Identity_Feminad_Template.nii.gz"
affine = nib.load(atlas_name).affine
header = nib.load(atlas_name).header 

sums = [0,0,0,0]
for mri in mris:
    nib_mri = nib.load(mri).get_fdata()
    nib_mri = np.expand_dims(nib_mri, axis=0)    
    randaffine_transform = RandAffine(
                            mode='bilinear',
                            prob=1.0,
                            rotate_range=(np.pi/45, np.pi/45, np.pi/45),
                            scale_range=(0.1, 0.1, 0.1),
                            translate_range=(2, 2, 2),
                        )
    randaffine_moving_image = randaffine_transform(nib_mri).unsqueeze(0)
    randaffine_matrix = randaffine_transform.rand_affine_grid.get_transformation_matrix()     
    nib.save(nib.Nifti1Image(randaffine_moving_image.squeeze(), affine, header), "tmp3.nii.gz")
    
    for i, model in enumerate(models):
        if "newmodel" in model:
            newmodel = True
        else:
            newmodel = False
        pred_image, ddfs = reg_mri.main(model, mri, "tmp1", False, False, "local", newmodel=newmodel)
        u1 = ddfs[1]        
        pred_image, ddfs = reg_mri.main(model, "tmp3.nii.gz", "tmp2", False, False, "local", newmodel=newmodel)
        u2 = ddfs[1]
        affine_loss = compute_affine_loss_vector(u1,u2,randaffine_matrix)
        sums[i] += affine_loss
for i in range(len(sums)):
    sums[i] /= len(mris)
    outmessage = "{}: aff: {:.4f}".format(outfolders[i], sums[i])
    print(outmessage)
        

paper-old-0.1: aff: 38.1056
paper-affine-0.1: aff: 33.3974
paper-old-8.0: aff: 37.8664
paper-affine-8.0: aff: 33.1848


In [12]:
modelname = "testptdr.pth"
dataset = "feminadaffine"
ft = None
ct = None
batchsize = 1 
max_epochs = 1000
lr = 0.001
patience = 20
weights = [1.0, 0.0, 8.0]
registration_type = "local"
atlas = True
mask = False
pt = None
newmodel = False 
validfeminad = True
freeze = 0
cycle_consistent_training = False 
use_jacobian_loss = False
affine_consistent_training = True
use_antifolding_loss = False
use_ddf = False    
    
torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = True
modelname = add_weights_to_name(modelname, weights)
print_model_output(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()

if newmodel:
    channels = 32
    extract = [0, 1, 2, 3, 4]
else:
    channels = 16
    extract = [0, 1, 2, 3]
model = getRegistrationModel(registration_type, img_size=128, pretrain_model=pt,
                                 channels=channels, extract=extract, use_ddf=use_ddf)

optimizer = getAdamOptimizer(model, lr)
learningrate = lr
scheduler = getReducePlateauScheduler(optimizer, factor=0.5, patience=20)
weights = loadExistingModel(model, optimizer, ft, ct, weights=weights, registration=True)
print_weights(weights)

dataloaders, size = getRegistrationDataset(dataset=dataset,
                                           batch=batchsize,
                                           training=True,
                                           augment=True,
                                           eval_augment=False,
                                           atlas=atlas,
                                           mask=mask,
                                           validfeminad=validfeminad,
                                           )

best_loss = np.inf
best_epoch = -1

sizelol = len(dataloaders["train"])
print(sizelol)

#for epoch in range(-1, max_epochs):
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")

    train_loss, train_metric, train_dice, train_lbl_loss, train_img_loss, train_ddf_loss = 0, 0, 0, 0, 0, 0
    valid_loss, valid_metric, valid_dice, valid_lbl_loss, valid_img_loss, valid_ddf_loss = 0, 0, 0, 0, 0, 0
    train_aff_loss, valid_aff_loss = 0, 0

    for phase in ['train', 'valid']:
        if epoch == -1 and phase == 'train':
            continue
        if phase == 'train':
            model.train()
        elif phase == 'valid':
            model.eval()

        running_loss = 0.0

        running_aff_loss = 0.0

        running_metric = 0.0
        running_dice = 0.0
        
        running_img_loss = 0.0
        running_lbl_loss = 0.0
        running_ddf_loss = 0.0

        for i, data in enumerate(dataloaders[phase]):
            if i >= sizelol:
                break

            print(i, end='\r')
            optimizer.zero_grad(set_to_none=True)
            with torch.set_grad_enabled(phase == 'train'):
                ddf, pred_image, pred_label, dvf = model(data)

                pred_image = pred_image.to(device, non_blocking=True)
                pred_label = pred_label.to(device, non_blocking=True)
                pred_mask = AsDiscrete(threshold=0.5)(pred_label)
                pred_image_masked = MaskIntensity(mask_data=pred_mask)(pred_image)

                fixed_image = data['fixed_image'].to(device, non_blocking=True)
                fixed_label = data['fixed_label'].to(device, non_blocking=True)
                fixed_mask = AsDiscrete(threshold=0.5)(fixed_label)
                fixed_image_masked = MaskIntensity(mask_data=fixed_mask)(fixed_image)

                img_loss, lbl_loss, ddf_loss = get_deformable_registration_loss_from_weights(pred_image_masked,
                                                                                                     pred_mask,
                                                                                                     fixed_image_masked,
                                                                                                     fixed_mask,
                                                                                                     dvf,
                                                                                                     weights)
                loss = img_loss + lbl_loss + ddf_loss
                
                if affine_consistent_training:
                    randaffine_transform = RandAffine(
                        mode='bilinear',
                        prob=1.0,
                        rotate_range=(np.pi/90, np.pi/90, np.pi/90),
                        scale_range=(0.05, 0.05, 0.05),
                        translate_range=(2, 2, 2),
                    )
                    randaffine_moving_image = randaffine_transform(data["moving_image"][0, :, :, :, :]).unsqueeze(0)
                    randaffine_matrix = randaffine_transform.rand_affine_grid.get_transformation_matrix().to(device)
                    randaffine_transform_nearest = Affine(
                        mode='nearest',
                        affine=randaffine_matrix
                    )
                    randaffine_moving_label, _ = randaffine_transform_nearest(data["moving_label"][0, :, :, :, :])
                    randaffine_moving_label = randaffine_moving_label.unsqueeze(0)
                    randaffine_data = {
                        "fixed_image": fixed_image,
                        "fixed_label": fixed_label,
                        "moving_image": randaffine_moving_image,
                        "moving_label": randaffine_moving_label,
                    }

                    if registration_type.lower() == 'local':
                        randaffine_ddf, randaffine_pred_image, randaffine_pred_label, randaffine_dvf = model(randaffine_data)
                        randaffine_pred_image = randaffine_pred_image.to(device, non_blocking=True)
                        randaffine_pred_label = randaffine_pred_label.to(device, non_blocking=True)
                        randaffine_pred_mask = AsDiscrete(threshold=0.5)(randaffine_pred_label)
                        randaffine_pred_image_masked = MaskIntensity(mask_data=randaffine_pred_mask)(randaffine_pred_image)

                        affine_img_loss, affine_lbl_loss, affine_ddf_loss = get_deformable_registration_loss_from_weights(
                                randaffine_pred_image_masked,
                                randaffine_pred_mask,
                                fixed_image_masked,
                                fixed_mask,
                                randaffine_dvf,
                                weights)

                        loss2 = affine_img_loss + affine_lbl_loss + affine_ddf_loss

                        affine_loss = compute_affine_loss_vector(ddf, randaffine_ddf, randaffine_matrix)
                        loss = loss + loss2 + affine_loss


                dice_metric = compute_mean_dice(pred_mask, fixed_mask)
                metric = np.mean(compute_landmarks_distance_local(ddf, data)[1:])

                if phase == 'train':
                    learningrate = optimizer.param_groups[0]['lr']
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * fixed_image.size(0)
            running_aff_loss += affine_loss.item() * fixed_image.size(0)
            running_metric += metric.item() * fixed_image.size(0)
            running_dice += dice_metric.item() * fixed_image.size(0)
            running_img_loss += img_loss.item() * fixed_image.size(0)
            running_lbl_loss += lbl_loss.item() * fixed_image.size(0)
            running_ddf_loss += ddf_loss.item() * fixed_image.size(0)


        running_loss /= sizelol
        running_aff_loss /= sizelol
        running_metric /= sizelol
        running_dice /= sizelol
        running_img_loss /= sizelol
        running_lbl_loss /= sizelol
        running_ddf_loss /= sizelol

        if phase == 'train':
            train_loss, train_metric, train_img_loss, train_lbl_loss, train_ddf_loss = (
                running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
            train_dice = running_dice
            train_aff_loss = running_aff_loss
        elif phase == 'valid':
            valid_loss, valid_metric, valid_img_loss, valid_lbl_loss, valid_ddf_loss = (
                running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
            valid_dice = running_dice
            valid_aff_loss = running_aff_loss

        outmessage = "{}: loss: {:.4f} - metric: {:.4f} -- img: {:.4f}, lbl: {:.4f}, ddf: {:.4f}".format(
                phase, running_loss, running_metric, running_img_loss, running_lbl_loss, running_ddf_loss)
        outmessage += " -- aff: {:.4f}".format(running_aff_loss)
        outmessage += " -- dice: {:.4f}".format(running_dice)
        print(outmessage)

        if (phase == 'valid' and not validfeminad) or (phase == 'train' and validfeminad):
            scheduler.step(running_loss)
            if running_loss < best_loss:
                best_loss = running_loss
                best_epoch = epoch + 1
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'weights': weights,
                    'epoch': epoch,
                    'lr': lr,
                    },
                    './models/' + modelname
                )
                print(
                    "best loss {:.4f} at epoch {}".format(
                        best_loss, best_epoch
                    )
                )

=> Saving to testptdr_1.0-0.0-8.0.pth


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.