[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit-examples/blob/master/tutorial/tuto_pinvnet_cnn.ipynb)

# PinvNet (pseudo-inverse) + CNN

This tutorial shows how to simulate measurements and perform image reconstruction using PinvNet (pseudoinverse linear network) with CNN denoising as a last layer. This tutorial is a simplified version of the documentation tutorial [Pseudoinverse solution + CNN denoising](https://spyrit.readthedocs.io/en/master/gallery/tuto_pseudoinverse_cnn_linear.html#sphx-glr-gallery-tuto-pseudoinverse-cnn-linear-py) to run on colab. 

The measurement operator is chosen as a Hadamard matrix with positive coefficients, which can be replaced by any matrix.

### Set google colab

On colab, to run on GPU, select *GPU* from the navigation menu *Runtime/Change runtime type*.

In [None]:
!nvidia-smi

Set *mode_colab=True* to run in colab. Mount google drive, needed to import spyrit modules.

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")  
    %cd /content/gdrive/MyDrive/    

### Clone Spyrit package

Clone and install spyrit package if not installed or move to spyrit folder. 

In [None]:
# #%%capture
install_spyrit = True
if (mode_colab is True):
    if install_spyrit is True:
        # Clone and install
        !git clone https://github.com/openspyrit/spyrit.git
        %cd spyrit
        !pip install -e .

        # Install extra dependencies
        !pip install gdown
        !pip install girder_client
    else:
        # cd to spyrit folder is already cloned in your drive
        %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit

    # Add paths for modules
    import sys
    sys.path.append('./spyrit/core')
    sys.path.append('./spyrit/misc')
    sys.path.append('./spyrit/tutorial')
else:
    # Change path to spyrit/
    # Assumes are in /spyrit-examples/tutorial
    import os
    os.chdir('../../spyrit')
    !pwd

## Settings and requirements

### Set Parameters

In [None]:
# Parameters
H = 64                          # Image height (assumed squared image)
M = H**2 // 4                   # Num measurements = subsampled by factor 4

B = 10                          # Batch size
imgs_path = './spyrit/images'   # Path to image examples

## Load data

Load a batch of images from the folder *spyrit/images*. Images are transformed to grayscale and normalized to $[-1,1]$ for training by `transform_gray_norm`. 

In [None]:
import os
from spyrit.misc.statistics import transform_gray_norm
import torchvision
import torch
from spyrit.misc.disp import imagesc

h = 64            # image size hxh
i = 1             # Image index (modify to change the image)

# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=h)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 7)

x, _ = next(iter(dataloader))
print(f'Shape of input images: {x.shape}')

# Select image
x = x[i:i+1,:,:,:]
x = x.detach().clone()
b,c,h,w = x.shape

# plot
x_plot = x.view(-1,h,h).cpu().numpy()
imagesc(x_plot[0,:,:], r'$x$ in [-1, 1]')

## Data simulation

Data simulation in spyrit is done by using three operators from `spyrit.core.meas` that correspond to image normalization, forward operator and noise perturbation. In the example below, this corresponds to the following steps:

$$
x \xrightarrow[\text{Step 1}]{\text{NoNoise}} \tilde{x}=\frac{x+1}{2} \xrightarrow[\text{Step 2}]{\text{Linear}} y=P\tilde{x} \xrightarrow[\text{Step 3}]{\text{Poisson}} \mathcal{P}(\alpha y)
$$

- **Normalization operator**: Given an image $x$ between $[-1, 1]$ (for training), the image is first normalized such that $\tilde{x}$ ranges between $[0, 1]$ as

$$
\tilde{x}=\frac{x+1}{2}
$$
using `spyrit.core.noise.NoNoise` operator. This normalization is required in order to apply the forward operator on positive images (see tutorial on [acquisition operators](https://spyrit.readthedocs.io/en/pinv_cnn/gallery/tuto_acquisition_operators.html)).

- **Forward operator**: Measurements $y$ are obtained via the linear operator $H$ : 

$$
y = H\tilde{x}.
$$

- **Noise operator**: Data is finally perturbed by Poisson noise as

$$
m = \mathcal{P}(\alpha y)
$$

where $\alpha$ accounts for the mean photon counts, using spirit's **spyrit.core.noise.Poisson**. 


### Define a measurement operator

We consider the case where the measurement matrix is the positive component of a Hadamard matrix and the sampling operator preserves only the first $M$ low-frequency coefficients. In the code below `M` is the number of meas, `h` the height, and `Ord` the Ordering matrix for undersampling. 

In [None]:
# Sampling parameters
und = 4                # undersampling factor
M = h**2 // und        # number of measurements (undersampling factor = 4)

In [None]:
import numpy as np
import math
from spyrit.misc.sampling import Permutation_Matrix
from spyrit.misc.walsh_hadamard import walsh2_matrix

F = walsh2_matrix(h)
F = np.where(F>0, F, 0)

Sampling_map = np.ones((h,h))
M_xy = math.ceil(M**0.5)
Sampling_map[:,M_xy:] = 0
Sampling_map[M_xy:,:] = 0

Perm = Permutation_Matrix(Sampling_map)
F = Perm@F
H = F[:M,:]
print(f"Shape of the measurement matrix: {H.shape}")

imagesc(Sampling_map, 'low-frequency sampling map')

Then, we instantiate a `spyrit.core.meas.Linear` measurement operator

In [None]:
from spyrit.core.meas import Linear
meas_op = Linear(H, pinv=True)

### Noise operator

In the noiseless case, we consider the `spyrit.core.noise.NoNoise` operator. If we wanted to use Poisson noise, we would use instead `spyrit.core.noise.Poisson` and set `N0` (mean number of photons) to the desired noise level. 

In [None]:
from spyrit.core.noise import NoNoise
# from spyrit.core.noise import Poisson

N0 = 1.0         # Noise level (noiseless)
noise = NoNoise(meas_op)
# noise = Poisson(meas_op)

# Simulate measurements
y = noise(x.view(b*c,h*w))
print(f'Shape of raw measurements: {y.shape}')

### Preprocessing operator

Note that previous steps allow to simulate measurements. A fourth step is done to preprocess the data corresponding to $\tilde{x}$ in order to compute it for the original $x$  

$$
y \xrightarrow[\text{Step 4}]{\text{Prep}} m=\frac{2y}{\alpha}-P\mathbb{I},
$$

which in spyrit is done with `spyrit.core.prep.DirectPoisson`.

We now compute and plot the preprocessed measurements corresponding to an image in $[-1,1]$

In [None]:
from spyrit.core.prep import DirectPoisson
prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator

m = prep(y)
print(f'Shape of the preprocessed measurements: {m.shape}')

To display the subsampled measurement vector as an image in the transformed domain, we use the `spyrit.misc.sampling.meas2img` function


In [None]:
# plot
from spyrit.misc.sampling import meas2img

m_plot = m.detach().numpy().squeeze()
m_plot = meas2img(m_plot, Sampling_map)
print(f'Shape of the preprocessed measurement image: {m_plot.shape}')

imagesc(m_plot, 'Preprocessed measurements (no noise)')

## PinvNet Network

We consider the `spyrit.core.recon.PinvNet` class that reconstructs an image by computing the pseudoinverse solution, which is fed to a neural network denoiser. To compute the pseudoinverse solution only, the denoiser can be set to the identity operator

<img src="https://spyrit.readthedocs.io/en/master/_images/pinvnet.png" alt="drawing" width="400" class="center" />

Note that the forward operator, the pseudo-inverse reconstruction and the denoiser are layers of PinvNet, where only the last layer has learnable parameters. Thus, we pass `noise`, `prep` and `denoi` to `PinvNet`.

In [None]:
from spyrit.core.recon import PinvNet

# Create PinvNet
pinv_net = PinvNet(noise, prep, denoi=torch.nn.Identity())

Then, we reconstruct the image from the measurement vector `y` using the `reconstruct` method

In [None]:
# Reconstruct
x_rec = pinv_net.reconstruct(y)

We plot the results

In [None]:
# plot
x_plot = x.squeeze().cpu().numpy()
x_plot2 = x_rec.squeeze().cpu().numpy()

import matplotlib.pyplot as plt
from spyrit.misc.disp import add_colorbar, noaxis

f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
im1=ax1.imshow(x_plot, cmap='gray')
ax1.set_title('Ground-truth image', fontsize=20)
noaxis(ax1)
add_colorbar(im1, 'bottom', size='20%')

im2=ax2.imshow(x_plot2, cmap='gray')
ax2.set_title('Pinv reconstruction', fontsize=20)
noaxis(ax2)
add_colorbar(im2, 'bottom', size='20%')


## PinvNet Network + CNN

Artefacts can be removed by selecting a neural network denoiser (last layer of PinvNet). We select a simple CNN using the `spyrit.core.nnet.ConvNet` class, but this can be replaced by any neural network (eg. UNet from `spyrit.core.nnet.Unet`).

<img src="https://spyrit.readthedocs.io/en/master/_images/pinvnet_cnn.png" alt="drawing" width="400" class="center" />

In [None]:
from spyrit.core.nnet import ConvNet
from spyrit.core.train import load_net

# Define PInvNet with ConvNet denoising layer
denoi = ConvNet()
pinv_net_cnn = PinvNet(noise, prep, denoi)

# Send to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pinv_net_cnn = pinv_net_cnn.to(device)

As an example, we use a simple ConvNet that has been pretrained using STL-10 dataset. We download the pretrained weights and load them into the network.

In [None]:
# Load pretrained model
try:
    import gdown
    # 30 epochs
    url_cnn = 'https://drive.google.com/file/d/1IZYff1xQxJ3ckAnObqAWyOure6Bjkj4k/view?usp=drive_link'
    name_cnn = 'pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07'
    
    # Create model folder
    model_path = "./model"
    if os.path.exists(model_path) is False:
        os.mkdir(model_path)
        print(f'Created {model_path}')

    # Download model weights
    model_cnn_path = os.path.join(model_path, name_cnn)
    gdown.download(url_cnn, f'{model_cnn_path}.pth', quiet=False,fuzzy=True)

    # Load model weights
    load_net(model_cnn_path, pinv_net_cnn, device, False)
    print(f'Model {model_path} loaded.')
except:
    print(f'Model {model_path} not found!')

We now reconstruct the image using PinvNet with pretrained CNN denoising and plot results side by side with the PinvNet without denoising

In [None]:
# Reconstruct
with torch.no_grad():
    x_rec_cnn = pinv_net_cnn.reconstruct(y.to(device))

In [None]:
# plot
x_plot = x.squeeze().cpu().numpy()
x_plot2 = x_rec.squeeze().cpu().numpy()
x_plot3 = x_rec_cnn.squeeze().cpu().numpy()

import matplotlib.pyplot as plt
from spyrit.misc.disp import add_colorbar, noaxis

f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))
im1=ax1.imshow(x_plot, cmap='gray')
ax1.set_title('Ground-truth image', fontsize=20)
noaxis(ax1)
add_colorbar(im1, 'bottom', size='20%')

im2=ax2.imshow(x_plot2, cmap='gray')
ax2.set_title('PinvNet reconstruction', fontsize=20)
noaxis(ax2)
add_colorbar(im2, 'bottom', size='20%')

im3=ax3.imshow(x_plot3, cmap='gray')
ax3.set_title(f'PinvNet with CNN', fontsize=20)
noaxis(ax3)
add_colorbar(im3, 'bottom', size='20%')