In [None]:
# GETTING STARTED
# This demo was run on:
# * Ubuntu    18.04.4 LTS / CentOS 7.7.1908
# * CUDA      10.1
# * anaconda  1.7.2
# * gcc       6.3.1
# * pytorch   1.6
#
# To get started, run the following commands in a terminal:
#
#   git clone git@github.com:balbasty/nitorch.git
#   cd nitorch
#   conda env create --file ./conda/nitorch-demo.yml
#   conda activate nitorch-demo
#   pip install .

### GETTING STARTED

First, we will import required packages

In [None]:
# Python
import os
import wget
import math
from timeit import default_timer as timer
import zipfile

# Torch / NiBabel
import nibabel as nib
import torch
from torch.nn import functional as F

# NiTorch
from nitorch.tools.affine_reg import (apply2affine, reslice2fix, run_affine_reg, test_cost_function)
from nitorch.plot import show_slices
from nitorch.tools.preproc import reslice2world

define a viewer

In [None]:
def show_in_world_space(pths):
    """ Look at input images (in world space)
    """
    for n in range(len(pths)):
        dat = reslice2world(pths[n], write=False)[1]
        _ = show_slices(dat, fig_num=n)

and get the PyTorch device

In [None]:
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)
if device_type == 'cuda':
    print('GPU: ' + torch.cuda.get_device_name(0) + ', CUDA: ' + str(torch.cuda.is_available()))
else:
    print('CPU')

This demo will use three MRIs (T1w, T2w, PDw) from the BrainWeb simulator (https://brainweb.bic.mni.mcgill.ca/brainweb/).
When running the script for the first time, these images will be downloaded to the same folder as this notebook

In [None]:
# URL to MRIs
url = 'https://www.dropbox.com/s/8xh6tnf47hzkykx/brainweb.zip?dl=1'

# Path to downloaded zip-file
cwd = os.getcwd()
pth_zip = os.path.join(cwd, 'brainweb.zip')
pth_mris = [os.path.join(cwd, 't1_icbm_normal_1mm_pn0_rf0.nii'),
            os.path.join(cwd, 't2_icbm_normal_1mm_pn0_rf0.nii'),
            os.path.join(cwd, 'pd_icbm_normal_1mm_pn0_rf0.nii')]

# Download file
if not os.path.exists(pth_zip):
    print('Downloading image...', end='')
    wget.download(url, pth_zip)
    print('done!')

# Unzip file
if not all([os.path.exists(p) for p in pth_mris]):
    with zipfile.ZipFile(pth_zip, 'r') as zip_ref:
        zip_ref.extractall(cwd)

Let's have a look at the images

In [None]:
show_in_world_space(pth_mris)

### INSPECT REGISTRATION COST FUNCTION

Now, let's inspect the cost function's behaviour as we keep one image fixed and move the second image.
Feel free to test the different options below

In [None]:
cost_fun = 'nmi'  # ['nmi', 'mi', 'ncc' ,'ecc' ,'njtv']
mean_space = False  # On;y available for 'njtv' cost function (as groupwise)
samp = 2  # Level of sub-sampling
ix_par = 0 # What parameter in the affine transformation to modify (0, ..., 11)
x_mn_mx = 30  # Min/max value of parameter
x_step = 0.1  # Step-size of parameter

test_cost_function([pth_mris[0], pth_mris[1]],
    cost_fun=cost_fun, mean_space=False, samp=samp, ix_par=ix_par, x_mn_mx=x_mn_mx, x_step=x_step)

### CREATE RIGIDLY MISALIGNED IMAGES

Next, we will rigidly misalign all of the input scans, this writes new images prefixed `ma_*`

In [None]:
# Max translational offset
t_std = 10
# Max rotational offset
r_std = 0.25
q_ma = torch.DoubleTensor(len(pth_mris), 6)
torch.manual_seed(0)
q_ma[:, :3] = torch.DoubleTensor(len(pth_mris), 3).uniform_(-t_std, t_std) # random translation
q_ma[:, 3:] = torch.DoubleTensor(len(pth_mris), 3).uniform_(-r_std, r_std) # random rotation
pth_mris_ma = apply2affine(pth_mris, q_ma, prefix='ma_')

Let's have a look at the misaligned images

In [None]:
show_in_world_space(pth_mris_ma)

### PAIRWISE REGISTRAITON

Next, we will align the images using pairwise registration by keeping one of the images fixed and registering
all other images to this fixed target image

In [None]:
# Parameters
cost_fun = 'nmi'  # The normalised mutual information cost
fix = 0  # Set the first image to the fixed one (remember, there are three in total)
samp = (4, 2)  # Use the default sub-sampling scheme (speeds up the registration)

# Do registration
# NOTE: The input images (pth_mris_ma) can also be given as
# tensors (see documentation in run_affine_reg)
q_est_pw, mat_fix, dim_fix = run_affine_reg(pth_mris_ma,
    device=device, samp=samp, cost_fun=cost_fun, fix=fix)

# NOTE: Registration results can be applied to affine in image header by doing
# pth_mris_pa = apply2affine(pth_mris_ma, q_est_pw, B=B)

Let's look at the pairwise registration result

In [None]:
# Get image data and reslice to the fixed image using the registration result (q_est)
rdat, _ = reslice2fix(pth_mris_ma, q_est_pw, mat_fix, dim_fix, device=device, write=False)
# Show images
_ = show_slices(rdat)

### GROUPWISE REGISTRAITON

Finally, we will align the images using groupwise registration where the cost function optimiser over all images
at the same time. This is done using the Normalised Joint Total Variation (NJTV) cost function

In [None]:
# Parameters
cost_fun = 'njtv'  # Normalised joint total variation
mean_space = True  # Optimise a mean-space fit
samp = (4, 2)  # Use the default sub-sampling scheme (speeds up the registration)

# Do registration
q_est_gw, mat_fix, dim_fix = run_affine_reg(pth_mris_ma,
    device=device, samp=samp, cost_fun=cost_fun, mean_space=mean_space)

Let's look at the groupwise registration result

In [None]:
# Get image data and reslice to the mean-space using the registration result (q_est)
rdat, _ = reslice2fix(pth_mris_ma, q_est_gw, mat_fix, dim_fix, device=device, write=False)
# Show images
_ = show_slices(rdat)