## Instructions
This notebook contains code for training a CNN (in one or two steps) and for generating stitched photometric composities/3D height maps via CNN inference. 
1. Specify the sample and set the optimization round to 1. Make sure `directory` is correct.
2. Run each cell in sequence until the CNN checkpoint file is saved. If a second round is necessary, restart the notebook and set the optimization round to 2. 
3. Run the cells under `Generate and save full reconstructions`, which load the saved CNN checkpoint files and perform CNN inference. This will gradually generate the photometric composite and 3D height map frames.

Each section below contains further details.

In [None]:
import numpy as np
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # restrict GPU usage
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
from mcam3d import mcam3d, flatten_illumination, load_stack
from tqdm.notebook import tqdm
import xarray as xr
import scipy.io

## Choose sample
Set `sample_id` and `optimization_round`. Round 1 registers images using the RGB intensities, while round 2 uses normalized high-pass-filtered versions. Restart the notebook in between rounds. Only the terrestrial organisms (fruit flies and harvester ants) need round 2.

In [None]:
sample_id = 'harvester_ants'  # 'zebrafish', 'fruit_flies', or 'harvester_ants'
optimization_round = 1  # 1 or 2

In [None]:
if sample_id == 'fruit_flies':
    directory = '/data/fruit_flies/'
    time_points = range(25, 601, 37)  # to make same length as range(0, 601, 40)
elif sample_id == 'harvester_ants':
    directory = '/data/harvester_ants/'
    time_points = range(0, 601, 40)
elif sample_id == 'zebrafish':
    directory = '/data/zebrafish/'
    time_points = range(0, 601, 40)
else:
    raise Exception('invalid sample_id: ' + sample_id)

sample_list = list()
for t in time_points:
    sample_list.append((os.path.join(directory, 'raw_video.nc'), 
                        os.path.join(directory, 'calibration2.mat'), t))

num_dataset = len(sample_list)

In [None]:
if sample_id == 'fruit_flies' or sample_id == 'harvester_ants':  # terrestrial samples
    if optimization_round == 1:
        height_map_reg_coef = 500
        hpf_sigmas = None 
        support_constraint_coef = None
        num_iter = 30001
        ckpt_path = None
    elif optimization_round == 2:
        height_map_reg_coef = 50
        hpf_sigmas = (1, 2, 4)
        support_constraint_coef = 100
        num_iter = 70001
        ckpt_path = os.path.join(directory, 'CNN_round1')
    support_constraint_threshold = 130
    rejection_threshold = 130
elif sample_id == 'zebrafish':  # aquatic samples
    assert optimization_round == 1  # only one round
    height_map_reg_coef = 50
    hpf_sigmas = None 
    support_constraint_coef = 100
    num_iter = 70001
    support_constraint_threshold = 185
    rejection_threshold = None
    ckt_path = None

## CNN-based stitching

