## Instructions
There are three calibration steps (all steps are required):
1. Camera pose and distortion estimation
2. Illumination optimization -- reducing intensity variation
3. "Digital refocusing" -- adjusting the zero-height reference

Under "Specify sample and calibration step", specify which sample via `sample_id` and which of the three above calibration steps via `calibration_step`. Make sure `directory` is correct. Then, run the cells under the relevant headers. Each step generates a `.mat` calibration file, which is used by the next step. Restart the notebook after each step. The final calibration file (`calibration2.mat`) will be used by the training/inference notebook.


In [None]:
import numpy as np
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # restrict GPU usage
import matplotlib.pyplot as plt
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
from mesoSfM import mesoSfM, monitor_progress
from mcam3d import mcam3d, flatten_illumination, load_stack
from tqdm.notebook import tqdm
import xarray as xr
import scipy.io

## Specify sample and calibration step
Set `sample_id` and `calibration_step`.

In [None]:
sample_id = 'zebrafish'  # 'zebrafish', 'fruit_flies', or 'harvester_ants'
calibration_step = 'camera_pose'  # should be, in order, one of: 'camera_pose', 'illumination', 'refocusing'

In [None]:
if sample_id == 'fruit_flies':
    directory = '/data/fruit_flies/'
elif sample_id == 'harvester_ants':
    directory = '/data/harvester_ants/'
elif sample_id == 'zebrafish':
    directory = '/data/zebrafish/'
else:
    raise Exception('invalid sample_id: ' + sample_id)

if calibration_step == 'camera_pose':
    # load the calibration dataset of the flat patterned target
    filename = 'calibration_dataset.nc'; video_frame=None; binning=1;
    initial_guess_path = '/data/camera_calibration_initial_guess.mat'
    optimize_ONLY_illum = False
    savepath = os.path.join(directory, 'calibration0.mat')
elif calibration_step == 'illumination':  # optional
    # load the raw video dataset
    filename = 'raw_video.nc'; video_frame=25; binning=2
    # this initial guess is generated by the 'flat_target' calibration step:
    initial_guess_path = os.path.join(directory, 'calibration0.mat')
    optimize_ONLY_illum = True
    savepath = os.path.join(directory, 'calibration1.mat')
elif calibration_step == 'refocusing':  # optional
    # load the raw video dataset
    filename = 'raw_video.nc'; video_frame=25; binning=2
    initial_guess_path = os.path.join(directory, 'calibration1.mat')
    binning_ratio = 0.5  # bin ratio relative the effective binning in initial_guess_path^
    savepath = os.path.join(directory, 'calibration2.mat')
else:
    raise Exception('invalid calibration_step: ' + calibration_step)

In [None]:
# load MCAM image data:
im_stack = load_stack(directory, filename, video_frame=video_frame)

magnification = 0.1133
num_channels = im_stack.shape[-1]
inds_keep = np.arange(len(im_stack))
camera_array_dims = np.array([9, 6])
recon_shape = np.array((16000, 13000)) // binning
ul_offset = np.array([6500, 4400]) // binning
optimize_illum = True

if calibration_step == 'illumination':
    # estimate an organism-free background (assuming the organisms move)
    im_stack_new = 0
    for i in tqdm(range(25, 601)):
        im_stack_ = load_stack(directory, filename, video_frame=i)
        im_stack_new = np.maximum(im_stack_new, im_stack_)
    im_stack = im_stack_new
    
# load previously optimized parameters:
restored = scipy.io.loadmat(initial_guess_path)
variable_initial_values = {key:restored[key].squeeze() 
                           for key in restored if '__' not in key}
variable_initial_values['rc'] = variable_initial_values['rc'] / binning

## Camera pose and illumination calibration
If `calibration_step` is `'camera_pose'` or `'illumination'`, run these cells:

In [None]:
if calibration_step == 'refocusing':
    raise Exception('skip these cells; run the digital refocusing cells below')

if calibration_step == 'illumination':
    scale = .05
else:
    scale = .15
num_iter = 201
a = mcam3d(stack=im_stack[inds_keep], ul_coords=np.zeros((len(im_stack),2)),
            recon_shape=recon_shape, ul_offset=ul_offset,
            scale=scale,
            batch_size=None,
            momentum=None,
            report_error_map=True,  # this slows down the gradient_update step;
           )
learning_rates={'camera_focal_length': -1e-3, 'camera_height': 1e-3, 
                'ground_surface_normal': 1e-3, 'camera_in_plane_angle': 1e-3, 'rc': 1,
                'gain': -1e-3, 'bias':-1e-3, 'radial_camera_distortion': 1e-3, 'illum_flat':-1e-2, 
                'ego_height': -1}
if optimize_illum:
    if optimize_ONLY_illum:
        for key in learning_rates:
            learning_rates[key] = -1
    learning_rates['illum_flat'] = 1e-2

a.create_variables(deformation_model='camera_parameters',
                   learning_rates=learning_rates,
                   variable_initial_values=variable_initial_values,
                   remove_global_transform=True, antialiasing_filter=True)
stack_downsamp, rc_downsamp = a.generate_dataset()  # when not batching, this is just a one-batch dataset;
reg_coefs = {'local_illum_flat': 1}

losses = list()

for ii in tqdm(range(num_iter)):
    loss_i, recon, normalize, error_map = a.gradient_update(stack_downsamp, rc_downsamp, reg_coefs=reg_coefs)
    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())
    
    if ii % 10 == 0:
        print(ii, 'Loss terms: ' + str(losses[-1]))
    if ii % 100 == 0:
        monitor_progress(recon, error_map, losses)
        
variable_initial_values = a.get_all_variables()

In [None]:
# save parameters for future use;
save_dict = {v: variable_initial_values[v].numpy() for v in variable_initial_values}
save_dict['inds_keep__'] = inds_keep  # use '__' to make it easier to distinguish between variables and non-var
save_dict['recon_shape__'] = recon_shape
save_dict['ul_offset__'] = ul_offset
save_dict['camera_array_dims__'] = camera_array_dims
save_dict['optimize_illum__'] = optimize_illum
save_dict['binning__'] = binning

scipy.io.savemat(savepath, save_dict)

## "Digital refocusing" step to center the zero-height reference
If `calibration_step` is `'refocusing'`, run these cells:

In [None]:
if calibration_step == 'camera_pose' or calibration_step == 'illumination':
    raise Exception('run the cells above in the previous section')

variable_initial_values['rc'] = variable_initial_values['rc'] / binning_ratio

scale = .25  # can use scale=1 for bin4
if scale == 1:
    antialiasing_filter = False
else:
    antialiasing_filter = True
a = mcam3d(stack=im_stack[inds_keep], ul_coords=np.zeros((len(im_stack),2)),
            recon_shape=recon_shape, ul_offset=ul_offset,
            scale=scale,
            batch_size=None,
            momentum=None,
            report_error_map=False
           )
learning_rates={'camera_focal_length': -1e-3, 'camera_height': 1e-3, 
                'ground_surface_normal': 1e-3, 'camera_in_plane_angle': 1e-3, 'rc': 1,
                'gain': -1e-3, 'bias':-1e-3, 'radial_camera_distortion': 1e-3, 'illum_flat':-1e-2}

a.create_variables(deformation_model='camera_parameters',
                   learning_rates=learning_rates,
                   variable_initial_values=variable_initial_values,
                   remove_global_transform=True, antialiasing_filter=antialiasing_filter)
stack_downsamp, rc_downsamp = a.generate_dataset()

In [None]:
# do a line search on global rescaling of coordinates:
losses = list()
s_sweep = np.linspace(.96, 1.04, 45)
# s_sweep = np.linspace(.975, .985, 11)
rc0 = a.rc_ul_per_im.numpy()  # original

for scale_factor in tqdm(s_sweep):
    a.rc_ul_per_im.assign(rc0 * scale_factor)  # change global height
    loss_i, recon, normalize, error_map = a.gradient_update(stack_downsamp, rc_downsamp, update_gradient=False,
                                                            dither_coords=False
                                                           )

    losses.append(loss_i.numpy())
    
#     print(scale_factor, losses[-1])
#     plt.figure(figsize=(15,15))
#     plt.imshow(recon[:,:,:3].numpy().astype(np.uint8))
#     plt.show()
        
a.rc_ul_per_im.assign(rc0)  # restore if you want to run again
plt.plot(s_sweep, losses,'o-')
plt.xlabel('rc scale factor')
plt.ylabel('loss')

In [None]:
# once you're happy with the above result, pick out the scale:
i_best = np.argmin(losses)
rc_scale_best = s_sweep[i_best]
print(rc_scale_best)
a.rc_ul_per_im.assign(rc0 * rc_scale_best)

variable_initial_values = a.get_all_variables()

In [None]:
# save parameters for future use;
save_dict = {v: variable_initial_values[v].numpy() for v in variable_initial_values}
save_dict['inds_keep__'] = inds_keep  # use '__' to make it easier to distinguish between variables and non-var
save_dict['recon_shape__'] = recon_shape
save_dict['ul_offset__'] = ul_offset
save_dict['camera_array_dims__'] = camera_array_dims
save_dict['optimize_illum__'] = optimize_illum
save_dict['binning__'] = binning
save_dict['rc_rescale__'] = rc_scale_best

scipy.io.savemat(savepath, save_dict)