In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)  # default is True, which reduces precision
from time import time
import scipy.io
from ReFLeCT import *
import xarray as xr
import os
import h5py
from tqdm.notebook import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # restrict GPU usage
if len(tf.config.list_physical_devices('GPU')) == 0:
    print('warning, not using GPU')  # different behavior for out-of-bounds gather_nd between gpu vs cpu

scale = .5
batch_size = 20000
downsamp = 1
pixel_size = 8  # recon pizel size in um, when scale=1
loss_print_iter = 10  # print loss every this many iters
recon_shape = (512, 512, 512)
inds_keep = np.arange(54)
inds_keep = np.delete(inds_keep, [26, 27])
depth_of_field = 5000  # um
camera_dims = (9, 6)
superimpose_images = False
ignore_recon_from_restore_path = True  # probably always true
weight_background = .01  # for reducing background artifacts
shuffle_buffer = 196608  # for shuffling rays
preshuffle = True
interp_rays = False  # false during final optimization
recon_lr = 10  # learning rate of the reconstruction (set according to optimization_round)
attenuation_map_lr = .5
r_extra_lr = 10000
extra_finetune_order = 15  # specify an integer for extra polynomial modeling post-nmr
dither_rays = True
occupancy_grid = None  # pre-optimized occupancy grid
bias_threshold = 2  # pixel bias in gray levels (out of 255)
recon_filepath = None  # where to save the final reconstruction; if None, the one is automatically generated
use_attenuation_map = True
non_recon_parameters_to_optimize = ['attenuation_map']
plot_iter = 250

data_path = '/data/'

In [None]:
optimize_single_frame = True  # one time point or whole video? run the corresponding cells below

# pick one of the four samples:
sample_id = 'drosophila_muscle'
# sample_id = 'drosophila_pericardial'
# sample_id = 'zebrafish_notochord'
# sample_id = 'zebrafish_heart'


# set sample-specific parameters:
if sample_id == 'zebrafish_notochord':
    binning = 4
    xyz_offset = np.array([0, 0, -800])  # in um
    recon_shape = (800, 800, 800)
    base_directory = os.path.join(data_path, 'zebrafish_notochord/')
    video_frames = np.arange(1201)
    reg_coefs = {'L1_attenuation_batch': {'coef': 0.018, 'scale': 1000}}
    num_iter = 250
elif sample_id == 'drosophila_pericardial':
    binning = 2
    xyz_offset = np.array([0, 0, -400])  # in um
    recon_shape = (700, 700, 700)
    base_directory = os.path.join(data_path, 'drosophila_pericardial/')
    video_frames = np.arange(301)
    reg_coefs = {'L1_attenuation_batch': {'coef': 0.036, 'scale': 1000}}
    attenuation_map_lr = .1
    num_iter = 750
elif sample_id == 'zebrafish_heart':
    binning = 2
    xyz_offset = np.array([0, 0, 0])  # in um
    recon_shape = (800, 800, 800)
    base_directory = os.path.join(data_path, 'zebrafish_heart/')
    video_frames = np.arange(301)
    reg_coefs = {'L1_attenuation_batch': {'coef': 0.012, 'scale': 1000}}
    num_iter = 400  # 2000
elif sample_id == 'drosophila_muscle':
    binning = 4
    xyz_offset = np.array([0, 0, 0])  # in um
    recon_shape = (700, 700, 700)
    base_directory = os.path.join(data_path, 'drosophila_muscle/')
    video_frames = np.arange(0, 1201)
    reg_coefs = {'L1_attenuation_batch': {'coef': 0.036, 'scale': 1000}}
    attenuation_map_lr = .1
    num_iter = 250
else:
    raise Exception('invalid sample_id')

if shuffle_buffer is not None:
    assert preshuffle  # otherwise, you might get poor shuffling
    
if optimize_single_frame:
    # only load one frame into memory; otherwise, it'll take a while
    video_frames = np.array([0])  # or pick a different single frame
    
restore_path = base_directory + 'calibration_parameters.mat'
filepaths = ['raw_video.nc']
occupancy_grid = np.load(base_directory + 'occupancy_grid.npz')['occupancy_grid']
occupancy_grid = occupancy_grid[video_frames]

In [None]:
stack, stack_list = get_mcam_video_data(base_directory, filepaths, video_frames, superimpose_images)
num_video_frames = len(stack_list)

# remove pixel bias:
if bias_threshold is not None:
    stack = np.maximum(stack, bias_threshold)
    stack = stack - bias_threshold
    for i in range(len(stack_list)):
        stack_list[i] = np.maximum(stack_list[i], bias_threshold)
        stack_list[i] = stack_list[i] - bias_threshold

In [None]:
# plot raw multi-view data for one frame:
plt.figure(figsize=(10,15))
for i, im in enumerate(stack):
    plt.subplot(9, 6, i + 1)
    plt.imshow(im)
    plt.axis('off')
    plt.title(str(i))

