[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/gallery/spyrit/tutorial/tuto_core_2d_short.ipynb)

# Demo for 2D single-pixel reconstruction


## Settings and requirements

First, we mount google drive

In [None]:
mode_colab = True
if (mode_colab is True):
    # Connect to googledrive
    #if 'google.colab' in str(get_ipython()):
    # Mount google drive to access files via colab
    from google.colab import drive
    drive.mount("/content/gdrive")  
    %cd /content/gdrive/MyDrive/    

Can choose GPU at Runtime/Change runtime type

Clone and install spyrit package

In [None]:
%%capture
if (mode_colab is True):
    # cd to spyrit folder is already cloned in your drive
    #%cd /content/gdrive/MyDrive/Colab_Notebooks/spyrit
    !git clone https://github.com/openspyrit/spyrit.git
    %cd spyrit
    !pip install -e .

In [None]:
if (mode_colab is True):
    # Checkout to ongoing branch
    !git checkout gallery
    !git branch


In [None]:
import os
import numpy as np

from spyrit.core.meas import HadamSplit
from spyrit.core.noise import NoNoise, Poisson
from spyrit.core.prep import SplitPoisson
from spyrit.core.recon import PseudoInverse, PinvNet, DCNet
from spyrit.misc.statistics import Cov2Var, data_loaders_stl10, transform_gray_norm
from spyrit.misc.disp import imagesc 
from spyrit.misc.sampling import meas2img2

import torch
import torchvision

In [None]:
# Parameters
H = 64                          # Image height (assumed squared image)
M = H**2 // 2                   # Num measurements = subsampled by factor 8
B = 10                          # Batch size
alpha = 100                     # ph/pixel max: number of counts

imgs_path = './spyrit/images'

In [None]:
if (mode_colab is False):
    # Change path to spyrit/
    os.chdir('../..')

## Load data

In [None]:
# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=H)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = min(B, len(dataset)))

x0, _ = next(iter(dataloader))
x = x0.detach().clone()
b,c,h,w = x.shape
x = x.view(b*c,h*w)
print(f'Shape of incoming image (b*c,h*w): {x.shape}')

x_plot = x.view(-1,H,H).cpu().numpy()    
imagesc(x_plot[0,:,:],'Ground-truth image normalized')

## Operators

Data simulation comprises three steps:

1. Split Linear Measurements

2. Noisy/Noisy raw measurements (handling negative images)

3. Preprocess measurements



In [None]:
"""
Data Simulation:
    1) Split Linear Measurements:
            y = Px = [H_{+}; H_{-}]x
        
        spyrit.core.meas
        meas_op = LinearSplit(H), matrix H
        meas_op = HadamSplit(M, h, Ord), M:#meas, h=height, Ord=Ordering matrix for undersampling

        y = meas_op(x)

    2) No Noisy/Noisy raw measurements (handling negative images):
        Handles the fact that images are between [-1, 1] and construct measurements 
        from measurements operator.
        Simulates raw measurements as expected by the single pixel camera (no negative measurements)

        Noiseless:
                y = 0.5*H(1+x)

            spyrit.core.noise
            meas_op = HadamSplit(M, h, Ord)
            y = NoNoise(meas_op)(x) 

        Noisy:
                y = Poisson((alpha/2)*H(1+x))

            spyrit.core.noise
            meas_op = HadamSplit(M, h, Ord)    
            y = Poisson(meas_op, alpha)(x)

    3) Preprocess measurements (before reconstruction): 
        Proceprocess to compensates for image normalization and splitting
        Mixes split measurements.
            m = (y+ - y-)/alpha - H*I
        
            spyrit.core.prep
            meas_op = HadamSplit(M, h, Ord)    
            m = SplitPoisson(alpha, meas_op)(y)

    4) Reconstruct

        Standard reconstruction:
            z = PseudoInverse()(m, meas_op)
        
        Inverse Net:
            Noiseless:
            pinv_net = PinvNet(NoNoise(meas_op), SplitPoisson(alpha, meas_op))
            z = pinv_net(x)

            Noisy:
            pinv_net = PinvNet(Poisson(meas_op, alpha), SplitPoisson(alpha, meas_op))
            z_invnet = pinv_net.reconstruct(y)

        DCNet:
            dcnet = DCNet(Poisson(meas_op, alpha), SplitPoisson(alpha, meas_op), Cov)
            y = dcnet.acquire(x) 
            z_dc = dcnet.reconstruct(y)
            """


