Skip to content

ngunnar/learning-a-deformable-registration-pyramid

Repository files navigation

Learning a Deformable Registration Pyramid

In this repository we show the implementation of a machine learning medical image registration method. The method was submitted to the Learn2Reg 2020 Challenge. The method is based on 3D downsampled CNN pyramid wherein displacement fields are estimated and refined at each level.

Environment

We are using Tensorflow 2.0.2. To install all necessary libaries run:

conda env create --file environment.yml

conda activate tf2

Architecture

Our methods is inspired by [PWC-Net], a 2D optical flow method popular in computer vision. Below is an overview of the architecture and a detail graph of the operations at each feature level.

overview figure levels figure

Description of the components:

  • Pyramid: Downsamples the moving and fixed image into several feature map levels using CNN layers. The same pyramid is used for the moving and the fixed images.
  • Warp (W): Warps features from moving images with the estimated displacement field.
  • Affine (A): A dense neural network that estimates the 12 parameters in an affine transformation.
  • Cost volume (CV): Correlation between the warped feature maps from the moving image and feature maps from the fixed image. For computational reasons the cost volume is restricted to the voxel neighbourhood.
  • Deform (D): A CNN that estimates the displacement field based on the affine displacement field, the cost volume and the feature maps from the fixed image.
  • Upsample (U): Upsamples the estimated displacement field from one level to the next.

Result

TBD

Dataset

see (https://learn2reg.grand-challenge.org/Datasets/) for instructions.

To run the traning and testing script we assuming the datasets are organized like this:

+-- task_02
|   +-- pairs_val.csv
|   +-- NIFTI
+-- task_02
|   +-- pairs_val.csv
|   +-- training
+-- task_03
|   +-- pairs_val.csv
|   +-- Training
+-- task_04
|   +-- pairs_val.csv
|   +-- Training
+-- Test
|    +-- task_01
|    |    +-- pairs_val.csv
|    |    +-- NIFTI
|    +-- task_02
|    |   +-- Training
|    +-- task_03
|    |   +-- Training
|    +-- task_04
|    |   +-- Training

Training

Train the model using images (and segmentations) for Task 2, 3 and 4 run

python train_model.py -ds {path to dataset root} -gpus {gpu numbers}

ex:
python train_model.py -ds /data/Learn2Reg/ -gpus 0,1,2

To fine tune the model on a specific task run:

python train_tf_task{TASK #}.py -ds {path to dataset root} -gpus {gpu numbers}

ex:
python train_tf_task2.py -ds /data/Learn2Reg/ -gpus 0,1,2

or feel free to modify, create your own training procedure

Testing

Create submission

Papers

TBD

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages