In [2]:
import argparse
import logging
import sys
import numpy as np
import torch
from utils import compute_mean_dice
import pandas as pd
import os
from glob import glob
from matplotlib import pyplot as plt
import monai
import torchinfo
import nibabel as nib
from monai.transforms import AsDiscrete
from miseval import evaluate
from scipy.io import savemat
import itk
import SimpleITK as sitk
import random
import subprocess
import sys

In [8]:
atlas_mri_file = os.path.join("dataset3", "Atlas", "P56_Atlas_128_norm_id.nii.gz")
atlas_mask_file = os.path.join("dataset3", "Atlas", "P56_Annotation_128_norm_id_mask.nii.gz")
atlas_labels_file = os.path.join("dataset3", "Atlas", "atlas_gin_map6.nii.gz")

fakedata_mris = os.path.join("dataset3", "Fakedata", "MRI_N4_Resample_Norm_Identity_Affine")
fakedata_masks = os.path.join("dataset3", "Fakedata", "Mask_Resample_Identity_Affine")

In [9]:
from monai.networks.utils import meshgrid_ij
def get_affine_warp(affine):
    image_size=(128,128,128)
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0).to(dtype=torch.float)
    affine_grid = affine_transform(affine)
    affine_warp = affine_grid - grid
    return affine_warp

def affine_transform(theta):
    image_size=(128,128,128)
    mesh_points = [torch.arange(0, dim) for dim in image_size]
    grid = torch.stack(meshgrid_ij(*mesh_points), dim=0).to(dtype=torch.float)
    grid_padded = torch.cat([grid, torch.ones_like(grid[:1])])
    grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
    return grid_warped

In [44]:
import warnings
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from numpy.lib.stride_tricks import as_strided

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
from monai.transforms.utils import (
    convert_pad_mode,
    create_control_grid,
    create_grid,
    create_rotate,
    create_scale,
    create_shear,
    create_translate,
    map_spatial_axes,
    scale_affine,
)
from monai.transforms.utils_pytorch_numpy_unification import allclose, linalg_inv, moveaxis
from monai.utils import (
    GridSampleMode,
    GridSamplePadMode,
    InterpolateMode,
    NumpyPadMode,
    convert_to_dst_type,
    convert_to_tensor,
    ensure_tuple,
    ensure_tuple_rep,
    ensure_tuple_size,
    fall_back_tuple,
    issequenceiterable,
    optional_import,
    pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
from monai.transforms import AffineGrid

RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]


class RandAffineGrid(Randomizable, Transform):
    backend = AffineGrid.backend
    @deprecated_arg(name="as_tensor_output", since="0.6")
    def __init__(
        self,
        rotate_range: RandRange = None,
        shear_range: RandRange = None,
        translate_range: RandRange = None,
        scale_range: RandRange = None,
        as_tensor_output: bool = True,
        device: Optional[torch.device] = None,
    ) -> None:
        self.rotate_range = ensure_tuple(rotate_range)
        self.shear_range = ensure_tuple(shear_range)
        self.translate_range = ensure_tuple(translate_range)
        self.scale_range = ensure_tuple(scale_range)

        self.rotate_params: Optional[List[float]] = None
        self.shear_params: Optional[List[float]] = None
        self.translate_params: Optional[List[float]] = None
        self.scale_params: Optional[List[float]] = None

        self.device = device
        self.affine: Optional[torch.Tensor] = torch.eye(4, dtype=torch.float64)


    def _get_rand_param(self, param_range, add_scalar: float = 0.0):
        out_param = []
        for f in param_range:
            if issequenceiterable(f):
                if len(f) != 2:
                    raise ValueError("If giving range as [min,max], should only have two elements per dim.")
                out_param.append(self.R.uniform(f[0], f[1]) + add_scalar)
            elif f is not None:
                a = self.R.uniform(-f, f) + add_scalar
                a = (float)(np.floor(a*100)/100)
                out_param.append(a)
        return out_param

    def randomize(self, data: Optional[Any] = None) -> None:
        self.rotate_params = self._get_rand_param(self.rotate_range)
        self.shear_params = self._get_rand_param(self.shear_range)
        self.translate_params = self._get_rand_param(self.translate_range)
        self.scale_params = self._get_rand_param(self.scale_range, 1.0)


    def __call__(
        self,
        spatial_size: Optional[Sequence[int]] = None,
        grid: Optional[NdarrayOrTensor] = None,
        randomize: bool = True,
    ) -> torch.Tensor:
        if randomize:
            self.randomize()
        affine_grid = AffineGrid(
            rotate_params=self.rotate_params,
            shear_params=self.shear_params,
            translate_params=self.translate_params,
            scale_params=self.scale_params,
            device=self.device,
        )
        _grid: torch.Tensor
        _grid, self.affine = affine_grid(spatial_size, grid)
        return _grid


    def get_transformation_matrix(self) -> Optional[torch.Tensor]:
        return self.affine

In [61]:
atlas_mri =   torch.from_numpy(nib.load(atlas_mri_file).get_fdata().reshape((1,1,128,128,128))).to(dtype=torch.float)
atlas_mask =  torch.from_numpy(nib.load(atlas_mask_file).get_fdata().reshape((1,1,128,128,128))).to(dtype=torch.float)
atlas_label = torch.from_numpy(nib.load(atlas_labels_file).get_fdata().reshape((1,1,128,128,128))).to(dtype=torch.float)

affine = nib.load(atlas_mri_file).affine
header = nib.load(atlas_mri_file).header

randaffine_grid = RandAffineGrid(rotate_range=(-np.pi/90, np.pi/90), 
                    translate_range=(-1,1), 
                    scale_range=(0.05))
warp = monai.networks.blocks.Warp("bilinear", "zeros")
warp_nearest = monai.networks.blocks.Warp("nearest", "zeros")

for i in range(33):
    print(i+1, end='\r')
    randaffine_grid(spatial_size=(128,128,128))
    A = randaffine_grid.affine
    A_warp = get_affine_warp(A.reshape(16)[:12])
    A_inv = torch.linalg.inv(A)
    A_invwarp = get_affine_warp(A_inv.reshape(16)[:12])
    
    invwarp_img = nib.Nifti1Image(A_invwarp, affine, header)
    invwarp_name = "dataset3/Fakedata/Warp/" + str(i) + "_warp.nii.gz"
    nib.save(invwarp_img, invwarp_name)
    
    mri = warp(atlas_mri, A_warp)
    mask = warp_nearest(atlas_mask, A_warp)
    label = warp_nearest(atlas_label, A_warp)
    
    mri_name = "dataset3/Fakedata/MRI/" + str(i) + "_mri_affine.nii.gz"
    mask_name = "dataset3/Fakedata/Mask/" + str(i) + "_mask_affine.nii.gz"
    label_name = "dataset3/Fakedata/Labels/" + str(i) + "_label_affine.nii.gz"
    
    mri_img = nib.Nifti1Image(mri.squeeze(), affine, header)
    mask_img = nib.Nifti1Image(mask.squeeze(), affine, header)
    label_img = nib.Nifti1Image(label.squeeze(), affine, header)
    
    nib.save(mri_img, mri_name)
    nib.save(mask_img, mask_name)
    nib.save(label_img, label_name)

33