### Split measurement and raw measurement operators

Split Linear Measurements:
$$
y = Px = [H_{+}; H_{-}]x
$$

Uses *spyrit.core.meas*

```
    meas_op = LinearSplit(H), 
    meas_op = HadamSplit(M, h, Ord), matrix for undersampling
    y = meas_op(x)
```
foir linear matrix $H$ and $M$ is the number of meas, $h$ the height, and $Ord$ the Ordering matrix for undersampling. 

Below, we create the measurement and noise operators and then compute measurements as:
```
meas_op = HadamSplit(M, H, Ord)
noise = Poisson(meas_op, alpha)
y = noise(x)
```
where inheritage is used
```
Poisson(NoNoise)
NoNoise(nn.module) 
```
and

$$
x \xrightarrow[]{\text{NoNoise}} \frac{x+1}{2} \xrightarrow[\text{meas\_op}]{\text{LinearSplit}} Px \xrightarrow[]{\text{Poisson}} y
$$


In [None]:
# Operators 
#
# Order matrix with shape (H, H) used to compute the permutation matrix 
# (as undersampling taking the first rows only)
# Ord = np.ones((H,H))            
Cov = np.eye(H*H)
Ord = Cov2Var(Cov)

# Measurement operator: 
# Computes linear measurements y=Px, where P is a linear operator (matrix) with positive entries      
# such that P=[H_{+}; H_{-}]=[max(H,0); max(0,-H)], H=H_{+}-H_{-}
meas_op = HadamSplit(M, H, Ord)

# Simulates raw split measurements from images in the range [0,1] assuming images provided in range [-1,1]
# y=0.5*H(1 + x)
# noise = NoNoise(meas_op) # noiseless
noise = Poisson(meas_op, alpha)

# Simulate raw measurements (non neagative measurements)
y = noise(x)
print(f'Shape of simulated measurements y: {y.shape}')

m_plot = y.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements y: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Simulated Measurement')

Note that measurements are positive

### Preprocess measurement operator 

Proceprocess to compensates for image normalization and splitting. It mixes split measurements:
$$
m = \frac{y_+ - y_-}{\alpha} - H*I
$$

Uses *spyrit.core.prep*
```
    meas_op = HadamSplit(M, h, Ord)    
    m = SplitPoisson(alpha, meas_op)(y)
```

In [None]:
# Preprocess the raw data acquired with split measurement operator assuming Poisson noise
prep = SplitPoisson(alpha, meas_op)

# Preprocessed data
m = prep(y)
print(f'Shape of preprocessed data m: {m.shape}')


m_plot = m.numpy()   
m_plot = meas2img2(m_plot.T, Ord)
m_plot = np.moveaxis(m_plot,-1, 0)
print(f'Shape of reshaped simulated measurements m: {m_plot.shape}')

imagesc(m_plot[0,:,:],'Preprocessed data')

Now, measurements can be negative

### Reconstruction operators

In [None]:
# Pseudo-inverse operator
pinv = PseudoInverse()

# Reconstruction
z_pinv = pinv(m, meas_op)
print(f'Shape of reconstructed image z: {z_pinv.shape}')

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')

In [None]:
# Pseudo-inverse net

# Reconstruction with for Core module (linear net)
pinvnet = PinvNet(noise, prep)
 
# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pinvnet = pinvnet.to(device)

x = x0.detach().clone()
x = x.to(device)
z_pinvnet = pinvnet(x)
# z_pinvnet = pinvnet.reconstruct(y)

z_plot = z_pinv.view(-1,H,H).numpy()
imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction')


In [None]:
# DCNet

# Reconstruction with for DCNet (linear net + denoising net)
dcnet = DCNet(noise, prep, Cov)

#y = pinvnet.acquire(x)         # or equivalently here: y = dcnet.acquire(x)
#m = pinvnet.meas2img(y)        # zero-padded images (after preprocessing)
dcnet = dcnet.to(device)
z_dcnet = dcnet.reconstruct(y.to(device))  # reconstruct from raw measurements
#x_dcnet_2 = dcnet(x)   # another reconstruction, from the ground-truth image

z_plot = z_dcnet.view(-1,H,H).cpu().numpy()
imagesc(z_plot[0,:,:],'DCNet reconstruction')