<div align="center">

<h1> Deformation-Recovery Diffusion Model (DRDM):
Instance Deformation for Image Manipulation and Synthesis
</h1>

<a href="https://jianqingzheng.github.io/def_diff_rec/"><img alt="Website" src="https://img.shields.io/website?url=https%3A%2F%2Fjianqingzheng.github.io%2Fdef_diff_rec%2F&up_message=online&up_color=darkcyan&down_message=offline&down_color=darkgray&label=Project%20Page"></a>
<a href="https://doi.org/10.48550/arXiv.2407.07295"><img alt="Website" src="https://img.shields.io/badge/arXiv-2407.07295-b31b1b.svg"></a>
<a href="https://github.com/jianqingzheng/def_diff_rec"><img src="https://img.shields.io/github/stars/jianqingzheng/def_diff_rec?style=social&label=Code+★" /></a>
</div>


Code for paper [Deformation-Recovery Diffusion Model (DRDM): Instance Deformation for Image Manipulation and Synthesis](https://doi.org/10.48550/arXiv.2407.07295)

> This repo provides an implementation of the training and inference pipeline of DRDM based on Pytorch.



---
### Contents ###
- 1. Installation
- 2. Usage
  - 2.1. Setup
  - 2.2. Training (~1 month)
  - 2.3. Inference
  - 2.4. Visualization
- 3. Citing this work

---

In [None]:
#@title 1. Installation {run: "auto"}
#@markdown Clone code from Github repo: https://github.com/jianqingzheng/def_diff_rec.git
%cd /content

!git clone https://github.com/jianqingzheng/def_diff_rec.git
%cd def_diff_rec/

#@markdown and Install packages

import torch
print('torch version: ',torch.__version__)

!pip install pyquaternion==0.9.9
!pip install pydicom==2.4.4
#@markdown > `torch==1.12.1+cu113` was the version originally used, but has changed here due to Colab compatibility issues.\
#@markdown > Other versions of the packages could also be applicable.

---

## 2. Usage


### 2.1. Setup ###


Directory layout:
```
[$DOWNLOAD_DIR]/def_diff_rec/
├── Config/
|   |   # configure file (.yaml files)
|   └── config_[$data_name].yaml
├── Data/
|   ├── Src_data/[$data_name]/
|   |   |   # processed image data for DRDM training (.nii|.nii.gz files)
|   |   ├── 0001.nii.gz
|   |   └── ...
|   ├── Tgt_data/[$data_name]/
|   |	├── Tr/
|   |   |   |   # image for deformation (.nii|.nii.gz files)
|   |   |   ├── 0001.nii.gz
|   |   |   └── ...
|   |	└── Gt/
|   |       |   # label for deformation (.nii|.nii.gz files)
|   |       ├── 0001.nii.gz
|   |       └── ...
|   └── Aug_data/[$data_name]/
|       |   # augmented data will be export to here (.nii|.nii.gz files)
|    	├── img/
|       |   |   # deformed image (.nii|.nii.gz files)
|       |   ├── 0001.nii.gz
|       |   └── ...
|    	├── msk/
|       |   |   # deformed label (.nii|.nii.gz files)
|       |   ├── 0001.nii.gz
|       |   └── ...
|    	└── ddf/
|           |   # deformation field (.nii|.nii.gz files)
|           ├── 0001.nii.gz
|           └── ...
├── models/
|   └── [$data_name]-[$model_name]/
|       |   # the files of model parameters (.pth files)
|       ├── [$epoch_id]_[$data_name]_[$model_name].pth
|       └── ...
└── ...
```


Configuration setting:

<div align="center">

| Argument              | Example           | Description                                	|
| --------------------- | ----------------- |----------------------------------------------|
| `--data_name` 	    |'cmr', 'lct'        | The data folder name                    |
| `--net_name` 	        |'recresacnet'      | The network name                    |
| `--ndims` 	        |2, 3                | The dimension of image                    |
| `--num_input_chn` 	|1, 3                | The channel number of input image               |
| `--img_size` 	        |256, 128            | The size of image                    |
| `--timesteps` 	    |80                 | The time step number for deformation             |
| `--v_scale` 	        |4.0e-05             | The time step number for deformation             |
| `--batchsize` 	    |64, 4               | The batch size for training                    |
| `--ddf_pad_mode` 	    |'border', 'zeros'   | The padding mode for integrating deformation field   |
| `--img_pad_mode` 	    |'border', 'zeros'   | The padding mode for resampling image    |
| `--resample_mode` 	|'nearest', 'bicubic'| The interpolation mode for resampling image     |
| `--device` 	        |'cuda', 'cpu'       | The used device     |
| `--patients_list` 	|[], [1], [1,2]       | The selected list of subject for augmentation     |
</div>

> configuration settings are edited in `[$DOWNLOAD_DIR]/def_diff_rec/Config/*.yaml`


---

### 2.2. Training (~1 month) ###

1. Run ```python DRDM_train.py --config Config/config_$data_name.yaml```
2. Check the saved model in `/content/def_diff_rec/models`


In [None]:
#@markdown \* Example for training (default):
data_name = 'cmr' #@param ["cmr","lct"]

!python DRDM_train.py --config Config/config_{data_name}.yaml

#@markdown > Training from scratch would take around 1 month,
#@markdown > which may not be possible in this demo
#@markdown > (the usage time limit in Colab is <12/24 hours).


---

### 2.3. Augmentation ###
1. Put the data to augment in `/content/def_diff_rec/Data/Tgt_data`
2. Run ```python DRDM_augment.py --config Config/config_$data_name.yaml```
3. Check the output data in `/content/def_diff_rec/Data/Aug_data`

In [None]:
#@markdown \* Example for augmentation (default):
data_name = 'cmr' #@param ["cmr","lct"]

!python DRDM_augment.py --config Config/config_{data_name}.yaml

#@markdown > default model is 0000.pth.

In [None]:
#@markdown \* Download the result file (after inference) from `/content/def_diff_rec/Data/Aug_data/$data_name`.

from google.colab import files
import os
download_path = os.path.join('Data','Aug_data',data_name)

!zip -r results.zip {download_path}/*
files.download(f"results.zip")
# files.download(download_path)
print('Download the results from: '+download_path)

---

In [None]:
#@title 2.4 Visualization
data_type = 'Aug_data' #@param ["Aug_data","Tgt_data","Src_data"]
data_format = 'img' #@param {type:"string"}
selected_img = 'Patient0001_Slice000000_AugImg0000_NoiseStep0048.nii.gz' #@param {type:"string"}

img_path = os.path.join('Data',data_type,data_name,data_format,selected_img)


from os.path import dirname, join
from pprint import pprint
import numpy as np
import ipywidgets as ipyw
import matplotlib.pyplot as plt
import nibabel as nib
class ImageSliceViewer3D:
  """
  ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
  ipython notebooks.

  User can interactively change the slice plane selection for the image and
  the slice plane being viewed.
Arguments:
  Volume = 3D input image
  figsize = default(8,8), to set the size of the figure
  cmap = default('gray'), string for the matplotlib colormap. You can find
  more matplotlib colormaps on the following link:
  https://matplotlib.org/users/colormaps.html

  """

  def __init__(self, volume, figsize=(100,100), cmap='gray'):
    self.volume = volume
    self.figsize = figsize
    self.cmap = cmap
    self.v = [np.min(volume), np.max(volume)]

    # Call to select slice plane
    ipyw.interact(self.views)

  def views(self):
    self.vol1 = np.transpose(self.volume, [1,2,0])
    self.vol2 = np.rot90(np.transpose(self.volume, [2,0,1]), 3) #rotate 270 degrees
    self.vol3 = np.transpose(self.volume, [0,1,2])
    maxZ1 = self.vol1.shape[2] - 1
    maxZ2 = self.vol2.shape[2] - 1
    maxZ3 = self.vol3.shape[2] - 1
    ipyw.interact(self.plot_slice,
        z1=ipyw.IntSlider(min=0, max=maxZ1, step=1, continuous_update=False,
        description='Axial:'),
        z2=ipyw.IntSlider(min=0, max=maxZ2, step=1, continuous_update=False,
        description='Coronal:'),
        z3=ipyw.IntSlider(min=0, max=maxZ3, step=1, continuous_update=False,
        description='Sagittal:'))
  def plot_slice(self, z1, z2, z3):
    # Plot slice for the given plane and slice
    f,ax = plt.subplots(1,3, figsize=self.figsize)
    #print(self.figsize)
    #self.fig = plt.figure(figsize=self.figsize)
    #f(figsize = self.figsize)
    ax[0].imshow(self.vol1[:,:,z1], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    ax[1].imshow(self.vol2[:,:,z2], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    ax[2].imshow(self.vol3[:,:,z3], cmap=plt.get_cmap(self.cmap),
        vmin=self.v[0], vmax=self.v[1])
    plt.show()

ImageSliceViewer3D(nib.load(img_path).slicer[:,:,:].get_fdata())


---

## 3. Citing this work

Any publication that discloses findings arising from using this source code or the network model should cite:

```bibtex
@article{zheng2024deformation,
  title={Deformation-Recovery Diffusion Model (DRDM): Instance Deformation for Image Manipulation and Synthesis},
  author={Zheng, Jian-Qing and Mo, Yuanhan and Sun, Yang and Li, Jiahua and Wu, Fuping and Wang, Ziyang and Vincent, Tonia and Papie{\.z}, Bart{\l}omiej W},
  journal={arXiv preprint arXiv:2407.07295},
  doi = {https://doi.org/10.48550/arXiv.2407.07295},
  url = {https://doi.org/10.48550/arXiv.2407.07295},
  keywords = {Image Synthesis, Generative Model, Data Augmentation, Segmentation, Registration}
  year={2024}
}     
```
