In [None]:
import os
import sys
import math
import copy
import time
import queue
import random
import h5py as hp
import numpy as np
import scipy.misc
import scipy.io as sio
import scipy.ndimage
import skimage.feature
import matplotlib.pyplot as plt

import torch
import torch.optim
import torch.nn as nn
import torch.nn.init
import torch.nn.functional as F

from utils import *
from physical import *
from models import *

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE"
dtype = torch.cuda.FloatTensor
ctype = torch.complex64

pi = torch.acos(torch.zeros(1)).item() * 2

In [None]:
with hp.File('./source/holo_back_pair_beads_commerical_same_dist_02232022.h5', 'r') as ee:
    _I_back = torch.from_numpy(ee['I_back'][()])
    _I_det = torch.from_numpy(ee['I_det'][()])
    
with hp.File('./source/uinc_gaussian.h5', 'r') as ee:
    _uinc_gaussian_r = torch.from_numpy(ee['uinc_gaussian_r'][()])
    _uinc_gaussian_i = torch.from_numpy(ee['uinc_gaussian_i'][()])
    
with hp.File('./source/uinc.h5', 'r') as ee:
    _uinc_r = torch.from_numpy(ee['uinc_r'][()])
    _uinc_i = torch.from_numpy(ee['uinc_i'][()])
    
_uinc_gaussian = _uinc_gaussian_r + 1j * _uinc_gaussian_i
_uinc = _uinc_r + 1j * _uinc_i


# Simple backpropagation demo
dim = 512
pad_size = 0
amp = pad_object(torch.ones(1, 1, dim, dim), pad_size)
pha = pad_object(torch.ones(1, 1, dim, dim), pad_size)
obj = to_phasor_notation(amp, pha, 'cpu').squeeze(0).squeeze(0)

c_wv = 793e-9
Lz = 9.75e-3

p = PhysicalSystem(dim)
p.initialize_parameters(c_wv, Lz, 0.0e-6)
assert p.Nx == dim, "Check dimensions of a computational grid!"

cnt_coord = (800, 800)
uinc = _uinc[cnt_coord[0]-dim//2 : cnt_coord[0]+dim//2, cnt_coord[1]-dim//2 : cnt_coord[1]+dim//2]
assert uinc.shape == (dim, dim), "Check dimensions of an illumination profile!"
H = p.forward_operator()
H_inv = p.inverse_operator()

I_det = torch.rot90(_I_det[cnt_coord[0]-dim//2 : cnt_coord[0]+dim//2, cnt_coord[1]-dim//2 : cnt_coord[1]+dim//2], 0)
I_back = torch.rot90(_I_back[cnt_coord[0]-dim//2 : cnt_coord[0]+dim//2, cnt_coord[1]-dim//2 : cnt_coord[1]+dim//2], 0)

psi = torch.fft.ifft2(torch.fft.fft2(obj * uinc) * H)
psi /= torch.abs(psi)
psi *= torch.sqrt(I_det)
psi = torch.fft.ifft2(torch.fft.fft2(psi) * H_inv)

cm = 'gray'
plt.figure(figsize = (15, 10))
plt.subplot(1,3,1); plt.imshow(I_det, cmap = cm); 
plt.subplot(1,3,2); plt.imshow(torch.abs(psi), cmap = cm); 
plt.subplot(1,3,3); plt.imshow(torch.angle(psi), cmap = cm);
plt.show()

one_iter_back_amp = torch.abs(psi)
one_iter_back_pha = torch.angle(psi)

one_iter_rec = []
one_iter_rec.append(one_iter_back_amp)
one_iter_rec.append(one_iter_back_pha)
one_iter_rec = torch.stack(one_iter_rec, dim = 0).unsqueeze(1)

In [None]:
with hp.File('./source/M810D3_30_7mA_bias.mat', 'r') as ee:
    meas_spec_commercial = np.reshape(ee['yy'][()], -1)
    meas_wv_commercial = np.reshape(ee['xx'][()], -1)
    
    
LEFT_WIDTH = 70e-9
RIGHT_WIDTH = 65e-9
DISPLACEMENTS = torch.linspace(-LEFT_WIDTH, RIGHT_WIDTH, 50)
LR = 1e-3
EPOCH = 550

c_wv = 793e-9
dim = 512
Ny_obj = dim
Nx_obj = dim
in_channels = 32
out_channels = 1
vmin_amp = 0.
vmax_amp = 1.0
vmin_pha = 0. * pi
vmax_pha = 0.25 * pi

input_dims = (1, in_channels, Ny_obj, Nx_obj)
channels_up = [32, 64, 128, 256, 512]
channels_down = [32, 64, 128, 256, 512]

Lz = 9.75e-3
l_dz = Lz
p.initialize_parameters(c_wv, Lz, 0.0e-6)

wv_g = wavelength_gammas(len(DISPLACEMENTS), 'constant')

net_amp = model_parallel(in_channels, out_channels, channels_down, channels_up, vmin_amp, vmax_amp)
net_pha = model_parallel(in_channels, out_channels, channels_down, channels_up, vmin_pha, vmax_pha)
input_noise_net = noise_net(input_dims)

net_amp.apply(weights_init_amp)
net_pha.apply(weights_init_pha)

loss_list = np.empty(shape = (1 + EPOCH, ))
loss_list[:] = np.NaN
weights_list = np.empty(shape = (1 + EPOCH, len(DISPLACEMENTS)))
weights_list[:] = np.NaN
Lz_list = np.empty(shape = (1 + EPOCH, ))
Lz_list[:] = np.NaN    

In [None]:
train(net_amp,                            # Untrained neural network for amplitude
      net_pha,                            # Untrained neural network for phase
      input_noise_net,                    # Input random noise
      l_dz,                               # Sample-to-camera distance
      dim,                                # Image dimension
      uinc,                               # Illumination beam profile
      I_det,                              # Experimental intensity on the detector plane
      one_iter_rec,                       # One-iteration reconstruction using a gradient descent method.
      p,                                  # PhysicalSystem object
      loss_list,                          # List object for optimization loss function
      weights_list,                       # List object for learned spectrum
      Lz_list,                            # List object for sample-to-camera distance
      c_wv,                               # Center wavelength
      wv_g,                               # Learned spectrum (\gamma_n's)
      meas_wv_commercial,                 # Measured spectrum (x-axis)
      meas_spec_commercial,               # Measured spectrum (y-axis)
      displacements = DISPLACEMENTS,      # Sampled wavelengths
      lr = LR,                            # Initial learning rate
      epochs = EPOCH                      # Number of epochs
     )