In [None]:
def generate_model_and_dataset(im_stack_inds_keep, recon_shape, ul_offset, downsamp, filters_list, 
                               variable_initial_values, laplace_sigmas=None, gaussian_only=False,
                               use_neighboring_cameras=True, skip_list=None, hpf_sigmas=None,
                               num_patch=2, patch_size=576, sample_margin=None,  # for making dataset
                               visitation_log_margins=None, connectivity=np.ones((1, 3)),
                               rejection_threshold=None,  # if None, don't do anything; otherwise, reject batches
                               # based on whether it samples a region that meets this threshold
                               architecture='fcnn',
                              ):
    
    # create visitation_log at lower scale:
    a = mcam3d(stack=im_stack_inds_keep, ul_coords=np.zeros((len(inds_keep), 2)),
                recon_shape=recon_shape, ul_offset=ul_offset,
                scale=.01*downsamp*binning,
                batch_size=None,
                momentum=None,
                report_error_map=False,
               )
    a.use_camera_calibration = True
    a.effective_focal_length_mm = 25.05486
    a.magnification_j = 0.1133
    a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
        learning_rates={'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                       'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': 1e-3,
                       'radial_camera_distortion': 1e-3},
                       variable_initial_values=variable_initial_values,
                       remove_global_transform=True, antialiasing_filter=False)
    stack_downsamp, rc_downsamp = a.generate_dataset()

    loss_i, recon, normalize, error_map, tracked = a.gradient_update(stack_downsamp, rc_downsamp,
                                                                     return_tracked_tensors=True, 
                                                                     update_gradient=False)

    # plot recon:
    plt.figure(figsize=(8,5))
    plt.subplot(121)
    plt.imshow(np.uint8(recon[:,:,1]))
    plt.title('low resolution recon')
    if rejection_threshold is not None:
        good_regions = recon[:,:,1] < rejection_threshold
        plt.subplot(122)
        plt.imshow(good_regions)
        plt.title('good regions')
    else:
        good_regions = None
    plt.show()
    
    a.recompute_CNN = True  # to save memory

    # define network:
    a.unet_scale = .0001
    if skip_list is None:
        skip_list = [0]*len(filters_list)
    if connectivity is not None:
        num_inputs_to_expanded_stack = connectivity.size
    a.define_network_and_camera_params(tracked['vanish_warp'], tracked['camera_to_vanish_point_xyz'], 
                                       num_channels_rgb=3,
                                       architecture=architecture,
                                       filters_list=filters_list, 
                                       skip_list=skip_list, 
                                       num_inputs_to_expanded_stack=num_inputs_to_expanded_stack,
                                       learning_rate=1e-3)

    # generate visitation log:
    with tf.device('/CPU:0'):
        visitation_log_vars = a.generate_visitation_log(camera_margins=visitation_log_margins)

    # to allow CNN access to neighboring cameras:
    if use_neighboring_cameras:
        a.expand_stack_channels_for_CNN(connectivity=connectivity, camera_array_dims=camera_array_dims)
        
    # postpend channels that will be used for stitching instead of the CNN inputs (or a subset thereof):
    if laplace_sigmas is not None:
        assert hpf_sigmas is None
        a.postpend_edge_filtered_channels_for_registration(im_stack_inds_keep, sigmas=laplace_sigmas,
                                                           gaussian_only=gaussian_only
                                                          )
    if hpf_sigmas is not None:
        assert laplace_sigmas is None
        a.postpend_edge_filtered_channels_for_registration(im_stack_inds_keep, sigmas=hpf_sigmas, 
                                                           use_hpf_norm=True)
        
    # generate dataset
    a.visitation_log_scale = a.scale
    dataset = a.generate_patched_dataset(num_patch, patch_size, patch_size*2, 
                                         fracture_big_tensors=True, sample_margin=sample_margin,
                                         good_regions=good_regions
                                        )
        
    return a, dataset, visitation_log_vars