In [None]:
# positions of the 54 cameras:
cam_sep = 13.5  # camera separation in mm
row = np.arange(camera_dims[1]) * cam_sep
col = -np.arange(camera_dims[0]) * cam_sep
row -= row.mean()
col -= col.mean()
row, col = np.meshgrid(row, col, indexing='ij')
camera_positions = np.stack([row.T.flatten(), col.T.flatten()], axis=0).T
plt.plot(camera_positions[:, 0], camera_positions[:, 1], 'o-')
plt.gca().set_aspect('equal')

In [None]:
# instantiate para mcam:
a = para_mcam(stack, recon_shape=recon_shape, dxyz=pixel_size, scale=scale, xyz_offset=xyz_offset, 
              batch_size=batch_size, batch_across_images=False, depth_of_field=depth_of_field, 
              interp_rays=interp_rays, dither_rays=dither_rays, occupancy_grid=occupancy_grid, 
             )
propagation_model = 'parabolic_nonparametric_higher_order_correction_nmr_tube'
nmr_tube_model = 'thick_wall'
if extra_finetune_order is not None:
    propagation_model += 'extra_finetune'
    a.extra_finetune_order = extra_finetune_order

In [None]:
learning_rates = {'f_mirror': -1e-1, 'f_lens': 1e-2,  # negative means not optimized
                  'galvo_xy': -1e-3, 'galvo_normal': -1e-3, 'galvo_theta': -1e-3,
                  'probe_dx': 1e-2, 'probe_dy': 1e-2, 'probe_z': 1e-2, 'probe_normal': 1e-3,
                  'probe_theta': 1e-3, 'recon': 1, 'per_image_scale': 1e-3, 'per_image_bias': -1e-3,              
                 }
if 'nonparametric' in propagation_model:
    learning_rates = {**learning_rates, 'delta_r': -1e-3, 'delta_u': -1e-3, 'r_2nd_order': -1e-3, 
                      'u_2nd_order': -1e-3, 'r_higher_order': -1e-3, 'u_higher_order': -1e-3,}
if 'nmr_tube' in propagation_model:
    for key in learning_rates:
        if key != 'recon' and key != 'per_image_scale':
            learning_rates[key] = -1
    learning_rates = {**learning_rates, 'nmr_outer_radius': 1e-3, 'nmr_inner_radius': 1e-3, 'nmr_delta_r': 1e-3, 
                      'nmr_normal': 1e-3, 'nmr_theta': 1e-3, 'n_glass': -1e-3, 'n_medium': -1e-3}
if 'extra_finetune' in propagation_model:
    learning_rates = {**learning_rates, 'r_extra': r_extra_lr, 'u_extra': 1e-3}

if use_attenuation_map:
    learning_rates = {**learning_rates, 'attenuation_map': attenuation_map_lr}

if restore_path is not None:
    restored = scipy.io.loadmat(restore_path)
    variable_initial_values = {key:restored[key].squeeze() 
                               for key in restored 
                               if '__' not in key and 'recon' not in key and 'xyz_shifts' not in key 
                               and 'attenuation' not in key}
    if 'recon' in restored and not ignore_recon_from_restore_path:
        recon = restored['recon']
    else:
        recon = None
else:
    variable_initial_values = {'f_mirror': 25.4, 'f_lens': 26.23, 'galvo_xy': 4.224/2,  
                               'galvo_normal': np.array((1e-7, 1e-7, -1), dtype=np.float32),  
                               'galvo_theta': 0, 'probe_dx': 0, 'probe_dy': 0,  'probe_z': 25,  
                               'probe_normal': np.array((1e-7, -1), dtype=np.float32),
                               'probe_theta': 0,
                              }  
    recon=None
        
if ('nmr_tube' in propagation_model and nmr_tube_model is not None 
    and restore_path is not None and 'nmr_outer_radius' not in restored):
    if nmr_tube_model == 'thick_wall':
        variable_initial_values['nmr_outer_radius'] = 4.9635 / 2
        variable_initial_values['nmr_inner_radius'] = 3.43 / 2
    elif nmr_tube_model == 'thin_wall':
        variable_initial_values['nmr_outer_radius'] = 4.9635 / 2
        variable_initial_values['nmr_inner_radius'] = 4.2065 / 2

if recon_lr is not None:
    learning_rates['recon'] = recon_lr
else:
    print('warning: recon_lr not set')
        
# final adjustment:
for key in learning_rates:
    if key not in non_recon_parameters_to_optimize and 'recon' not in key:
        learning_rates[key] = -1
            
a.create_variables(nominal_probe_xy=camera_positions, inds_keep=inds_keep, propagation_model=propagation_model,
                   learning_rates=learning_rates, variable_initial_values=variable_initial_values, 
                   stack_downsample_factor=downsamp, recon=recon, use_attenuation_map=use_attenuation_map
                  )
with tf.device('/cpu:0'):  
    dataset_list = list()
    for i, stack in tqdm(enumerate(stack_list), total=num_video_frames):
        if num_video_frames > 1:  # using multiple video frames; need to unique identifier for each dataset
            identifier = i
        else:
            identifier = None
        dataset_list.append(a.generate_dataset(a.format_stack(stack), identifier=identifier, 
                                               preshuffle=preshuffle, shuffle_buffer=shuffle_buffer, seed=i))

