In [None]:
sample_id = 'zebrafish'  # should be 'fruitfly', 'zebrafish', 'trachea', or 'esophagus'

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from time import time
import scipy.io
from paraOCRT import *
import h5py
import os

batch_size_stratified = 50  # this times the number of angles is the total number of Ascans per batch
n_batches = 400**2 / batch_size_stratified
scale = .5
RI_scale = .25
pixel_size = 1  # recon pizel size in um, when scale=1
momentum = 4 / n_batches  # see paper for explanation
lr_multiplier = .1  # rescale all learning rates by this
th_range = np.arange(96); th_range = np.delete(th_range, [27, 72, 73, 74, 75])  # batch from a subset of the angles
update_gradient_after = 1000  # only start updating tf.Variables after this many iterations
loss_print_iter = 10  # print loss every this many iters
plot_iter = 1000  # plot every this many iters
num_iter = 50000  # total number of iterations

In [None]:
if sample_id == 'trachea':
    path = './data/trachea/trachea.hdf5'  # where the dataset is stored
    restore_path = './data/trachea/tf_ckpts/'  # initialize with parameters optimized on calibration sample
    recon_shape = (1500, 1500, 500)
    xyz_offset = np.array([0, 190, 0])  # to center the reconstruction
    
elif sample_id == 'esophagus':
    path = './data/esophagus/esophagus.hdf5'
    restore_path = './data/esophagus/tf_ckpts/'
    recon_shape = (1000, 1000, 850)
    xyz_offset = np.array([0, 330, 0]) 
    
elif sample_id == 'fruitfly':
    path = './data/fruitfly/fruitfly.hdf5'
    restore_path = './data/fruitfly/tf_ckpts/'
    recon_shape = (1200, 1200, 1000)
    xyz_offset = np.array([0, 280, 80])

elif sample_id == 'zebrafish':
    path = './data/zebrafish/zebrafish.hdf5'
    restore_path = './data/zebrafish/tf_ckpts/'
    recon_shape = (1200, 1200, 700)
    xyz_offset = np.array([0, 130, 20])
    
else:
    raise Exception('invalid sample_id: ' + sample_id)

In [None]:
with h5py.File(path, 'r') as f:
    num_x = f.attrs['num_x']
    num_y = f.attrs['num_y']
    num_th = f.attrs['num_th']
    A_scan_num = f.attrs['Ascan_len']
    stack_shape = (num_th, num_x, num_y, A_scan_num)
    num_Ascans = np.prod(stack_shape[:-1])  # number of A-scans total
    hx = np.array(f['galvo_x'])
    hy = np.array(f['galvo_y'])
    x = np.array(f['probe_x'])
    y = np.array(f['probe_y'])
y -= 10
hx = - hx / hx[0] # normalize to first value
hy = - hy / hy[0]
galvo_xy = np.stack([hx, hy], axis=1)
probe_xy = np.stack([x, y], axis=1)
    
# instantiate paraOCRT:
a = paraOCRT(recon_shape=recon_shape, RI_shape_scale=RI_scale, dxyz=pixel_size, hdf5_path=path, 
             batch_size=batch_size_stratified, scale=scale, momentum=momentum,
            )
a.shuffle_size = 1  # don't shuffle
a.prefetch = 5
a.th_range = th_range
a.data_num_x = num_x
a.data_num_y = num_y
a.data_num_z = A_scan_num
a.z_downsamp = 10  # downsample the A-scan when propagating to save memory/time
a.use_first_reflection_RI_loss = True
a.correct_momentum_bias_in_loss = True
a.xyz_offset = xyz_offset
a.n_back = 1.342
a.z_start = 0
a.z_end = 1500
    
