# Training a convolutional network for anatomy guided PET image denoising and deblurring

In this tutorial, we will learn how to set up and train a simple 3D convolutional network for anatomical-guided denoising and deblurring of PET images.

We will set up a network that takes a batch of 3D tensors with 2 channels (e.g. PET and MR) as input and outputs a batch of 3D tensors with 1 channel (denoised and deblurred PET image).


The approach and the model architecture that we will use in this tutorial is inspired by Schramm et al., ["Approximating anatomically-guided PET reconstruction in image space using a convolutional neural network"](https://doi.org/10.1016/j.neuroimage.2020.117399), NeuroImage 2021, DOI 10.1016/j.neuroimage.2020.117399

![foo bar](https://raw.githubusercontent.com/gschramm/pyapetnet/master/figures/fig_1_apetnet.png)



This tutotial uses simulated PET/MR data based on the brain web. However, applying the same training strategy to real data should be straight forward. 

To setup and train the model, we will use tensorflow and keras to show the basic concepts of setting up and training a model. Of course, the same concepts can be used with any other deep learning frame work (such as e.g. pytorch).

### The tutorial is split into two notebooks
1. In the first notebook ([01_tf_data.ipynb](01_tf_data.ipynb)), we will learn how to setup a data loader pipeline to efficiently create mini batches of training data including data augmentation.
2. In the second notebook ([02_tf_models.ipynb](02_tf_models.ipynb)), we will learn how to setup and train the model architecture shown above.

Finally, it wil be you turn to combine the knowledge of these two notebooks to train your own network.

### To run this tutorial, you need to install the following python packages
- ```pyapetnet >= 1.1``` (or later, will install dependencies tensorflow, nibabel, pymirc, ...) 
- ```pydot >= 1.4``` 
- ```graphviz >= 0.16```
- ```ipympl >= 0.7```

All packages are available on pypi and can be installed via:
```
pip install pyapetnet
pip install pydot
pip install graphviz
pip install ipympl
```

**If you are running these notebooks as part of the Training School for the Synergistic Image Reconstruction Framework (SIRF) and Core Imaging Library (CIL) on the <font color='red'>STFC jupyter cloud servers</font>, these packages are already available and do not need to be installed.**

### Downloading data used in this tutorial

The data sets that we will use in these notebooks are available on zenodo at:
https://zenodo.org/record/4897350/files/brainweb_petmr.zip

Please download this zip file (ca 10GB, download takes ca 15min depending on your connection) and unzip it and place it somewhere on your machine.
In all notebooks, this location will be stored in the ```data_path``` variable which **might need to
be adjusted.**

**If you are running these notebooks as part of the Training School for the Synergistic Image Reconstruction Framework (SIRF) and Core Imaging Library (CIL) on the <font color='red'>STFC jupyter servers</font>, the data is available in
```/mnt/materials/SIRF/Fully3D/DL/brainweb_petmr/``` and does not need to be downloaded.**

The cell below looks for all subjects in ```data_path``` and find 20 subjects.
If ```data_path``` is correctly set, the output should look sth like.

```
01 brainweb_petmr/subject04
02 brainweb_petmr/subject05
03 brainweb_petmr/subject06
04 brainweb_petmr/subject18
05 brainweb_petmr/subject20
06 brainweb_petmr/subject38
07 brainweb_petmr/subject41
08 brainweb_petmr/subject42
09 brainweb_petmr/subject43
10 brainweb_petmr/subject44
11 brainweb_petmr/subject45
12 brainweb_petmr/subject46
13 brainweb_petmr/subject47
14 brainweb_petmr/subject48
15 brainweb_petmr/subject49
16 brainweb_petmr/subject50
17 brainweb_petmr/subject51
18 brainweb_petmr/subject52
19 brainweb_petmr/subject53
20 brainweb_petmr/subject54
```


In [4]:
import pathlib

# adjust this variable to the path where the simulated PET/MR data from zenodo was unzipped
data_path = pathlib.Path('brainweb_petmr')

# print all downloaded subjects
for i, p in enumerate(sorted(list(data_path.glob('subject??')))):
  print(f'{(i+1):02}', str(p))

01 brainweb_petmr/subject04
02 brainweb_petmr/subject05
03 brainweb_petmr/subject06
04 brainweb_petmr/subject18
05 brainweb_petmr/subject20
06 brainweb_petmr/subject38
07 brainweb_petmr/subject41
08 brainweb_petmr/subject42
09 brainweb_petmr/subject43
10 brainweb_petmr/subject44
11 brainweb_petmr/subject45
12 brainweb_petmr/subject46
13 brainweb_petmr/subject47
14 brainweb_petmr/subject48
15 brainweb_petmr/subject49
16 brainweb_petmr/subject50
17 brainweb_petmr/subject51
18 brainweb_petmr/subject52
19 brainweb_petmr/subject53
20 brainweb_petmr/subject54
