[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb)

# Tutorial to train a reconstruction network 

Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10.

Current example trains DCNET (data completion with UNet denoising with 0.5 M trainable parameters). 

## Settings and requirements

In [None]:
import os
import datetime

First, mount google drive to import modules spyrit modules.

### Set google colab

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")
    %cd /content/gdrive/MyDrive/

    # For the profiler
    !pip install -U tensorboard-plugin-profile

    # Load the TensorBoard notebook extension
    %load_ext tensorboard

On colab, hoose GPU at *Runtime/Change runtime type*

In [None]:
!nvidia-smi

### Clone Spyrit package

Clone and install spyrit package if not installedClone and install spyrit package if not installed or move to spyrit folder

In [None]:
install_spyrit = True
if (mode_colab is True):
    if install_spyrit is True:
        # Clone and install
        !git clone https://github.com/openspyrit/spyrit.git
        %cd spyrit
        !pip install -e .

        # Checkout to ongoing branch
        !git fetch --all
    else:
        # cd to spyrit folder is already cloned in your drive
        %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit

    # Add paths for modules
    import sys
    sys.path.append('./spyrit/core')
    sys.path.append('./spyrit/misc')
    sys.path.append('./spyrit/tutorial')
else:
    # Change path to spyrit/
    os.chdir('../..')
    !pwd

## Download data

Download covariance matrix. Alternatively install *openspyrit/spas* package:
```
├───stats
│   ├───Average_64x64.npy
│   ├───Cov_64x64.npy
```

In [None]:
download_cov = True
if (download_cov is True):
    !pip install girder-client
    import girder_client

    # api Rest url of the warehouse
    url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'
    
    # Generate the warehouse client
    gc = girder_client.GirderClient(apiUrl=url)

    #%% Download the covariance matrix and mean image
    data_folder = './stat/'
    dataId_list = [
            '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)
            '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)
            ]
    for dataId in dataId_list:
        myfile = gc.getFile(dataId)
        gc.downloadFile(dataId, data_folder + myfile['name'])

    print(f'Created {data_folder}') 
    !ls $data_folder

## Train

Perturbed by Poisson noise (100 photons) and undersampling factor of 4, on stl10 dataset

In [None]:
# Parameters
N0 = 100
M = 1024
data_root = './data/'
data = 'stl10'
stat_root = './stat'
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
tb_path = f'runs/runs_stdl10_n100_m1024/{now}' # None
tb_prof = True # False

In [None]:
# Run tuto_train
if (mode_colab is True):
    # Copy tuto_train.py to main directory for colab
    !pwd
    !cp spyrit/tutorial/train.py .
    !python3 train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof
    !rm train.py
else:
    !python3 spyrit/tutorial/train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof

## Evaluate the trained model

### Tensorboard

In [None]:
# Launch TensorBoard
# %tensorboard --logdir $tb_path
%tensorboard --logdir runs

In [None]:
# If run twice tensorboard
#!lsof -i:6006
#!kill -9 17387