<div align="center">
<h1> Residual Aligner-based Network (RAN) for Coarse-to-fine Discontinuous Deformable Registration </h1>

[![DOI](https://img.shields.io/badge/DOI-j.media.2023.103038-darkyellow)](https://doi.org/10.1016/j.media.2023.103038) \|
[![arXiv](https://img.shields.io/badge/arXiv-2203.04290-b31b1b.svg)](https://arxiv.org/abs/2203.04290) \|
<a href="https://github.com/jianqingzheng/res_aligner_net"><img src="https://img.shields.io/github/stars/jianqingzheng/res_aligner_net?style=social&label=Code+★" /></a>
\|
[![Explore RAN in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jianqingzheng/res_aligner_net/blob/main/res_aligner_net.ipynb)

</div>

Code for *Medical Image Analysis* paper [Residual Aligner-based Network (RAN): Motion-Separable Structure for Coarse-to-fine Deformable Image Registration](https://doi.org/10.1016/j.media.2023.103038)


> This repo provides an implementation of the training and inference pipeline of RAN based on TensorFlow and Keras.



---
### Contents ###
- 1. Installation
- 2. Usage
  - 2.1. Setup (for unpaired data)
  - 2.2. Training (>1 week)
  - 2.3. Inference
  - 2.4. Visualization
- 3. Citing this work

---

In [None]:
#@title 1. Installation {run: "auto"}
#@markdown Clone code from Github repo: https://github.com/jianqingzheng/XBCR-net.git
%cd /content

!git clone https://github.com/jianqingzheng/res_aligner_net.git
%cd res_aligner_net/

#@markdown and Install packages

import tensorflow as tf
print('tf version: ',tf.__version__)

!pip install pyquaternion==0.9.9

#@markdown > `tensorflow==2.3.1` was the version originally used, but has changed here due to Colab compatibility issues.\
#@markdown > Other versions of the packages could also be applicable.

---

## 2. Usage


### 2.1. Setup (for unpaired data) ###
```
[$DOWNLOAD_DIR]/res_aligner_net/           
├── data/[$data_name]/dataset
|   |   # experimental dataset for training and testing (.nii|.nii.gz files)
|   ├── train/
|   |   ├── images/
|   |   |   ├──0001.nii.gz
|   |   |   └── ...
|   |   ├── labels/
|   |   |   ├──0001.nii.gz
|   |   |   └── ...
|   ├── test/
|   |   ├── images/
|   |   |   ├──0001.nii.gz
|   |   |   └── ...
|   |   └── labels/
|   |       ├──0001.nii.gz
|   |       └── ...
└── models/[$data_name]/
|   └── [$data_name]-[$model_name]/
|       |   # the files of model parameters (.tf.index and .tf.data-000000-of-00001 files)
|       ├── model_1_[$model_num].tf.index
|       ├── model_1_[$model_num].tf.data-000000-of-00001
|       └── ...
└── ...
```
> The data used for experiments in this paper are publicly available from [abdomen CT](https://github.com/ucl-candi/datasets_deepreg_demo/archive/abdct.zip) and [lung CT](https://zenodo.org/record/3835682).



In [None]:
#@markdown \* Download data (default):
data_name = 'unpaired_ct_abdomen' #@param ["unpaired_ct_abdomen","unpaired_ct_lung"]

if data_name == 'unpaired_ct_abdomen':
  data_download_py="abd_data.py"
elif data_name == 'unpaired_ct_lung':
  data_download_py="lung_data.py"

import os
data_path='data'
os.makedirs(os.path.join(data_path,data_name), exist_ok=True)

!python external/deepreg/{data_download_py}
!python main_preprocess.py --proc_type train --data_name {data_name}
!python main_preprocess.py --proc_type test --data_name {data_name}


---

### 2.2. Training (>1 week) ###

1. Run ```python main_train.py --model_name $model_name --data_name $data_name --max_epochs $max_epochs```
2. Check the saved model in ```res_aligner_net/models/$data_name/$data_name-$model_name/```


<div align="center">

| Argument              | Description                                	|
| --------------------- | ----------------------------------------------|
| `--data_name` 	      | The data folder name                    |
| `--model_name`        | The used model                      	     	|
| `--max_epochs`        | The max epoch number for training 	     	|

</div>


In [None]:
#@markdown \* Example for training (default):
data_name = 'unpaired_ct_abdomen' #@param ["unpaired_ct_abdomen","unpaired_ct_lung"]
model_name = 'RAN4' #@param {type:"string"}
max_epochs = 0  #@param {type:"integer"}

!python main_train.py --model_name {model_name} --data_name {data_name} --max_epochs {max_epochs}

#@markdown > `max_epochs=0` indicates training from scratch. \
#@markdown > Training from scratch would take more than 1 week,
#@markdown > which may not be possible in this demo
#@markdown > (the usage time limit in Colab is <12/24 hours).


---

### 2.3. Inference ###
1. Run ```python main_infer.py --model_name $model_name --data_name $data_name```
2. Check the results in ```res_aligner_net/data/$data_name/dataset/test/```

<div align="center">

| Argument              | Description                                	|
| --------------------- | ----------------------------------------------|
| `--data_name` 	| The data folder name                       	|
| `--model_name`        | The used model                      	     	|

</div>

In [None]:
#@markdown \* Example for inference (default):
data_name = 'unpaired_ct_abdomen' #@param ["unpaired_ct_abdomen","unpaired_ct_lung"]
model_name = 'RAN4' #@param {type:"string"}
model_id = "2" #@param ["1","2","3"]

!python main_infer.py --model_name {model_name} --model_id {model_id} --data_name {data_name}

#@markdown > `model_id==1` for a model after synthetic training,
#@markdown > `model_id==2` for a model after real training,
#@markdown > `model_id==3` for the model trained according to paper's settings
#@markdown > (seems to be incompatible with the version of TensorFlow/Keras in Colab).

In [49]:
#@markdown \* Download the result file (after inference).

from google.colab import files
import os
download_path = os.path.join('data',data_name,'dataset','test_proc','warped_img')

!zip -r results.zip {download_path}/*
files.download(f"results.zip")
# files.download(download_path)
print('Download the results from: '+download_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download the results from: data/unpaired_ct_abdomen/dataset/test_proc/warped_img


---

In [None]:
#@title 2.4 Visualization

target_id = 0 #@param {type:"integer"}
source_id = 1 #@param {type:"integer"}

#@markdown > visualize the original image with id=`target_id` when `target_id==source_id`\
#@markdown > visualize the warped image with `target_id` and `source_id` when `target_id!=source_id`

download_path = os.path.join('data',data_name,'dataset','test_proc','warped_img')

if target_id == source_id:
  img_path = os.path.join(download_path,'img_target_'+str(target_id)+'.nii')
else:
  img_path = os.path.join(download_path,'img_warped_'+str(model_name)+'_'+str(target_id)+'_from_'+str(source_id)+'.nii')


from os.path import dirname, join
from pprint import pprint
import numpy as np
import ipywidgets as ipyw
import matplotlib.pyplot as plt
import nibabel as nib
class ImageSliceViewer3D:
  """
  ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
  ipython notebooks.

  User can interactively change the slice plane selection for the image and
  the slice plane being viewed.
Arguments:
  Volume = 3D input image
  figsize = default(8,8), to set the size of the figure
  cmap = default('gray'), string for the matplotlib colormap. You can find
  more matplotlib colormaps on the following link:
  https://matplotlib.org/users/colormaps.html

  """

  def __init__(self, volume, figsize=(100,100), cmap='gray'):
    self.volume = volume
    self.figsize = figsize
    self.cmap = cmap
    self.v = [np.min(volume), np.max(volume)]

    # Call to select slice plane
    ipyw.interact(self.views)

  def views(self):
    self.vol1 = np.transpose(self.volume, [1,2,0])
    self.vol2 = np.rot90(np.transpose(self.volume, [2,0,1]), 3) #rotate 270 degrees
    self.vol3 = np.transpose(self.volume, [0,1,2])
    maxZ1 = self.vol1.shape[2] - 1
    maxZ2 = self.vol2.shape[2] - 1
    maxZ3 = self.vol3.shape[2] - 1
    ipyw.interact(self.plot_slice,
        z1=ipyw.IntSlider(min=0, max=maxZ1, step=1, continuous_update=False,
        description='Axial:'),
        z2=ipyw.IntSlider(min=0, max=maxZ2, step=1, continuous_update=False,
        description='Coronal:'),
        z3=ipyw.IntSlider(min=0, max=maxZ3, step=1, continuous_update=False,
        description='Sagittal:'))
  def plot_slice(self, z1, z2, z3):
    # Plot slice for the given plane and slice
    f,ax = plt.subplots(1,3, figsize=self.figsize)
    #print(self.figsize)
    #self.fig = plt.figure(figsize=self.figsize)
    #f(figsize = self.figsize)
    ax[0].imshow(self.vol1[:,:,z1], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    ax[1].imshow(self.vol2[:,:,z2], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    ax[2].imshow(self.vol3[:,:,z3], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    plt.show()

ImageSliceViewer3D(nib.load(img_path).slicer[:,:,:].get_fdata())


---

## 3. Citing this work

Any publication that discloses findings arising from using this source code or the network model should cite:
- Zheng, J. Q., Wang, Z., Huang, B., Lim, N. H., & Papież, B. W. "Residual Aligner-based Network (RAN): Motion-separable structure for coarse-to-fine discontinuous deformable registration." *Medical Image Analysis*, 2024, 91: 103038.
```bibtex
@article{ZHENG2024103038,
	title = {Residual Aligner-based Network (RAN): Motion-separable structure for coarse-to-fine discontinuous deformable registration},
	journal = {Medical Image Analysis},
	volume = {91},
	pages = {103038},
	year = {2024},
	issn = {1361-8415},
	doi = {https://doi.org/10.1016/j.media.2023.103038},
	url = {https://www.sciencedirect.com/science/article/pii/S1361841523002980},
	author = {Jian-Qing Zheng and Ziyang Wang and Baoru Huang and Ngee Han Lim and Bartłomiej W. Papież},
	keywords = {Discontinuous deformable registration, Motion-separable structure, Motion disentanglement, Coarse-to-fine registration},
}
```
