# NITorch (**N**euro**I**maging Py**Torch**)

# Affine Registration Demo

Also available in self-contained Colab notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/13eSBtEvAp1wIJD0Rlvq5Q9kJWnuEc7WI?usp=sharing "NITorch Affine Registration Demo")


##### For offline version:

In [None]:
# GETTING STARTED
# This demo was run on:
# * Ubuntu    18.04.4 LTS / CentOS 7.7.1908
# * CUDA      10.1
# * anaconda  1.7.2
# * gcc       6.3.1
# * pytorch   1.6
#
# To get started, run the following commands in a terminal:
#
#   git clone git@github.com:balbasty/nitorch.git
#   cd nitorch
#   conda env create --file ./conda/nitorch-demo.yml
#   conda activate nitorch-demo
#   pip install .

### Installations

First clone the repo...

In [None]:
!git clone https://github.com/balbasty/nitorch

Set-up of NITorch and dependencies

In [None]:
! pip install numpy
! pip install nibabel
! pip install matplotlib
! pip install scipy

In [None]:
! pip install ./nitorch/

Above cells can be pasted into your own colab notebook for easy install

### GETTING STARTED

First, we will import required packages

In [None]:
# Python
import os
import wget
import zipfile
from shutil import copyfile

# Torch / NiBabel
import torch

# NiTorch
from nitorch.core.pyutils import file_mod
from nitorch.tools.affine_reg._align import _test_cost
from nitorch.plot import show_slices
from nitorch.tools.preproc import (affine_align, world_reslice)
from nitorch.tools._preproc_utils import _format_input
from nitorch.core.linalg import _expm
from nitorch.spatial import affine_basis
from nitorch.io import map

and define some helper functions

In [None]:
def realign(pths, prefix='ma_', odir='', t_std=10, r_std=0.25):
    """Realign images.
    """
    # Make random realignment
    N = len(pths)
    q = torch.DoubleTensor(N, 6)
    torch.manual_seed(0)
    q[:, :3] = torch.DoubleTensor(N, 3).\
        uniform_(-t_std, t_std)  # random translation
    q[:, 3:] = torch.DoubleTensor(N, 3).\
        uniform_(-r_std, r_std)  # random rotation
    # Apply random realignment
    B = affine_basis(group='SE', dim=3)
    for n in range(N):
        # Make copy
        ipth = pths[n]
        opth = file_mod(ipth, prefix=prefix, odir=odir)
        os.makedirs(os.path.dirname(opth), exist_ok=True)
        copyfile(ipth, opth)
        # Compose transformations
        dat = map(opth)
        M = dat.affine
        R = _expm(q[n, ...], basis=B)
        M = M.solve(R)[0]
        # Modify affine in header
        dat.set_metadata(affine=M)
        pths[n] = opth

    return pths

def show_in_world_space(pths):
    """Show images in world space
    """
    for n in range(len(pths)):
        dat = world_reslice(pths[n], write=False)[0]
        _ = show_slices(dat, fig_num=n)

def show_dat(dat):
    """Show images.
    """
    for n in range(dat.shape[0]):
        _ = show_slices(dat[n, ...], fig_num=n)

and get the PyTorch device

In [None]:
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)
if device_type == 'cuda':
    print('GPU: ' + torch.cuda.get_device_name(0) + ', CUDA: ' + str(torch.cuda.is_available()))
else:
    print('CPU')

This demo will use three MRIs (T1w, T2w, PDw) from the BrainWeb simulator (https://brainweb.bic.mni.mcgill.ca/brainweb/).
When running the script for the first time, these images will be downloaded to the same folder as this notebook

In [None]:
# URL to MRIs
url = 'https://www.dropbox.com/s/8xh6tnf47hzkykx/brainweb.zip?dl=1'

# Path to downloaded zip-file
cwd = os.getcwd()
pth_zip = os.path.join(cwd, 'brainweb.zip')
pth_mris = [os.path.join(cwd, 't1_icbm_normal_1mm_pn0_rf0.nii'),
            os.path.join(cwd, 't2_icbm_normal_1mm_pn0_rf0.nii'),
            os.path.join(cwd, 'pd_icbm_normal_1mm_pn0_rf0.nii')]

# Download file
if not os.path.exists(pth_zip):
    print('Downloading images...', end='')
    wget.download(url, pth_zip)
    print('done!')

# Unzip file
if not all([os.path.exists(p) for p in pth_mris]):
    with zipfile.ZipFile(pth_zip, 'r') as zip_ref:
        zip_ref.extractall(cwd)

Let's have a look at the images

In [None]:
show_in_world_space(pth_mris)

### INSPECT REGISTRATION COST FUNCTION

Now, let's inspect the cost function's behaviour as we keep one image fixed and move the second image.
Feel free to test the different options below

In [None]:
cost_fun = 'nmi'    # ['nmi', 'mi', 'ncc' ,'ecc' ,'njtv', 'jtv']
mean_space = False  # Only available for 'njtv' cost function (as groupwise)
samp = 2            # Level of sub-sampling
ix_par = 0          # What parameter in the affine transformation to modify (0, ..., 11)
x_mn_mx = 30        # Min/max value of parameter
x_step = 0.5        # Step-size of parameter

dat, mat, _ = _format_input(pth_mris[:2], device=device)
_test_cost(dat, mat,
    ix_par=ix_par, cost_fun=cost_fun, mean_space=mean_space, samp=samp,
    x_mn_mx=x_mn_mx, x_step=x_step)

### CREATE RIGIDLY REALIGNED IMAGES

Next, we will rigidly realign all of the input scans, this writes new
images prefixed `ra_*`

In [None]:
pth_mris_ra = realign(pth_mris, prefix='ra_')

Let's have a look at the realigned images

In [None]:
show_in_world_space(pth_mris_ra)

### PAIRWISE REGISTRAITON

We will now align the images using pairwise registration by keeping one of
the images fixed and registering
all other images to this fixed target image

In [None]:
# Parameters
cost_fun = 'nmi'  # The normalised mutual information cost
fix = 0           # Set the first image to the fixed one (remember, there are three in total)
samp = (4, 2)     # Use the default sub-sampling scheme (speeds up the registration)

# Do registration
dat_aligned = affine_align(pth_mris_ra, device=device, cost_fun=cost_fun,
                           samp=samp)[0]

Let's look at the pairwise registration result

In [None]:
show_dat(dat_aligned)

### GROUPWISE REGISTRAITON

Finally, we will align the images using groupwise registration where the cost function optimiser over all images
at the same time. This is done using the Normalised Joint Total Variation (NJTV) cost function

In [None]:
# Parameters
cost_fun = 'njtv'  # Normalised joint total variation
mean_space = True  # Optimise a mean-space fit
samp = (4, 2)      # Use the default sub-sampling scheme (speeds up the registration)

# Do registration
dat_aligned = affine_align(pth_mris_ra, device=device, cost_fun=cost_fun,
                           samp=samp, mean_space=mean_space)[0]

Let's look at the groupwise registration result

In [None]:
show_dat(dat_aligned)