losses = list()
variables = list()
track_list = [key for key in learning_rates if learning_rates[key]>0 and 'recon' not in key]

ii = 0
    
for key in learning_rates:
    if learning_rates[key] > 0:
        print(key, learning_rates[key])
print(reg_coefs)

## Single-frame optimization
Reconstruct volume at one time point. Projections and cross-sections are plotted periodically during optimization.

In [None]:
assert len(dataset_list) == 1
assert optimize_single_frame

# optimization loop for single volume:
for batch in tqdm(dataset_list[0], total=num_iter): 
    start = time()

    loss_i, recon_i, tracked = a.gradient_update(batch, return_tracked_tensors=True, reg_coefs=reg_coefs, 
                                                 dataset_index=None, weight_background=weight_background
                                                )

    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())

    if ii % loss_print_iter == 0:
        variables.append({key:a.train_var_dict[key].numpy() for key in a.train_var_dict 
                          if key not in ['recon', 'attenuation_map']})
        print(ii, losses[-1], time()-start)

    if ii % plot_iter == 0:
        summarize_recon(recon_i[..., 0].numpy())  # plot cross sections of the reconstruction
        if use_attenuation_map:
            summarize_recon(a.train_var_dict['attenuation_map'].numpy())

        plt.figure(figsize=(13,4))
        plt.subplot(121)
        plt.plot(losses)
        plt.title('loss history')
        plt.legend(a.loss_list_names)
        plt.subplot(122)
        plt.plot(np.log(losses))
        plt.title('log loss history')
        plt.legend(a.loss_list_names)
        plt.show()

    if ii == num_iter:
        break
    ii += 1

## Frame by frame video reconstructions
Do this after experimenting with the above section (just one frame).
The code below will reconstruct one time point at a time, saving the two-channel reconstruction as an hdf5 file for each time point.

In [None]:
assert not optimize_single_frame

# folder for saving all reconstructions; if doesn't exist, make one:
if recon_filepath is None:
    save_directory = os.path.join(base_directory, sample_id + '_4D_reconstructions')
else:
    save_directory = recon_filepath
if not os.path.exists(save_directory):
    os.mkdir(save_directory)
    print('made path: ' + save_directory)
    
# shape of each hdf5 file:
hdf5_shape = (2,) + tuple(int(dim * scale) for dim in recon_shape)
print(hdf5_shape)
hdf5_filename_base = 'recon'
compression_level = 1  # for gzip compression

In [None]:
for i_dataset, dataset in tqdm(zip(video_frames, dataset_list), total=len(dataset_list)):
    
    losses = list()
    ii = 0
    # inner loop is a reduced version of the above single-dataset optimization section:
    for batch in tqdm(dataset):

        loss_i, recon_i, tracked = a.gradient_update(batch, return_tracked_tensors=True, reg_coefs=reg_coefs, 
                                                     dataset_index=None, weight_background=weight_background
                                                    )

        if type(loss_i) is list:
            losses.append([loss.numpy() for loss in loss_i])
        else:
            losses.append(loss_i.numpy())

        if ii == num_iter:
            break
        ii += 1
        
        
    # open, save, and close:
    save_start = time()
    with h5py.File(os.path.join(save_directory, hdf5_filename_base + '_' + str(i_dataset) + '.hdf5'), 'w') as f:
        # save 2-channel 3D volume:
        fluorescence = recon_i[..., 0].numpy()
        attenuation = a.train_var_dict['attenuation_map'].numpy()
        recon = np.stack([fluorescence, attenuation], axis=0)
        hdf5_dataset = f.create_dataset('cxyz', hdf5_shape, dtype='float32', data=recon,
                                        compression='gzip', compression_opts=compression_level)
        
        # save projections too for convenience:
        hdf5_cxy = f.create_dataset('cxy', (2, hdf5_shape[1], hdf5_shape[2]), dtype='float32', data=recon.max(3))
        hdf5_cxz = f.create_dataset('cxz', (2, hdf5_shape[1], hdf5_shape[3]), dtype='float32', data=recon.max(2))
        hdf5_cyz = f.create_dataset('cyz', (2, hdf5_shape[2], hdf5_shape[3]), dtype='float32', data=recon.max(1))
    print('Time to save: ' + str(time() - save_start) + ' sec')
        
        
    # summarize results:
    if i_dataset % 10 == 0:
        summarize_recon(fluorescence)  # plot cross sections of the reconstruction
        if use_attenuation_map:
            summarize_recon(attenuation)

        plt.plot(losses)
        plt.title('loss history')
        plt.legend(a.loss_list_names)
        plt.show()
        plt.plot(np.log(losses))
        plt.title('log loss history')
        plt.legend(a.loss_list_names)
        plt.show()
        
    # reset optimizer state:
    for var in a.optimizer_dict['recon'].variables():
        var.assign(tf.zeros_like(var))
    for var in a.optimizer_dict['attenuation_map'].variables():
        var.assign(tf.zeros_like(var))
    a.train_var_dict['recon'].assign(tf.zeros_like(a.train_var_dict['recon']))
    a.train_var_dict['attenuation_map'].assign(tf.zeros_like(a.train_var_dict['attenuation_map']))