In [1]:
import reg_mri
import os
from glob import glob
from utils import compute_mean_dice
import nibabel as nib
from scipy.spatial.distance import dice
import numpy as np
import itk
import SimpleITK as sitk
import scipy.ndimage
import scipy
import matplotlib.pyplot as plt
from transforms_dict import getRegistrationEvalInverseTransformForMRI, SaveTransformForMRI
from tqdm import tqdm
import monai
import subprocess
from monai.transforms import AsDiscrete, MaskIntensity, RandAffine, Affine
import torch

In [2]:
import argparse
import logging
import sys

import numpy as np
import torch
from monai.transforms import AsDiscrete, MaskIntensity, RandAffine, Affine
from monai.utils import set_determinism
from monai.losses import LocalNormalizedCrossCorrelationLoss
from torch.utils.tensorboard import SummaryWriter
import random

import utils_parser
from reg_data import getRegistrationDataset
from reg_model import getRegistrationModel
from utils import compute_mean_dice, getAdamOptimizer, getReducePlateauScheduler, loadExistingModel, getDevice
from utils import print_model_output, print_weights, add_weights_to_name, compute_landmarks_distance_local
from loss import compute_affine_loss, get_jacobian, antifolding_loss, JacobianDet
from loss import get_deformable_registration_loss_from_weights, get_affine_registration_loss_from_weights, jacobian_loss
from models import TrilinearLocalNet
from torchinfo import summary
from miseval import evaluate
import time

In [5]:
from monai.networks.blocks import Warp
from monai.networks.utils import meshgrid_ij

def compute_ddfcompare_loss(u1, u2):        
    image_size=(128,128,128)
    warp_stn = Warp("bilinear", "reflection")
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0)  # (spatial_dims, ...)
    X = grid.to(dtype=torch.float)
    X_x = X[0,:,:,:].unsqueeze(0).unsqueeze(0)
    X_y = X[1,:,:,:].unsqueeze(0).unsqueeze(0)
    X_z = X[2,:,:,:].unsqueeze(0).unsqueeze(0)
    u1X = X + u1[0,:,:,:,:]
    #print('-'*10)
    #p = [72,83,94]
    #print('X: ' + str(X[:,p[0],p[1],p[2]]))
    #print('u1: ' + str(u1[0,:,p[0],p[1],p[2]]))
    #print('u1X: ' + str(u1X[:,p[0],p[1],p[2]]))
    u1X_x = u1X[0,:,:,:].unsqueeze(0).unsqueeze(0)
    u1X_y = u1X[1,:,:,:].unsqueeze(0).unsqueeze(0)
    u1X_z = u1X[2,:,:,:].unsqueeze(0).unsqueeze(0)
    u2u1X_x = warp_stn(u1X_x, u2)        
    u2u1X_y = warp_stn(u1X_y, u2)    
    u2u1X_z = warp_stn(u1X_z, u2)    
    u2u1X = torch.stack([u2u1X_x.squeeze(), u2u1X_y.squeeze(), u2u1X_z.squeeze()])
    #print('u2u1X: ' + str(u2u1X[:,p[0],p[1],p[2]]))
    #print('-'*10)
    loss = torch.nn.MSELoss()
    noise_ddfcompare_loss = loss(u2u1X, X)
    
    return noise_ddfcompare_loss

In [6]:
from monai.networks.utils import meshgrid_ij
from monai.transforms import AffineGrid

def get_affine_warp(affine):
    image_size=(128,128,128)
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0).to(dtype=torch.float)
    affine_grid = affine_transform(affine)
    affine_warp = affine_grid - grid
    return affine_warp

def affine_transform(theta):
    image_size=(128,128,128)
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0).to(dtype=torch.float)
    grid_padded = torch.cat([grid, torch.ones_like(grid[:1])])
    grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
    return grid_warped

affine_grid = AffineGrid(rotate_params=(-np.pi/90, np.pi/90, np.pi/90), 
                    translate_params=(1.21,1.37,-1.24), 
                    scale_params=(1.027,1.012,0.984), 
                    device=None, 
                    dtype=np.float32,
                    affine=None)
_, A = affine_grid(spatial_size=(128,128,128))
A = A.reshape(16)[:12]
A_inv = torch.linalg.inv(torch.cat((A, torch.Tensor([0,0,0,1])), 0).reshape(4,4)).reshape(16)[:12]

A_warp = get_affine_warp(A)
A_inv_warp = get_affine_warp(A_inv)

compute_ddfcompare_loss(A_warp, A_inv_warp)

tensor(3.2102)

In [None]:
affine_grid = AffineGrid(rotate_params=(0, 0, 0), 
                    translate_params=(0,0,0), 
                    scale_params=(1,1,1), 
                    device=None, 
                    dtype=np.float32,
                    affine=None)