In [None]:
dataset_list = list()  # we're going to merge the datasets
downsamp = 1
for j, (nc_path, calibration_path, video_frame) in tqdm(enumerate(sample_list), total=len(sample_list)):
    # load image stack for one frame:
    im_stack = load_stack(nc_path, '', video_frame=video_frame)
    
    # get precalibrated values:
    restored = scipy.io.loadmat(calibration_path)
    variable_initial_values = {key:restored[key].squeeze() 
                               for key in restored if '__' not in key}
    variable_initial_values['rc'] = variable_initial_values['rc'] / downsamp
    
    # get other settings:
    inds_keep = restored['inds_keep__'].flatten()
    recon_shape = restored['recon_shape__'].flatten()
    ul_offset = restored['ul_offset__'].flatten()
    camera_array_dims = restored['camera_array_dims__'].flatten()
    optimize_illum = restored['optimize_illum__']
    if 'binning__' in restored:
        binning = restored['binning__'].squeeze()
    else:
        binning = 1
    
    if downsamp > 1:
        im_stack = im_stack[:, ::downsamp, ::downsamp, :]
        recon_shape = recon_shape // downsamp
        ul_offset = ul_offset // downsamp

    if optimize_illum:
        im_stack_inds_keep = flatten_illumination(im_stack, inds_keep, variable_initial_values['illum_flat'])
    else:
        im_stack_inds_keep = im_stack[inds_keep]

    sample_margin = [0, 0, .07, .07]
    visitation_log_margins = [.081, 0]  # for individual camera FOVs
    print('Bin: ' + str(binning))
    if binning == 1:
        if 'harvester_ants' in nc_path:
            filters_list = np.array([32, 32, 32, 32, 32, 32])
        else:
            filters_list = np.array([32, 32, 32, 32, 32])
        skip_list = np.array([0] * len(filters_list))
        patch_size = 1024
        num_patch = 1
    elif binning == 4:
        if 'harvester_ants' in nc_path:
            filters_list = np.array([32, 32, 32, 32])
        else:
            filters_list = np.array([32, 32, 32])
        skip_list = np.array([0] * len(filters_list))
        patch_size = 384
        num_patch = 8
    elif binning == 2:
        if 'harvester_ants' in nc_path:
            filters_list = np.array([32, 32, 32, 32, 32])
        else:
            filters_list = np.array([32, 32, 32, 32])
        skip_list = np.array([0] * len(filters_list))
        patch_size = 768
        num_patch = 2
    else:
        raise Exception('binning must be 1, 2,  or 4')

    laplace_sigmas = None
    assert (laplace_sigmas is None) + (hpf_sigmas is None) != 0  # only one of these can be chosen

    with tf.device('/CPU:0'):
        a, dataset, visitation_log_vars = generate_model_and_dataset(im_stack_inds_keep, recon_shape, ul_offset, 
                                                                     downsamp,
                                                                     filters_list,
                                                                     variable_initial_values,
                                                                     hpf_sigmas=hpf_sigmas,
                                                                     laplace_sigmas=laplace_sigmas,
                                                                     gaussian_only=False,
                                                                     num_patch=num_patch, patch_size=patch_size,
                                                                     sample_margin=sample_margin,
                                                                     skip_list=skip_list,
                                                                     visitation_log_margins=visitation_log_margins,
                                                                     rejection_threshold=rejection_threshold,
                                                                    )
    dataset_list.append(dataset)

plt.imshow((visitation_log_vars[0].numpy()>0).sum(2))
plt.title('visitation')
plt.show()
    
# combine all the datasets (i.e., cycle through all of them):
choice_dataset = tf.data.Dataset.range(num_dataset).repeat()
dataset = tf.data.experimental.choose_from_datasets(dataset_list, choice_dataset)
    
# restore checkpointed network:
if optimization_round == 2:
    a.ckpt = None
    a.checkpoint_all_variables(ckpt_path, skip_saving=True)
    a.restore_all_variables(ckpt_no=0)
    a.ckpt = None