learning_rates = {'f_mirror': -1e-1, 'f_lens': -1e-1,  # negative means not optimized
                  'galvo_xy': 1e-3, 'galvo_normal': 1e-3, 'galvo_theta': 1e-3,
                  'probe_dx': 1e-2, 'probe_dy': 1e-2, 'probe_z': 1e-2, 'probe_normal': 1e-3,
                  'probe_theta': 1e-3, 'd_before_f': 1e-3, 'RI': 1e-2, 'Ascan_background': 1e-3,
                  'delta_r': 1e-3, 'delta_u': 1e-3, 'galvo_xy_per': 1e-3, 
                  'galvo_theta_in_plane_per': 1e-3, 'probe_dxyz_per': 1e-3,
                  'r_2nd_order': 1e-3, 'u_2nd_order': 1e-3,
                  'dome_inner_radius': -1e-3, 'dome_outer_radius': -1e-3, 'dome_center': 1e-3,
                  'r_higher_order': 1e-3, 'u_higher_order': 1e-3,
                 }
# only optimize RI:
for key in learning_rates:
    if key == 'RI' or key == 'Ascan_background':
        pass
    else:
        learning_rates[key] = -1
# ... and boundary conditions:
learning_rates['delta_r'] = 1
learning_rates['delta_u'] = 1e-2
learning_rates['galvo_theta_in_plane_per'] = 1e-3
    
if lr_multiplier is not None:
    for key in learning_rates:
        learning_rates[key] *= lr_multiplier
        
variable_initial_values = {'f_mirror': 12.5, 'f_lens': 30, 'galvo_xy': .7 * np.ones(2, dtype=np.float32) *12/7,  
                           'galvo_normal': np.array((1e-7, 1e-7, -1), dtype=np.float32),  
                           'galvo_theta': 0, 'probe_dx': 0, 'probe_dy': 0,  'probe_z': -10,  
                           'probe_normal': np.array((1e-7, 1e-7, -1), dtype=np.float32),
                           'probe_theta': 0, 'd_before_f': .725,
                           'effective_inverse_momentum': 11446
                          }  

a.create_variables(nominal_probe_xy=probe_xy, nominal_galvo_xy=galvo_xy, 
                   propagation_model='parabolic_dome_nonparametric_higher_order_correction',
                   learning_rates=learning_rates, variable_initial_values=variable_initial_values)
reg_coefs = {'first_reflection_RI': 1e-5, 'RI_TV2': 3e-8}  # regularization coefficients
    
dataset = a.generate_dataset()  # dataset yields stratified batches of random A-scans

losses = list()
ii = 0  # optimization iteration counter

# restoring parameters from a checkpoint:
if restore_path is not None:
    a.checkpoint_all_variables(path=restore_path, skip_saving=True, var_ignore=['RI'])
    a.restore_all_variables()

In [None]:
# optimization loop:
for Ascan_batch, batch_inds in dataset:
    if ii == 0:
        # use the average of the first batch of A-scans as the initial guess
        Ascan_background_init = Ascan_batch.numpy().mean(0).mean(0)
        a.train_var_dict['Ascan_background'].assign(Ascan_background_init)
  
    start = time()
    
    if ii > update_gradient_after:
        update_gradient = True
    else:
        update_gradient = False
    
    try:  # update variables
        loss_i, recon_i = a.gradient_update(Ascan_batch, batch_inds, update_gradient=update_gradient,
                                            reg_coefs=reg_coefs)
    except ValueError:
        # for some reason, I get an error when update_gradient changes from False to True, but only the first time
        pass
    
    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())
        
    # force RI to be greater than that of background:
    a.train_var_dict['RI'].assign(tf.math.maximum(a.train_var_dict['RI'], a.n_back))
        
    if ii % loss_print_iter == 0:
        print(ii, losses[-1], time()-start)
    
    if ii % plot_iter == 0:
        summarize_recon(recon_i.numpy())  # plot cross sections of the reconstruction;
        summarize_recon(a.train_var_dict['RI'].numpy(), 'viridis', True)

        plt.plot(losses)
        plt.title('loss history')
        plt.show()
        plt.plot(np.log(losses))
        plt.title('log loss history')
        plt.show()
        if ii > 500:
            plt.plot(np.convolve(np.array([loss[0] for loss in losses][50:]), np.ones(2000)/2000, 'valid'))
            plt.title('blurred loss history')
            plt.show()

    if ii == num_iter:
        break
    else:
        ii += 1
    

In [None]:
# save results:
scipy.io.savemat(sample_id + '.mat', {'recon': recon_i.numpy(),
                                      'RI': a.train_var_dict['RI'].numpy(),
                                      'losses': losses
                                     })