[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_core_2d_short.ipynb)

# Tutorial 2D - Image reconstruction for single-pixel imaging


This tutorial shows how to simulate data and perform image reconstruction with spyrit toolbox for 2D single-pixel imaging. 


The image below displays the ground-truth image, undersampled data and reconstruction with different methods.

<img src="https://drive.google.com/uc?id=1zBrCKHWM-AnAL1bs6HIZZ49wztWk_rvG" alt="drawing" width="800"/>


For **data simulation**, it loads an image from ImageNet and simulated measurements based on 
an undersampled Hadamard operator. You can select the number of counts and undersampled factor. 

**Image reconstruction** is performed using the following methods: 
-    Pseudo-inverse
-    PInvNet:        Linear net (same result as Pseudo-inverse)
-    DCNet:          Data completion net with unit matrix denoising
-    DCUNet:         Data completion with UNet denoising, trained on stl10 dataset (requires to download UNet weights). 

In this tutorial, we have adopted a simplified version for data simulation based on unit covariance and undersampling factor based on the first $M$ measurements, for simplicity. In order to replicate the results above, just set *download_cov=True*.

In [None]:
# Set download data covariance to True for realistic simulations
# It taken a few minutes to download the data
download_cov = True

### Set google colab

Set *mode_colab=True* to run in colab. Mount google drive, needed to import spyrit modules.

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")  
    %cd /content/gdrive/MyDrive/    

On colab, to run on GPU, select *GPU* from the navigation menu *Runtime/Change runtime type*.

In [None]:
!nvidia-smi

## Settings and requirements

In [None]:
import os
import numpy as np
import torch
import torchvision

### Clone Spyrit package

Clone and install spyrit package if not installed or move to spyrit folder

In [None]:
# #%%capture
install_spyrit = True
if (mode_colab is True):
    if install_spyrit is True:
        # Clone and install
        !git clone https://github.com/openspyrit/spyrit.git
        %cd spyrit
        !pip install -e .

        # Checkout to ongoing branch
        !git fetch --all
        !pip install girder_client
    else:
        # cd to spyrit folder is already cloned in your drive
        %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit

    # Add paths for modules
    import sys
    sys.path.append('./spyrit/core')
    sys.path.append('./spyrit/misc')
    sys.path.append('./spyrit/tutorial')
else:
    # Change path to spyrit/
    os.chdir('../..')
    !pwd

In [None]:
# Load spyrit modules
from spyrit.core.meas import HadamSplit
from spyrit.core.noise import NoNoise, Poisson, PoissonApproxGauss
from spyrit.core.prep import SplitPoisson
from spyrit.core.recon import PseudoInverse, PinvNet, DCNet
from spyrit.core.nnet import Unet
from spyrit.misc.statistics import Cov2Var, data_loaders_stl10, transform_gray_norm
from spyrit.misc.disp import imagesc 
from spyrit.misc.sampling import meas2img2
from spyrit.core.train import load_net

## Download covariance and DCNet model

Download full covariance matrix (default set to unit matrix):
```
├───stat
│   ├───Average_64x64.npy
│   ├───Cov_64x64.npy
├───model
│   ├───dc-net_unet_... .pth
├───spirit
```

In [None]:
if (download_cov is True):
    import girder_client

    # api Rest url of the warehouse
    url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'
    
    # Generate the warehouse client
    gc = girder_client.GirderClient(apiUrl=url)

    #%% Download the covariance matrix and mean image
    data_folder = './stat/'
    dataId_list = [
            '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)
            '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)
            ]
    for dataId in dataId_list:
        myfile = gc.getFile(dataId)
        gc.downloadFile(dataId, data_folder + myfile['name'])

    print(f'Created {data_folder}') 
    !ls $data_folder

    #%% Download the models
    data_folder = './model/'
    dataId_list = [
                #'644a38c985f48d3da07140ba', # N_rec = 64, M = 4095
                '644a38c785f48d3da07140b7', # N_rec = 64, M = 1024
                #'644a38c585f48d3da07140b4', # N_rec = 64, M = 512
                ]

