[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_core_2d_drunet.ipynb)

# Tutorial 2D - Image reconstruction for single-pixel imaging with DRUNet denoising


This tutorial shows how to simulate data and perform image reconstruction with spyrit toolbox for 2D single-pixel imaging. In specific with DCRUNET that leverages the pretrained DRUNet denoising network.

DRUNet taken from https://github.com/cszn/DPIR
Deep Plug-and-Play Image Restoration (DPIR) toolbox
June 2023

The image below displays the ground-truth image, undersampled data and reconstruction with different methods.

<img src="https://drive.google.com/uc?id=1bbmk2dHbCQ92_YBO-pXP3qcitv4p3d2U" alt="drawing" width="800"/>

For **data simulation**, it loads an image from ImageNet and simulated split measurements based on 
an undersampled Hadamard operator (see [Tutorial on split measurements](https://spyrit.readthedocs.io/en/pinv_cnn/gallery/tuto_acquisition_split_measurements.html#sphx-glr-gallery-tuto-acquisition-split-measurements-py)). You can select the noise level and undersampled factor. 

**Image reconstruction** is performed using the following methods: 
-    Pseudo-inverse
-    PInvNet:        Linear net (same result as Pseudo-inverse)
-    DCNet:          Data completion net with unit matrix denoising
-    DCUNet:         Data completion with UNet denoising, trained on stl10 dataset (requires to download UNet weights). 
-    DCUNetRes:      Data completion with pretrained DRUNet denoising.


### Set google colab

On colab, to run on GPU, select *GPU* from the navigation menu *Runtime/Change runtime type*.

In [None]:
!nvidia-smi

Set *mode_colab=True* to run in colab. Mount google drive, needed to import spyrit modules.

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")  
    %cd /content/gdrive/MyDrive/    

### Clone Spyrit package

Clone and install spyrit package if not installed or move to spyrit folder. 

Installation set for colab. To run the notebook locally, clone both *spyrit* and *spyrit-examples* at the same level (data files will be downloaded automatically below): 

```
    openspyrit/
    ├───spirit
    │   ├───stat
    │       ├───Average_64x64.npy
    │       ├───Cov_64x64.npy
    │   ├───spirit
    │       ├───model_zoo
    │           ├───dc-net_unet_... .pth
    │           ├───drunet_gray.pth
    ├───spirit-examples
    │   ├───tutorial
    │       ├───tuto_core_2d_short_drunet.ipynb
```


In [None]:
# #%%capture
install_spyrit = True
if (mode_colab is True):
    if install_spyrit is True:
        # Clone and install
        !git clone https://github.com/openspyrit/spyrit.git
        %cd spyrit
        !pip install -e .

        # Checkout to ongoing branch
        !git fetch --all
        !pip install girder_client
    else:
        # cd to spyrit folder is already cloned in your drive
        %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit

    # Add paths for modules
    import sys
    sys.path.append('./spyrit/core')
    sys.path.append('./spyrit/misc')
    sys.path.append('./spyrit/tutorial')
else:
    # Change path to spyrit/
    # Assumes are in /spyrit-examples/tutorial
    import os
    os.chdir('../../spyrit')
    !pwd

In [None]:
# Install extra dependencies
!pip install tensorboard
!pip install gdown

In [None]:
# Load spyrit modules
from spyrit.core.meas import HadamSplit
from spyrit.core.noise import NoNoise, Poisson, PoissonApproxGauss
from spyrit.core.prep import SplitPoisson
from spyrit.core.recon import PseudoInverse, PinvNet, DCNet, DCDRUNet
from spyrit.core.nnet import Unet
from spyrit.misc.statistics import Cov2Var, data_loaders_stl10, transform_gray_norm
from spyrit.misc.disp import imagesc 
from spyrit.misc.sampling import meas2img2
from spyrit.core.train import load_net
from spyrit.external.drunet import UNetRes as drunet


## Settings and requirements

In [None]:
import os
import numpy as np
import torch
import torchvision

import gdown
import girder_client

## Download covariance and DCNet model

In this tutorial, we adopt a full-covariance matrix, which takes into account correlated measurements (see [full covariance](https://spyrit.readthedocs.io/en/pinv_cnn/gallery/tuto_acquisition_split_measurements.html#sphx-glr-gallery-tuto-acquisition-split-measurements-py)). This requires to download the covariance matrix (`download_cov=True`). Alternatively, you can set a unit covariance, which leads to pixelized reconstructions.

In [None]:
# Set download data covariance to True for realistic simulations
download_cov = True
if (download_cov is True):
    # api Rest url of the warehouse
    url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'
    
    # Generate the warehouse client
    gc = girder_client.GirderClient(apiUrl=url)

    #%% Download the covariance matrix and mean image
    data_folder = './stat/'
    dataId_list = [
            '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)
            '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)
            ]
    for dataId in dataId_list:
        myfile = gc.getFile(dataId)
        gc.downloadFile(dataId, data_folder + myfile['name'])

    print(f'Created {data_folder}') 
    !ls $data_folder

    #%% Download the models
    data_folder = './model/'
    dataId_list = [
                #'644a38c985f48d3da07140ba', # N_rec = 64, M = 4095
                '644a38c785f48d3da07140b7', # N_rec = 64, M = 1024
                #'644a38c585f48d3da07140b4', # N_rec = 64, M = 512
                ]

### Set Parameters

In [None]:
# Parameters
H = 64                          # Image height (assumed squared image)
M = H**2 // 4                   # Num measurements = subsampled by factor 2
B = 10                          # Batch size
alpha = 100                     # ph/pixel max: number of counts
                                # otherwise, set to unit matrix

imgs_path = './spyrit/images'
cov_name = './stat/Cov_64x64.npy'

# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load data

Load a batch of images from the folder *spyrit/images*. Images are transformed to grayscale and normalized to $[-1,1]$ for training by **transform_gray_norm**. 

In [None]:
# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=H)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = min(B, len(dataset)))

# Select image
x0, _ = next(iter(dataloader))
x0 = x0[1:6,:,:,:]
x = x0.detach().clone()
b,c,h,w = x.shape
x = x.view(b*c,h*w)
print(f'Shape of incoming image (b*c,h*w): {x.shape}')

x_plot = x.view(-1,H,H).cpu().numpy()    
imagesc(x_plot[0,:,:],'Ground-truth image normalized')

## Neural network pipeline and operators

### Experimental data simulation: Split measurement, noise and raw measurement operators

We adopt linear measurements using a Hadamard operator corrupted by Poisson noise.

Data simulation in spyrit is done by using three operators from using **spyrit.core.meas**: image normalization, split measurements and noise perturbation. In the example below, this corresponds to the following steps:

$$
x \xrightarrow[\text{Step 1}]{\text{NoNoise}} \tilde{x}=\frac{x+1}{2} \xrightarrow[\text{Step 2}]{\text{HadamSplit}} y=P\tilde{x} \xrightarrow[\text{Step 3}]{\text{Poisson}} \mathcal{P}(\alpha y)
$$

- Step 1: Given an image $x$ between $[-1, 1]$ (for training), the image is first normalized such that $\tilde{x}$ ranges between $[0, 1]$ as

$$
\tilde{x}=\frac{x+1}{2}
$$
using **spyrit.core.noise.NoNoise** operator. This normalization is required in order to apply to the forward operator on positive images (see tutorial on [acquisition operators](https://spyrit.readthedocs.io/en/pinv_cnn/gallery/tuto_acquisition_operators.html)).

- Step 2: Split measurements $y$ are obtained via the linear operator $P$: 

$$
y = P\tilde{x} = 
\begin{pmatrix}
H_{+} \\
H_{-}
\end{pmatrix}\tilde{x}=
\begin{pmatrix}
\max(H, 0) \\
\max(0,-H)
\end{pmatrix}\tilde{x}
$$

where $H=(H_{+}-H_{-})$ is a Hadamard matrix (see [Tutorial on split measurements](https://spyrit.readthedocs.io/en/pinv_cnn/gallery/tuto_acquisition_split_measurements.html#sphx-glr-gallery-tuto-acquisition-split-measurements-py)). 

- Step 3: Data is finally perturbed by Poisson noise as

$$
\tilde{y} = \mathcal{P}(\alpha y)
$$

where $\alpha$ accounts for the mean photon counts, using spirit's **spyrit.core.noise.Poisson**. 

In the code below $M$ is the number of meas, $h$ the height, and $Ord$ the Ordering matrix for undersampling. 


In [None]:
# Operators 
#
# Order matrix with shape (H, H) used to compute the permutation matrix 
# (as undersampling taking the first rows only)
try:
    Cov  = np.load(cov_name)
except:
    Cov = np.eye(H*H)
    print(f"Cov matrix {cov_name} not found! Set to the identity")
    
Ord = Cov2Var(Cov)

# Measurement operator: 
# Computes linear measurements y=Px, where P is a linear operator (matrix) with positive entries      
# such that P=[H_{+}; H_{-}]=[max(H,0); max(0,-H)], H=H_{+}-H_{-}
meas_op = HadamSplit(M, H, Ord)

# Simulates raw split measurements from images in the range [0,1] assuming images provided in range [-1,1]
# y=0.5*H(1 + x)
# noise = NoNoise(meas_op) # noiseless
noise = Poisson(meas_op, alpha)

# Simulate raw measurements (non neagative measurements)
y = noise(x)
print(f'Shape of simulated measurements y: {y.shape}')

m_plot = y.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements y: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Simulated Measurement')

Note that measurements are positive, as expected experimentally. 


### Preprocessing measurement operator 
Note that previous steps allow to simulate experimental split measurements, which only considers positive pixels. A fourth step is done in order to preprocess the raw data acquired with a split measurements operator, for an image $\tilde{x}$, and to compute the measurements for the original $x$  

$$
y \xrightarrow[\text{Step 4}]{\text{Prep}} \tilde{m}=\frac{y_+-y_-}{\alpha}\longrightarrow m=\frac{2\tilde{m}}{\alpha}-P\mathbb{I},
$$

where $y_+=H_+\tilde{x}$, which in spyrit is done with **spyrit.core.prep.SplitPoisson**.

In [None]:
# Preprocess the raw data acquired with split measurement operator assuming Poisson noise
prep = SplitPoisson(alpha, meas_op)

# Preprocessed data
m = prep(y)
print(f'Shape of preprocessed data m: {m.shape}')


m_plot = m.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements m: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Preprocessed data')

### Reconstruction operators

Let $H$ represent the retaining Hadamard patterns, the Hadamard coefficients $m$ are given in terms of the image $x$ as 

$$
m=Hx.
$$

The Hadamard operator is orthogonal, so an unknown image $\hat{x}$ can be recovered from the inverse Hadamard transform 

$$
\hat{x}=H^{\dagger}m.
$$

Image reconstruction in spyrit comprises four steps (the last one is optional):

1. Denoising

2. Data completion

3. Hadamard inverse transform

4. Nonlinear postprocessing (denoising)

$$
m\in\mathbb{R}^{M} \xrightarrow[\text{Denoising}]{\text{Step 1}} y_1\in\mathbb{R}^{M} \xrightarrow[\text{Completion}]{\text{Step 2}} y_2\in\mathbb{R}^{N-M} \xrightarrow[\text{Inverse}]{\text{Step 3}} \hat{x}\in\mathbb{R}^{N} \xrightarrow[\text{Postprocessing}]{\text{Step 4}} \mathcal{D}(\hat{x})\in\mathbb{R}^{N_x\times N_y}
$$

with $y=[y_1^T; y_2^T]^T$ and $\mathcal{D}$ a denoising operator. 

In spyrit, the four steps are comprised inside **spyrit.core.recon.PinvNet** or **spyrit.core.recon.DCNet**, and are automatically handed for sigle-pixel imaging data. The denoising network in the nonlinear step in dealt by **spyrit.core.nnet.Unet**.
```

In [None]:
# Pseudo-inverse operator
pinv = PseudoInverse()

# Reconstruction
z_pinv = pinv(m, meas_op)
print(f'Shape of reconstructed image z: {z_pinv.shape}')

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')

In [None]:
# Pseudo-inverse net

# Reconstruction with for Core module (linear net)
pinvnet = PinvNet(noise, prep)
 
# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pinvnet = pinvnet.to(device)

x = x0.detach().clone()
x = x.to(device)
z_pinvnet = pinvnet(x)
# z_pinvnet = pinvnet.reconstruct(y)

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')

In [None]:
# DCNet

# Reconstruction with for DCNet (linear net + denoising net)
dcnet = DCNet(noise, prep, Cov)

#y = pinvnet.acquire(x)         # or equivalently here: y = dcnet.acquire(x)
#m = pinvnet.meas2img(y)        # zero-padded images (after preprocessing)
dcnet = dcnet.to(device)
z_dcnet = dcnet.reconstruct(y.to(device))  # reconstruct from raw measurements
#x_dcnet_2 = dcnet(x)   # another reconstruction, from the ground-truth image

z_plot = z_dcnet.view(-1,H,H).cpu().numpy()
imagesc(z_plot[0,:,:],'DCNet reconstruction')

In [None]:
# Pretreined DC UNet (UNet denoising)
denoi = Unet()
dcnet_unet = DCNet(noise, prep, Cov, denoi)

# Load previously trained model
try:
    # Download weights
    url_unet = 'https://drive.google.com/file/d/1LBrjU0B-Tecd4GBRozX9-24LTRzIiMzA/view?usp=drive_link'
    model_unet_path = "./spyrit/model_zoo"
    
    if os.path.exists(model_unet_path) is False:
        os.mkdir(model_unet_path)
        print(f'Created {model_unet_path}')

    model_unet_path = os.path.join(model_unet_path, 'dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light')
    gdown.download(url_unet, f'{model_unet_path}.pth', quiet=False,fuzzy=True)

    # Load model    #dcnet_unet.load_state_dict(torch.load(model_path), loa)
    load_net(model_unet_path, dcnet_unet, device, False)
    
    dcnet_unet = dcnet_unet.to(device)
    with torch.no_grad():
        z_dcunet = dcnet_unet.reconstruct(y.to(device))  # reconstruct from raw measurements

    z_plot = z_dcunet.view(-1,H,H).detach().cpu().numpy()
    imagesc(z_plot[0,:,:],'DC UNet reconstruction', show=False)
except:
    print(f'Model {model_unet_path} not found!')


In [None]:
# DC DRUNET (with pretreined DRUNet denoising)
#
# Download weights
model_drunet_path = './spyrit/model_zoo'
url_drunet = 'https://drive.google.com/file/d/1oSsLjPPn6lqtzraFZLZGmwP_5KbPfTES/view?usp=drive_link'

if os.path.exists(model_drunet_path) is False:
    os.mkdir(model_drunet_path)
    print(f'Created {model_drunet_path}')

model_drunet_path = os.path.join(model_drunet_path, 'drunet_gray.pth')
gdown.download(url_drunet, model_drunet_path, quiet=False,fuzzy=True)

# Define denoising network
n_channels = 1                   # 1 for grayscale image    
denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',                     
            downsample_mode="strideconv", upsample_mode="convtranspose")  

# Load pretrained model
try:       
    denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=True)       
    print(f'Model {model_drunet_path} loaded.')
except:
    print(f'Model {model_drunet_path} not found!')
    load_drunet = False

denoi_drunet.eval()         
for k, v in denoi_drunet.named_parameters():             
    v.requires_grad = False  
print(sum(map(lambda x: x.numel(), denoi_drunet.parameters())) )  

# Define DCDRUNet
#noise_level = 10
#dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet, noise_level=noise_level)
dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet)

# Reconstruct
# Uncomment to set a new noise level: The higher the noise, the higher the denoising
noise_level = 10
dcdrunet.set_noise_level(noise_level)
dcdrunet = dcdrunet.to(device)
with torch.no_grad():
    # reconstruct from raw measurements
    z_dcdrunet = dcdrunet.reconstruct(y.to(device))  

denoi_drunet = denoi_drunet.to(device)

# DCDRUNet
z_plot = z_dcdrunet.view(-1,H,H).detach().cpu().numpy()
imagesc(z_plot[0,:,:],f'DC DRUNet reconstruction noise level={noise_level}', show=False)
