[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_train_lin_meas_colab.ipynb)

# Tutorial to train a reconstruction network 

Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10, for linear measurements. In specific, we choose a Hadamard positive matrix, but this can be replaced by any matrix. 

Training is performed by a call to *train.py*. Several parameters allow to modify acquisition, network and training (network architecture), optimisation and the use of tensorboard. 

Currently you can train the following networks by modifying the network architecture variable *arch*: 

- 'dc-net': Denoised Completion Network (DCNet). 
- 'pinv-net': Pseudo Inverse Network (PinvNet).
- 'upgd': Unrolled proximal gradient descent (UPGD). 

and the denoising variable *denoi*: 
- 'cnn': CNN no batch normalization
- 'cnnbn': CNN with batch normalization
- 'unet': UNet (0.5 M trainable parameters)
- 'drunet': DRUNet (high capacity residual UNet that allows training for all noise levels)

## Settings and requirements

### Set google colab

On colab, choose GPU at *Runtime/Change runtime type*

In [None]:
!nvidia-smi

### Dependencies

In [None]:
import os
import datetime

First, mount google drive to import modules spyrit modules.

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")
    %cd /content/gdrive/MyDrive/

    # For the profiler
    !pip install -U tensorboard-plugin-profile

    # Load the TensorBoard notebook extension
    %load_ext tensorboard

### Clone Spyrit package

Clone and install spyrit package if not installed or change to spyrit folder.

In [None]:
if (mode_colab is True):
    # Clone and install
    !git clone https://github.com/openspyrit/spyrit.git
    %cd spyrit
    !pip install -e .

    # Checkout to ongoing branch
    !git fetch --all

    # Add paths for modules
    import sys
    sys.path.append('./spyrit/core')
    sys.path.append('./spyrit/misc')
    sys.path.append('./spyrit/tutorial')
    %cd ..

    # Clone Spyrit-examples and checkout to branch tutorials
    !git clone https://github.com/openspyrit/spyrit-examples.git
    %cd spyrit-examples/tutorial

## Train

### Select data and training parameters

You can choose the following parameters:
- Measurements type (forward):
    - --meas: Measurement operator: 'hadam-split', 'hadam-pos'. Default="hadam-split" 
    - --noise: Noise operator: 'poisson', 'gauss-approx', 'no-noise'. Default="poisson"
    - --prep: Preprocessing operator: 'dir-poisson', 'split-poisson'. Default="dir-poisson"

- Acquisition: 
    - --img_size: Height / width dimension, default=64
    - --M: Number of undersampling patterns, default=512
    - --subs: Among 'var','rect', default="var"

- Network and training: 
    - --data: stl10 or imagenet, default="stl10"
    - --model_root: Path to model saving files, default='./model/'
    - --data_root: Path to the dataset, default="./data/"

    - --N0: Mean maximum total number of photons, default=10
    - --stat_root: Path to precomputed data (cov matrix), default=""
    - --arch: Choose among 'dc-net','pinv-net', 'upgd', default="dc-net"
    - --denoi: Choose among 'cnn','cnnbn', 'unet', default="unet"

- Optimisation:
    - --num_epochs: Number of training epochs, default=30
    - --batch_size: Size of each training batch, default=512
    - --reg: Regularisation Parameter, default=1e-7
    - --step_size: Scheduler Step Size, default=10
    - --gamma: Scheduler Decrease Rate, default=0.5
    - --checkpoint_model: Optional path to checkpoint model, default=""
    - --checkpoint_interval: Interval between saving model checkpoints, default=0
    - Training is done with *Adam* optimizer, *MSELoss*

- Tensorboard:
    - --tb_path: Relative path for Tensorboard experiment tracking logs, default=False
    - --tb_prof: Code profiler with Tensorboard, default=False
    - Logging of scalars *train_loss*, *val_loss* and images (dataset example ground-truth and predictions at different epochs).


In this tutorial, we consider noiseless data (1 mean photons) and an undersampling factor of 4. Training is done on stl10 dataset with default parameters and using experiment tracking with tensorboard. 

In [None]:
# Parameters
# (the first three paramaters allow to generalize train_gen_meas.py 
# to common measurement types)
meas = 'hadam-pos'    # measurement type
noise = 'no-noise' # noise type
prep = 'dir-poisson'    # preprocessing type
#
N0 = 1.0        # ph/pixel max: number of counts
img_size = 64   # image size
M =  img_size**2 // 4  # Num measurements = subsampled by factor 4
data_root = './data/'
data = 'stl10'
arch = 'pinv-net' # Network architecture
denoi = 'cnn' # Denoiser architecture
num_epochs = 30

# Tensorboard logs path
name_run = "stdl10_hadampos"
mode_tb = True
if (mode_tb is True):
    now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
    tb_path = f'runs/runs_{name_run}_n{int(N0)}_m{M}/{now}'
    print(f"Tensorboard logdir {tb_path}")
else:
    tb_path = None
    
tb_prof = False # False

### Train

In this notebook, training is done by calling `train_gen_meas.py`, which handles all the data, model definitions and training parameters for the provided tutorials, and then it calls `train_model` from `spyrit.core.train` module. For personalized training, you may want to use only `train_model` and create your personalized script version of `train_gen_meas.py`. 

If you find problems executing `!python3 train_gen_meas.py` directly, you may also try `subprocess`, but results are not shown during training. You may visualize them using tensorboard, see below. 

Training time: 
- 2 min to download stl10
- 2 min per epoch

In [None]:
# Run train.py
!python3 train_gen_meas.py --meas $meas --noise $noise --prep $prep --N0 $N0 --M $M --tb_path $tb_path --arch $arch --denoi $denoi --num_epochs $num_epochs

#import subprocess
#subprocess.run(['python3', 'train_gen_meas.py', '--meas', meas, '--noise', noise, '--prep', prep,
#                '--N0', str(N0), '--M', str(M), 
#                '--arch', arch, '--denoi', denoi, '--num_epochs', str(num_epochs),
#                '--tb_path', tb_path])

You can check that logs are being save under `spyrit-examples/tutorial/runs` (clicking `Files` icon on your left pannel or directyly in your drive).


## Check model is saved

In [None]:
# List model
!ls -R model

In [None]:
# List runs
!ls -R runs

### Tensorboard

You can launch tensorboard to visualize tracked metrics and images. Select *SCALARS* or *IMAGES* to visualize losses/metrics and reconstructed images, respectively. More options are available in the top-right corner (CNN weights, profiling). 

You can launch tensorboard in another notebook *launch_tensorboard_colab.ipynb* during training but it may not always work

In [None]:
# Launch TensorBoard
# %tensorboard --logdir $tb_path
%tensorboard --logdir runs

In [None]:
# If run twice tensorboard
#!lsof -i:6006
#!kill -9 17387

## Close colab session!

Don't forget to close colab session by deleting the instance at the upper menu *Runtime/Manage sessions/*.