In [1]:
import torch
import numpy as np
from spyrit.misc.walsh_hadamard import walsh_matrix
from spyrit.core.Forward_Operator import *
from spyrit.core.Preprocess import *

## INITIALIZE COMMON MATRICES AND OPERATORS

### Instantiate Hsub and Perm matrices

In [2]:
h,w = 32, 32
img_size = h*w
nb_measurements = 400
batch_size = 10
Hcomplete = walsh_matrix(img_size)
Hsub = Hcomplete[0:nb_measurements,:]
Perm = np.random.random([w*h,w*h])

### Instantiate different types of Forward_operators

In [3]:
FO = Forward_operator(Hsub)
FO_Split = Forward_operator_Split(Hsub)
FO_Split_ft_had = Forward_operator_Split_ft_had(Hsub, Perm, w, h)

## PARENT CLASS: Split_diag_poisson_preprocess

In [4]:
alpha = 9
SPP = Preprocess_Split_diag_poisson(alpha, nb_measurements, img_size)

### forward

In [5]:
x = torch.tensor(np.random.random([batch_size,img_size]), dtype=torch.float)
y_FO_Split = FO_Split(x)
y_SPP = SPP(y_FO_Split, FO_Split)
print(y_FO_Split.shape)
print(y_SPP.shape)

torch.Size([10, 800])
torch.Size([10, 400])


### sigma

In [6]:
y_sigma = SPP.sigma(y_FO_Split)
print(y_sigma.shape)

torch.Size([10, 400])


### sigma_expe

In [7]:
SPP.set_expe(gain=1, mudark=0, sigdark=0, nbin=1)
y_sigma_expe = SPP.sigma_expe(x)
print(y_sigma_expe.shape)

torch.Size([10, 400])


### sigma_from_image

In [8]:
# y = SPP.sigma_from_image(x, FO)

### forward_expe

In [9]:
# xsub = FO_Split_ft_had(x)
# print(xsub.shape)

In [10]:
xsub = torch.tensor(np.random.random([10, 2*400]), dtype=torch.float)
y_FE, alpha_est = SPP.forward_expe(xsub, FO_Split_ft_had)
print(y_FE.shape)
print(alpha_est)

torch.Size([10, 400])
tensor([0.0218, 0.0227, 0.0244, 0.0275, 0.0268, 0.0231, 0.0241, 0.0256, 0.0241,
        0.0253])


### denormalize_expe

In [11]:
x1 = x.view(10,1,h,w)
norm = 9*torch.tensor(np.random.random([1,10]))
y_DE = SPP.denormalize_expe(x1, norm, h, w)
print(y_DE.shape)

torch.Size([10, 1, 32, 32])


## PARENT CLASS: Preprocess_shift_poisson

In [12]:
PSP = Preprocess_shift_poisson(alpha, nb_measurements, h*w)
print(PSP)

Preprocess_shift_poisson()


### forward

In [13]:
# x2 could be just a "torch.tensor(np.random.random([10, 400+1]), dtype=torch.float)"
# I just wanted to test the "cat" torch function and use the already existing x
x2 = torch.cat((torch.ones(10,1),x[:,0:400]), dim=1)
print("x2 shape:", x2.shape)
y_PSP = PSP(x2, FO)
print("y_PSP shape:", y_PSP.shape)

x2 shape: torch.Size([10, 401])
y_PSP shape: torch.Size([10, 400])


### sigma

In [14]:
sigma_PSP = PSP.sigma(x2)
print(sigma_PSP.shape)

torch.Size([10, 400])


### sigma_from_image

In [15]:
# sig_im_PSP = PSP.sigma_from_image(x, FO)

### offset

In [16]:
y = PSP.offset(x)
print(y.shape)

torch.Size([10, 1])


## PARENT CLASS: Preprocess_pos_poisson

In [17]:
PPP = Preprocess_pos_poisson(9, 400, 32*32)
print(PPP)

Preprocess_pos_poisson()


### forward

In [25]:
xsub = torch.tensor(np.random.random([10, 400]), dtype=torch.float)
y = PPP(xsub, FO)
print(y.shape)

torch.Size([10, 400])
torch.Size([10, 400])
torch.Size([10, 400])


### offset

In [26]:
x = torch.tensor(np.random.random([10, 400]), dtype=torch.float)
y = PPP.offset(x)
print(y.shape)

torch.Size([10, 1])
