In [1]:
import torch
import numpy as np
from torchvision import datasets, transforms
from pathlib import Path
import spyrit.misc.walsh_hadamard as wh

from spyrit.misc.statistics import stat_walsh_stl10
from spyrit.misc.statistics import *
from spyrit.misc.disp import *

In [2]:
img_size = 64 # image size
M = 1024    # number of measurements
N0 = 50     # Image intensity (in photons)
bs = 10 # Batch size

#- Model and data paths
data_root = Path('../../data/')
stats_root = Path('../../data/stats_walsh')

In [None]:
#%% A batch of STL-10 test images
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(7)

transform = transforms.Compose(
    [transforms.functional.to_grayscale,
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])])

testset = \
    torchvision.datasets.STL10(root=data_root, split='test',download=False, transform=transform)
testloader =  torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False)

In [None]:
inputs, _ = next(iter(testloader))
b,c,h,w = inputs.shape

In [None]:
# stat_walsh_stl10()
Cov = np.load(stats_root / Path("Cov_{}x{}.npy".format(img_size, img_size)))
Mean = np.load(stats_root / Path("Average_{}x{}.npy".format(img_size, img_size)))
H =  wh.walsh2_matrix(img_size)
# H =  wh.walsh2_matrix(img_size)/img_size


Ord = Cov2Var(Cov)
Perm = Permutation_Matrix(Ord)
Hperm = Perm@H;
Pmat = Hperm[:M,:];


In [None]:
x = inputs.view(b*c,w*h)
x_0 = torch.zeros_like(x)

In [None]:
img = x[1,:]
img = img.numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
from spyrit.restructured.Updated_Had_Dcan import * 

In [None]:
FO_split = Split_Forward_operator(Pmat)
A_b = Bruit_Poisson_approx_Gauss(N0, FO_split)
SPP = Split_diag_poisson_preprocess(N0, M, w*h)

In [None]:
m = A_b(x)
y = SPP(m,FO_split)
var = SPP.sigma(m)

# Pinv_orthogonal
## Instancier

In [None]:
P = Pinv_orthogonal(FO_split)

## Test methode forward

In [None]:
x_est = P(y)

In [None]:
img = x_est[1,:]
img = img.numpy();
imagesc(np.reshape(img,(h,w)))

# learned_measurement_to_image
## Instantiation

In [None]:
P = learned_measurement_to_image(FO_split)

## Test methode forward

In [None]:
x_est = P(y)

In [None]:
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

Makes sens that we don't see anything as it currently only has random weights

# gradient_step
## Instantiation

In [None]:
P = gradient_step(FO_split, 0.0003)

## Test methode forward

In [None]:
x_est = P(y, x_0)

In [None]:
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
x_est = P(y, x_est)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
x_est = P(y, x_est)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

After 3 iterations we converge at the pseudo inverse (given the correct initial $\mu$

# Tikhonov_cg (not yet validated)
## Instantiation

In [None]:
P = Tikhonov_cg(FO_split, n_iter = 6, mu = 1)

## Test methode forward

In [None]:
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
P = Tikhonov_cg(FO_split, n_iter = 6, mu = 10000)
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula

In [None]:
P = Tikhonov_cg(FO_split, n_iter = 7, mu = 10000) # to check 6 is OK, & is not ok == Potential bug
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

# Tikhonov_cg_pylops /!\ Aborded for now - problems with the  MatrixLinearOperator not being capable of taking Hermitian inverse + Multiplication
## Instantiation

In [None]:
# FO_pyl = Split_Forward_operator_pylops(Pmat)
# P = Tikhonov_cg_pylops(FO_pyl, n_iter = 6, mu = 1)

## Test methode forward

In [None]:
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

# Tikhonov_solve (validated)
## Instantiation

In [None]:
P = Tikhonov_solve(FO_split, mu = 0.1)

## Test methode forward

In [None]:
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
P = Tikhonov_solve(FO_split, mu = 1000000)
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula

# Orthogonal Tikhonov (Validated)
## Instantiation

In [None]:
P = Orthogonal_Tikhonov(FO_split, mu = 0.1)

## Test methode forward

In [None]:
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
P = Orthogonal_Tikhonov(FO_split,  mu = 1000000)
x_est = P(y, x_0)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula

# Generalised_Tikhonov_cg (not yet validated)
## Instantiation

In [None]:
Sigma_prior = 0.01*(1/(h*w))**2*Hperm.T@Cov@Hperm;
P = Generalised_Tikhonov_cg(FO_split, Sigma_prior = Sigma_prior, n_iter = 6)

## Test methode forward

In [None]:
x_est = P(y, x_0, var)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

In [None]:
P = Generalised_Tikhonov_cg(FO_split, Sigma_prior = Sigma_prior, n_iter = 7)
x_est = P(y, x_0, var)
img = x_est[1,:]
img = img.detach().numpy();
imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula

# Generalised_Tikhonov_cg_pylops /!\ Aborded for now
## Instantiation

In [None]:
# FO_pyl = Split_Forward_operator_pylops(Pmat)
# P = Generalised_Tikhonov_cg_pylops(FO_pyl, n_iter = 6, mu = 1)

## Test methode forward

In [None]:
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

# Generalized_Orthogonal_Tikhonov
## Instantiation

In [None]:
P = Generalised_Tikhonov_solve(FO_split, mu = 0.1)

## Test methode forward

In [None]:
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

In [None]:
# P = Tikhonov_solve(FO_split, mu = 1000000)
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula

# Generalized_Orthogonal_Tikhonov
## Instantiation

In [None]:
P = Generalized_Orthogonal_Tikhonov(FO_split, mu = 0.1)

## Test methode forward

In [None]:
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

In [None]:
# P = Orthogonal_Tikhonov(FO_split,  mu = 1000000)
# x_est = P(y, x_0)
# img = x_est[1,:]
# img = img.detach().numpy();
# imagesc(np.reshape(img,(h,w)))

Too much regularisation (makes sense) due to tikhonov formula