In [None]:
import torch
import numpy as np
import spyrit.misc.walsh_hadamard as wh

from matplotlib import pyplot as plt

from spyrit.learning.model_Had_DCAN import *
from spyrit.misc.disp import torch2numpy
from spyrit.misc.statistics import Cov2Var
from spyrit.learning.nets import *

from spas import read_metadata, reconstruction_hadamard
from spas import ReconstructionParameters, setup_reconstruction, load_noise, reconstruct
from spas.noise import noiseClass
from spas.visualization import *
#from siemens_star_analysis import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Torch device: {device}')

In [None]:
H = wh.walsh2_matrix(64)

In [None]:
f = np.load('./data/zoom_x1_starsector/zoom_x1_starsector_spectraldata.npz')

spectral_data = f['spectral_data']

metadata, acquisition_metadata, spectrometer_parameters, dmd_parameters = \
    read_metadata('./data/zoom_x1_starsector/zoom_x1_starsector_metadata.json')
wavelengths = np.asarray(acquisition_metadata.wavelengths)

print(f'Spectral data dimensions: {spectral_data.shape}')
print(f'Wavelength range: {wavelengths[0]} - {wavelengths[-1]} nm')

print('\nAcquired data description:')
print(f'Light source: {metadata.light_source}')
print(f'Object: {metadata.object}')
print(f'Filter: {metadata.filter}')
print(f'Patterns: {acquisition_metadata.pattern_amount}')
print(f'Integration time: {spectrometer_parameters.integration_time_ms} ms')

In [None]:
metadata

In [None]:
recon = reconstruction_hadamard(acquisition_metadata.patterns, 'walsh', H, spectral_data)

plt.imshow(np.sum(recon, axis=2), cmap='gray')
plt.colorbar()

In [None]:
data = np.load('./fit_model2.npz')
mu = data['mu']
sigma = data['sigma']
coeff = data['k']
noise = noiseClass(mu, sigma, coeff)

In [None]:
F_bin_GT, wavelengths_bin_recon, bin_width, noise_bin = spectral_binning(spectral_data.T, wavelengths, 530, 730, 1, noise)
recon_GT = reconstruction_hadamard(acquisition_metadata.patterns, 'walsh', H, F_bin_GT.T)

plt.imshow(recon_GT, cmap='gray')
plt.colorbar()

In [None]:
recon_GT.max()

In [None]:
def subsample(spectral_data, CR):
    
    # If only one wavelength is considered
    if spectral_data.ndim == 1:
        torch_img = np.zeros((2*CR))
        
        pos = spectral_data[0::2][:CR]
        neg = spectral_data[1::2][:CR]
        
        torch_img[0::2] = pos
        torch_img[1::2] = neg
    
    # If spectral_data contains all wavelengths
    if spectral_data.ndim == 2:
        
        torch_img = np.zeros((2*CR, spectral_data.shape[1]))
        
        pos = spectral_data[0::2][:CR,:]
        neg = spectral_data[1::2][:CR,:]
        
        torch_img[0::2,:] = pos
        torch_img[1::2,:] = neg
    
    return torch_img

In [None]:
img_size = 64
CR = 2048
net_arch = 0 # Network variant

# Intensity distribution
N0 = 10000
sig = 0.5

#- Training parameters
num_epochs = 30
lr = 1e-3 
step_size = 10
gamma = 0.5
batch_size = 512
reg = 1e-7

In [None]:
suffix = '_N0_{}_M_{}_epo_{}_lr_{}_sss_{}_sdr_{}_bs_{}_reg_{}'.format(
           img_size, CR, num_epochs, lr, step_size,
           gamma, batch_size, reg)

H_network = H / img_size
Mean = np.load('./stats/Average_64x64.npy')/img_size
Cov  = np.load('./stats/Cov_64x64.npy')/img_size**2

model = DenoiCompNet(img_size, CR, Mean, Cov, net_arch, N0, sig, H_network, Cov2Var(Cov))
network_path = './models/NET_c0mp_N0_10000.0_sig_0.5_Denoi_N_64_M_2048_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07'
load_net(network_path, model, device)
model.to(device)

In [None]:
spectral_data

In [None]:
mu.min()

In [None]:
sigma.max()

In [None]:
imgs = subsample(spectral_data, CR).T
F_bin, wavelengths_bin_recon, bin_width, noise_bin = spectral_binning(imgs, wavelengths, 530, 730, 1, noise)

In [None]:
lambda_ind, = np.where((wavelengths > 530) & 
                           (wavelengths < 730))
lambda_ind.shape

In [None]:
math.sqrt(1723)*15

In [None]:
recon_fbin = reconstruction_hadamard(acquisition_metadata.patterns[:2*CR], 'walsh', H, F_bin.T)

plt.imshow(recon_fbin, cmap='gray')
plt.colorbar()

In [None]:
torch_img = torch.from_numpy(F_bin)
torch_img = torch_img.float()
torch_img = torch.reshape(torch_img, (1, 1, 2*CR)) # batches, channels, patterns
torch_img = torch_img.to(device)

result = (model.forward_reconstruct_expe(
    torch_img, 1, 1, img_size, img_size, 
    torch.from_numpy(noise_bin.mu).float().to(device),
    torch.from_numpy(noise_bin.sigma).float().to(device),
    torch.from_numpy(noise_bin.K).float().to(device),
) + 1) * N0 /2

result = result.cpu().detach().numpy().squeeze()

_,N0 = model.forward_preprocess_expe(torch_img, 1, 1, img_size, img_size)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(result, cmap='gray')
plt.colorbar()

In [None]:
torch_img = torch.from_numpy(F_bin)
torch_img = torch_img.float()
torch_img = torch.reshape(torch_img, (1, 1, 2*CR)) # batches, channels, patterns
torch_img = torch_img.to(device)

result = model.forward_reconstruct_pinv(
    torch_img, 1, 1, img_size, img_size,
)

result = (result+1)*model.N0/2
result = result.cpu().detach().numpy().squeeze()

plt.subplot(121)
plt.imshow(result, cmap='gray')
plt.colorbar()
plt.subplot(122)
plt.imshow(result - recon_fbin.squeeze(), cmap='gray')
plt.colorbar()
plt.tight_layout()

In [None]:
torch_img = torch.from_numpy(F_bin)
torch_img = torch_img.float()
torch_img = torch.reshape(torch_img, (1, 1, 2*CR)) # batches, channels, patterns
torch_img[0,0,:2] = 0
torch_img = torch_img.to(device)

model.N0 = 9316.7578125
 
result = model.forward_reconstruct(
    torch_img, 1, 1, img_size, img_size, 
)

#result = (result+1) * model.N0/2
result = result.cpu().detach().numpy().squeeze()

plt.imshow(result, cmap='gray')
plt.colorbar()

In [None]:
torch_img = torch.from_numpy(F_bin)
torch_img = torch_img.float()
torch_img = torch.reshape(torch_img, (1, 1, 2*CR)) # batches, channels, patterns
torch_img[0,0,:2] = 0
torch_img = torch_img.to(device)

model.N0 = 9316.7578125
 
result = model.forward_reconstruct_mmse(
    torch_img, 1, 1, img_size, img_size, 
)

#result = (result+1) * model.N0/2
result = result.cpu().detach().numpy().squeeze()

plt.imshow(result, cmap='gray')
plt.colorbar()

In [None]:
help(model.forward_reconstruct_mmse)