In [None]:
# Parameters
H = 64                          # Image height (assumed squared image)
M = H**2 // 4                   # Num measurements = subsampled by factor 2
B = 10                          # Batch size
alpha = 100                     # ph/pixel max: number of counts
load_cov = False                 # Load cov matrix (requires /stat/Cov_64x64.npy); 
                                # otherwise, set to unit matrix
load_unet = True                # Load pretrained UNet denoising

imgs_path = './spyrit/images'

cov_name = './stat/Cov_64x64.npy'

# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load data

Load a batch of images from the folder *spyrit/images*.

In [None]:
# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=H)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = min(B, len(dataset)))

# Select image
x0, _ = next(iter(dataloader))
x0 = x0[1:6,:,:,:]
x = x0.detach().clone()
b,c,h,w = x.shape
x = x.view(b*c,h*w)
print(f'Shape of incoming image (b*c,h*w): {x.shape}')

x_plot = x.view(-1,H,H).cpu().numpy()    
imagesc(x_plot[0,:,:],'Ground-truth image normalized')

## Neural network pipeline and operators

### Experimental data simulation: Split measurement, noise and raw measurement operators

We adopt linear measurements using a Hadamard operator corrupted by Poisson noise.

Data simulation in spyrit is done by using three operators from using *spyrit.core.meas*: image normalization, split measurements and noise perturbation. In the example below, this corresponds to the following steps:

$$
x \xrightarrow[\text{Step 1}]{\text{NoNoise}} \frac{x+1}{2} \xrightarrow[\text{Step 2}]{\text{HadamSplit}} y=Px \xrightarrow[\text{Step 3}]{\text{Poisson}} \mathcal{P}(\alpha y)
$$

- Step 1: Given an image $x$ between $[-1, 1]$, the image is first normalized between $[0, 1]$ as

$$
\frac{x+1}{2}
$$
using *spyrit's* *spyrit.core.noise.NoNoise* operator.

- Step 2: Split measurements $y$ are obtained via the linear operator $P$: 

$$
y = Px = 
\begin{pmatrix}
H_{+} \\
H_{-}
\end{pmatrix}x=
\begin{pmatrix}
\max(H, 0) \\
\max(0,-H)
\end{pmatrix}x
$$

where $H=(H_{+}-H_{-})$ is a Hadamard matrix. 

- Step 3: Data is finally perturbed by Poisson noise as

$$
\tilde{y} = \mathcal{P}(\alpha y)
$$

using spirit's *spyrit.core.noise.Poisson* :
```   
    meas_op = HadamSplit(M, h, Ord),
    noise = Poisson(meas_op, alpha)
    y = noise(x)
```
where $M$ is the number of meas, $h$ the height, and $Ord$ the Ordering matrix for undersampling. Inheritage is used as 
```
    Poisson(NoNoise)
    NoNoise(nn.module) 
```


In [None]:
# Operators 
#
# Order matrix with shape (H, H) used to compute the permutation matrix 
# (as undersampling taking the first rows only)
try:
    Cov  = np.load(cov_name)
except:
    Cov = np.eye(H*H)
    print(f"Cov matrix {cov_name} not found! Set to the identity")
    
Ord = Cov2Var(Cov)

# Measurement operator: 
# Computes linear measurements y=Px, where P is a linear operator (matrix) with positive entries      
# such that P=[H_{+}; H_{-}]=[max(H,0); max(0,-H)], H=H_{+}-H_{-}
meas_op = HadamSplit(M, H, Ord)

# Simulates raw split measurements from images in the range [0,1] assuming images provided in range [-1,1]
# y=0.5*H(1 + x)
# noise = NoNoise(meas_op) # noiseless
noise = Poisson(meas_op, alpha)

# Simulate raw measurements (non neagative measurements)
y = noise(x)
print(f'Shape of simulated measurements y: {y.shape}')

m_plot = y.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements y: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Simulated Measurement')

Note that measurements are positive, as expected experimentally. 


### Preprocessing measurement operator 
Note that previous steps allow to simulate experimental split measurements, which only considers positive pixels. A fourth step is done in order to preprocess the raw data acquired with a split measurements operator

$$
y \xrightarrow[\text{Step 4}]{\text{Prep}} m=\frac{y_+-y_-}{\alpha},
$$