In [None]:
ii = 0
losses = list()
for batch in tqdm(dataset, total=num_iter):
    if any([len(batch[0][i]) < 2 for i in range(num_patch)]):
        # there needs to be at least two patches to register!
        continue

    loss_i, recon_i, normalize_i, grads_i, norm_i = a.gradient_update_patch(
        batch, height_map_reg_coef=height_map_reg_coef,
        return_gradients=True,
        clip_gradient_norm=10,
        dither_coords=True,
        downsample_factor=1,
        support_constraint_coef=support_constraint_coef,
        support_constraint_threshold=support_constraint_threshold,
    )
    losses.append([l.numpy() for l in loss_i])
    
    if len(losses) % 100 == 0:
        print(len(losses), 'Loss terms: ' + str(losses[-1]))
        
    if len(losses) % 2000 == 0 or len(losses) == 1:
        plt.plot(losses)
        plt.title('loss history')
        plt.show()
        
        if len(losses) > 51:
            if type(loss_i) is list:
                plt.plot(np.convolve(np.array([loss[0] for loss in losses][50:]), np.ones(2000)/2000, 'valid'))
            else:
                plt.plot(np.convolve(np.array(losses[50:]).squeeze(), np.ones(2000)/2000, 'valid'))
            plt.title('blurred loss history')
            plt.show()
        
        plt.subplot(121)
        plt.imshow(np.uint8(recon_i[:,:,0, :-1]))
        plt.subplot(122)
        plt.imshow(recon_i[:,:,0, -1], cmap='jet')
        plt.show()
        
    if len(losses) % 10000 == 0:
        # generate full reconstruction:
        with tf.device('/CPU:0'):
            if binning == 1:
                margin = 30
            elif binning == 4:
                margin = 6
            elif binning == 2:
                margin = 12
                
            recon_full, _ = a.generate_full_recon(margin=margin)
            recon_full = recon_full.numpy()
            plt.figure(figsize=(10, 10))
            plt.imshow(recon_full[:, :, :-1].astype(np.uint8))
            plt.title('photometric recon')
            plt.show()
            plt.figure(figsize=(10, 10))
            plt.imshow(recon_full[:, :, -1], cmap='turbo')
            plt.title('height map recon')
            plt.show()
            
            plt.figure(figsize=(15,15))
            plt.imshow(recon_full[3000//binning:5000//binning, 4500//binning:6500//binning, :3].astype(np.uint8), cmap='turbo')
            plt.title('crop of photometric recon')
            plt.show()
            
            plt.figure(figsize=(15,15))
            plt.imshow(recon_full[3000//binning:5000//binning, 4500//binning:6500//binning, -1], cmap='turbo')
            plt.title('crop of height map recon')
            plt.show()

    if ii == num_iter:
        break
    else:
        ii+=1


In [None]:
# save CNN checkpoint file in a unique directory:
a.ckpt = None
if optimization_round == 1:
    ckpt_path = os.path.join(directory, 'CNN_round1')
elif optimization_round == 2:
    ckpt_path = os.path.join(directory, 'CNN_round2')
print(ckpt_path)
a.checkpoint_all_variables(path=ckpt_path)

Restart the notebook after round 1 and run round 2, if necessary. Otherwise, continue to next step:

## Generate and save full reconstructions
After optimizing the CNN (both rounds if applicable), run inference on the CNN to generate full videos here. If running from a reset notebook, run all cells until `generate_model_and_dataset` is defined.

In [None]:
binning = 2
nc_path = os.path.join(directory, 'raw_video.nc')
calibration_path = os.path.join(directory, 'calibration2.mat')
new_directory = os.path.join(directory, '3D_video_frames')  # where to save video frames
video_slice_range = np.arange(601)  # which frames to reconstruct
save_float32_copy = False  # script will save rgb and height maps as 3-channel uint8 images;
# if you want a float32 copy of the height map in mm, set sae_float32_copy to True

if sample_id == 'zebrafish':
    ckpt_path = os.path.join(directory, 'CNN_round1')  # trained network
    height_clims = [-1.4, 1.4]  # height range in mm
    color_corr = np.array([1.1284974, 1.1000000, 1.3200001], 
                          dtype=np.float32)  # correct the RGB color imbalance
elif sample_id == 'harvester_ants':
    ckpt_path = os.path.join(directory, 'CNN_round2')
    height_clims = [0, 4.5]
    color_corr = np.array([1.3615384, 1.1000000, 1.6225001], 
                          dtype=np.float32)
elif sample_id == 'fruit_flies':
    ckpt_path = os.path.join(directory, 'CNN_round2')
    height_clims = [0, 4]
    color_corr = np.array([1.3597223, 1.1000000, 1.6316667], 
                          dtype=np.float32)

if binning == 1:
    if 'harvester_ants' in nc_path:
        filters_list = np.array([32, 32, 32, 32, 32, 32])
    else:
        filters_list = np.array([32, 32, 32, 32, 32])
    device = '/CPU:0'  # OOM error if GPU not big enough
    margin = 20
    patch_size = 1024
elif binning == 2:
    if 'harvester_ants' in nc_path:
        filters_list = np.array([32, 32, 32, 32, 32])
    else:
        filters_list = np.array([32, 32, 32, 32])
    device = '/GPU:0'
    margin = 10
    patch_size = 768
elif binning == 4:
    if 'harvester_ants' in nc_path:
        filters_list = np.array([32, 32, 32, 32])
    else:
        filters_list = np.array([32, 32, 32])
    device = '/GPU:0'
    margin = 5
    patch_size = 384

num_patch = 1
downsamp = 1
sample_list = list()
for i in video_slice_range:
    sample_list.append((nc_path, calibration_path, i))

num_dataset = len(sample_list)

# make new directory to save these images:
os.makedirs(new_directory, exist_ok=True)
savename_base = os.path.join(new_directory, 'stitched_rgb')
height_savename_base = os.path.join(new_directory, 'height_map')

# colormap object:
turbo_color = matplotlib.cm.get_cmap('turbo')(np.arange(256))[:, :-1]  # get look up table values
# linearly interpolate between 256 colors (default behavior doesn't interpolate):
cmap = matplotlib.colors.LinearSegmentedColormap.from_list('turbo_interp', turbo_color, N=2**16)

In [None]:
# almost same script as for generating the tf.datasets, modified to just generate reconstructions:
for j, (nc_path, calibration_path, video_frame) in tqdm(enumerate(sample_list), total=len(sample_list)):
    # load image stack for one frame:
    im_stack = load_stack(nc_path, '', video_frame=video_frame)
    
    if downsamp > 1:
        im_stack = im_stack[:, ::downsamp, ::downsamp, :]
        recon_shape = recon_shape // downsamp
        ul_offset = ul_offset // downsamp
        x_pos = x_pos / downsamp
        y_pos = y_pos / downsamp
    
    # get precalibrated values:
    restored = scipy.io.loadmat(calibration_path)
    variable_initial_values = {key:restored[key].squeeze() 
                               for key in restored if '__' not in key}
    variable_initial_values['rc'] = variable_initial_values['rc'] / downsamp
    
    # get other settings:
    inds_keep = restored['inds_keep__'].flatten()
    recon_shape = restored['recon_shape__'].flatten()
    ul_offset = restored['ul_offset__'].flatten()
    camera_array_dims = restored['camera_array_dims__'].flatten()
    optimize_illum = restored['optimize_illum__']

    if optimize_illum:
        im_stack_inds_keep = flatten_illumination(im_stack, inds_keep, variable_initial_values['illum_flat'])
        print('optimized illumination')
    else:
        im_stack_inds_keep = im_stack[inds_keep]

    with tf.device(device):
        a, dataset, visitation_log_vars = generate_model_and_dataset(im_stack_inds_keep, recon_shape, ul_offset, 
                                                                     downsamp,
                                                                     filters_list,
                                                                     variable_initial_values,
                                                                     num_patch=num_patch, patch_size=patch_size
                                                                    )

        # restore old network:
        a.ckpt = None
        a.checkpoint_all_variables(ckpt_path,
                                   skip_saving=True)
        a.restore_all_variables(ckpt_no=0)
        a.ckpt = None
        
        # generate reconstruction:
        recon_full, normalize = a.generate_full_recon(margin=margin)
        recon_full = recon_full.numpy()
        normalize = normalize.numpy()
        
        if j == 0:
            # keep a fixed color range:
            clims = height_clims
            
            # crop out regions of zeros:
            r_nonzero, c_nonzero = np.nonzero(recon_full[:, :, -1])
            r0 = r_nonzero.min()
            r1 = r_nonzero.max()
            c0 = c_nonzero.min()
            c1 = c_nonzero.max()
            
        # crop margins:
        recon_cropped = recon_full[r0:r1, c0:c1].copy()
        normalize_cropped = normalize[r0:r1, c0:c1].copy()
        
        # generate height map:
        height_map = recon_cropped[:, :, -1]
            
        # names for saving:
        savename = savename_base + '_' + f'{j:04}' + '.png'
        height_savename = height_savename_base + '_' + f'{j:04}' + '.png'
        height_savename_f32 = height_savename_base + '_' + f'{j:04}' + '.mat'
        
        # save height map as float32 before casting down to uint8:
        if save_float32_copy:
            scipy.io.savemat(height_savename_f32, {'height_map': height_map})
        
        # cast down to save as RGB images:
        recon = np.uint8(np.clip(recon_cropped[:, :, :-1] * color_corr[None, None, :], 0, 255))
        height_map = np.clip(height_map, clims[0], clims[1])
        
        # save reconstruction:
        plt.imsave(savename, recon)
        
        # convert height map to RGB for saving:
        height_map -= clims[0]
        height_map /= (clims[1] - clims[0])
        height_map = (255*cmap(height_map)).astype(np.uint8)  # converted to rgb
        plt.imsave(height_savename, height_map)