_, affine_matrix = affine_grid(spatial_size=(128,128,128))

In [None]:
from typing import List, Optional, Tuple, Union
from monai.networks.utils import meshgrid_ij
from monai.transforms import Affine

def affine_transform(theta: torch.Tensor):
    grid_padded = torch.cat([grid, torch.ones_like(grid[:1])])
    grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
    return grid_warped


def get_affine(theta=0, tx=0, ty=0, tz=0, sx=1, sy=1, sz=1):
    return torch.tensor([sx*np.cos(theta), -np.sin(theta), 0, tx,
                         np.sin(theta), sy*np.cos(theta), 0, ty,
                         0, 0, sz*1, tz], dtype=torch.float)

def get_affine_warp(affine):
    image_size=(128,128,128)
    grid = get_reference_grid(image_size)
    print(grid[:,72,85,51])
    affine_grid = affine_transform(affine)
    print(affine_grid[0,:,72,85,51])
    ##print('-'*10)
    affine_warp = affine_grid - grid
    print(affine_warp[0,:,72,85,51])
    print('-'*10)
    return affine_warp

def get_rotation_matrix_center(theta, tx, ty):
    translation_1 = torch.tensor([1, 0, 0, -tx,
                                0, 1, 0, -ty,
                                0, 0, 1, 0], dtype=torch.float)  
    translation_1 = torch.cat((translation_1, torch.Tensor([0,0,0,1])), 0).reshape(4,4)
    
    rotation = torch.tensor([np.cos(theta), -np.sin(theta), 0, 0,
                         np.sin(theta), sy*np.cos(theta), 0, 0,
                         0, 0, 1, 0], dtype=torch.float)
    rotation = torch.cat((rotation, torch.Tensor([0,0,0,1])), 0).reshape(4,4)
    translation_2 = torch.tensor([1, 0, 0, tx,
                                0, 1, 0, ty,
                                0, 0, 1, 0], dtype=torch.float)  
    translation_2 = torch.cat((translation_2, torch.Tensor([0,0,0,1])), 0).reshape(4,4)
    out = translation_2*rotation*translation_1
    out = out.reshape(16)[:12]
    return out

#ptdr = get_rotation_matrix_center(np.pi/2,64,64)
    

alpha = np.pi/2
tx = 64
ty = 64
tz = 0
sx = 1
sy = 1
sz = 1
A = get_affine(alpha, tx, ty, tz, sx, sy, sz)
A_inv = torch.linalg.inv(torch.cat((A, torch.Tensor([0,0,0,1])), 0).reshape(4,4)).reshape(16)[:12]
warp = get_affine_warp(A)
warp_inv = get_affine_warp(A_inv)
compute_ddfcompare_loss(warp, warp_inv)


In [None]:
from typing import List, Optional, Tuple, Union
from monai.networks.utils import meshgrid_ij
from monai.transforms import AffineGrid

affine_grid = AffineGrid(rotate_params=(0, 0, 0), 
                    translate_params=(0,0,0), 
                    scale_params=(1,1,1), 
                    device=None, 
                    dtype=np.float32,
                    affine=None)
_, affine_matrix = affine_grid(spatial_size=(128,128,128))
affine_matrix_inv = torch.linalg.inv(affine_matrix)
print(affine_matrix)
print(affine_matrix_inv)
affine_matrix = affine_matrix.reshape(16)[:12]
affine_matrix_inv = affine_matrix_inv.reshape(16)[:12]

def affine_transform(theta: torch.Tensor):
    grid_padded = torch.cat([grid, torch.ones_like(grid[:1])])
    grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
    return grid_warped


def get_affine(theta=0, tx=0, ty=0, tz=0, sx=1, sy=1, sz=1):
    return torch.tensor([sx*np.cos(theta), -np.sin(theta), 0, tx,
                         np.sin(theta), sy*np.cos(theta), 0, ty,
                         0, 0, sz*1, tz], dtype=torch.float)

def get_reference_grid(image_size: Union[Tuple[int], List[int]]) -> torch.Tensor:
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0)  # (spatial_dims, ...)
    return grid.to(dtype=torch.float)

def get_affine_warp(affine):
    image_size=(128,128,128)
    grid = get_reference_grid(image_size)
    print(grid[:,72,85,51])
    affine_grid = affine_transform(affine)
    print(affine_grid[0,:,72,85,51])
    ##print('-'*10)
    affine_warp = affine_grid - grid
    print(affine_warp[0,:,72,85,51])
    print('-'*10)
    return affine_warp

print(affine_matrix.shape)
affine_warp = get_affine_warp(affine_matrix)
print(affine_warp.shape)
affine_warp_inv = get_affine_warp(affine_matrix_inv)

compute_ddfcompare_loss(affine_warp, affine_warp_inv)