where $y_+=H_+x$, which in spyrit is done with *spyrit.core.prep.SplitPoisson*
```
    prep = SplitPoisson(alpha, meas_op)
    m = prep(y)
```
Now, measurements can be negative.

In [None]:
# Preprocess the raw data acquired with split measurement operator assuming Poisson noise
prep = SplitPoisson(alpha, meas_op)

# Preprocessed data
m = prep(y)
print(f'Shape of preprocessed data m: {m.shape}')


m_plot = m.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements m: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Preprocessed data')

### Reconstruction operators

Let $\tilde{H}$ represent the retaining Hadamard patterns, the Hadamard coefficients $m$ are given in terms of the image $x$ as 

$$
m=\tilde{H}x.
$$

The Hadamard operator is orthogonal, so an unknown image $x$ can be recovered from the inverse Hadamard transform 

$$
\hat{x}=\tilde{H}^{\dagger}m.
$$

Image reconstruction in spyrit comprises four steps (the last one is optional):

1. Denoising

2. Data completion

3. Hadamard inverse transform

4. Nonlinear postprocessing

$$
m\in\mathbb{R}^{M} \xrightarrow[\text{Denoising}]{\text{Step 1}} y_1\in\mathbb{R}^{M} \xrightarrow[\text{Completion}]{\text{Step 2}} y_2\in\mathbb{R}^{N-M} \xrightarrow[\text{Inverse}]{\text{Step 3}} \tilde{x}\in\mathbb{R}^{N} \xrightarrow[\text{Postprocessing}]{\text{Step 4}} \hat{x}\in\mathbb{R}^{N_x\times N_y}
$$

In spyrit, the four steps are comprised inside *spyrit.core.recon.PinvNet* or *spyrit.core.recon.DCNet*, and are automatically handed for sigle-pixel imaging data. The denoising network in the nonlinear step in dealt by *spyrit.core.nnet.Unet* and must be defined :

```
    denoi = Unet()
    dcnet_unet = DCNet(noise, prep, Cov, denoi)
    z_dcnet = dcnet.reconstruct(y) 
```

In [None]:
# Pseudo-inverse operator
pinv = PseudoInverse()

# Reconstruction
z_pinv = pinv(m, meas_op)
print(f'Shape of reconstructed image z: {z_pinv.shape}')

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')

In [None]:
# Pseudo-inverse net

# Reconstruction with for Core module (linear net)
pinvnet = PinvNet(noise, prep)
 
# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pinvnet = pinvnet.to(device)

x = x0.detach().clone()
x = x.to(device)
z_pinvnet = pinvnet(x)
# z_pinvnet = pinvnet.reconstruct(y)

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')

In [None]:
# DCNet

# Reconstruction with for DCNet (linear net + denoising net)
dcnet = DCNet(noise, prep, Cov)

#y = pinvnet.acquire(x)         # or equivalently here: y = dcnet.acquire(x)
#m = pinvnet.meas2img(y)        # zero-padded images (after preprocessing)
dcnet = dcnet.to(device)
z_dcnet = dcnet.reconstruct(y.to(device))  # reconstruct from raw measurements
#x_dcnet_2 = dcnet(x)   # another reconstruction, from the ground-truth image

z_plot = z_dcnet.view(-1,H,H).cpu().numpy()
imagesc(z_plot[0,:,:],'DCNet reconstruction')

In [None]:
# Pretreined DC UNet (UNet denoising)
denoi = Unet()
dcnet_unet = DCNet(noise, prep, Cov, denoi)

# Load previously trained model
try:
    model_path = "./model/dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light"
    #model_path = './model/dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth'
    #dcnet_unet.load_state_dict(torch.load(model_path), loa)
    load_net(model_path, dcnet_unet, device, False)
    
    dcnet_unet = dcnet_unet.to(device)
    with torch.no_grad():
        z_dcunet = dcnet_unet.reconstruct(y.to(device))  # reconstruct from raw measurements

    z_plot = z_dcunet.view(-1,H,H).detach().cpu().numpy()
    imagesc(z_plot[0,:,:],'DC UNet reconstruction', show=False)
except:
    print(f'Model {model_path} not found!')
