## Overview
There are 3 rounds of optimization at different resolutions. The first two rounds are common to all methods. There are two separate sections for round 3, depending on whether to regularize with CNN reparameterization or total variation.

In [None]:
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # make only one GPU visible if multiple are present
os.environ['TF_GPU_HOST_MEM_LIMIT_IN_MB'] = '32000'  # for tensorflow LMS
import matplotlib.pyplot as plt
import cv2
from time import time
import tensorflow as tf
from mesoSfM import mesoSfM, stack_loader_phone, monitor_progress, xcorr_initial_guess
from tqdm.notebook import tqdm
from datetime import date

tf.config.experimental.set_lms_enabled(True)
tf.config.experimental.set_lms_defrag_enabled(True)

def set_seed(seed=0):
    tf.random.set_seed(seed)  # this doesn't deal with nondeterministic GPU operations though
    np.random.seed(seed)
set_seed(10)

## Specify dataset and hyperparameters
Specify the dataset and distortion method in the next cell.

In [None]:
dataset = 'cut_cards'  # one of the following: cut_cards, painting, helicopter_seeds, PCB, tuning_set
distort_method = 'piecewise_linear'  # 'radial' or 'piecewise_linear'
num_radial_terms = 2  # only used for 'radial' distort model; the max power is twice this number;

In [None]:
directory = './data/' + dataset + '/'
im_stack = stack_loader_phone(directory)  # load images into memory
num_channels = im_stack.shape[-1]
inds_keep = np.arange(len(im_stack))  # which images to use, specified by indices
recon_shape = np.array((5500, 6500))  # shape of reconstruction in pixels
dither_coords = True

# if difficult to get initial registration, use sequential cross-correlation-based estimate:
if dataset in ['helicopter_seeds', 'tuning_set', 'PCB']:
    x_pos, y_pos = xcorr_initial_guess(im_stack)
elif dataset in ['cut_cards']:
    x_pos, y_pos = xcorr_initial_guess(im_stack, crop_frac=.2)
else:
    # otherwise, just initialize initial positions with 0s;
    x_pos = np.zeros(len(im_stack))
    y_pos = np.zeros(len(im_stack))
    
# global xy offset in reconstruction in pixels;
ul_offset = np.array([1100, 1100])
if 'PCB' in directory:
    ul_offset = np.array([1400, 1400])

# pre-calibrated magnifications using control points:
if 'cut_cards' in directory:
    magnification = 0.069315160567588
elif 'tuning_set' in directory:
    magnification = 0.0835
elif 'painting' in directory:
    magnification = 0.058971786873920
elif 'PCB' in directory:
    magnification = 0.076508519683842
elif 'helicopter_seeds' in directory:
    magnification = 0.076790599352193
else:
    raise Exception('magnification not defined')

# I already downsampled the images by 2x, so indicate that here:
pre_downsample_factor = 2

recon_shape = recon_shape // pre_downsample_factor
ul_offset = ul_offset // pre_downsample_factor

## Round 1 of optimization
First, optimize at a low resolution (heavy downsampling), updating only the xy shifts and keeping everything else fixed.

In [None]:
set_seed(1)

# create mesoSfM object with particular settings:
a = mesoSfM(stack=im_stack[inds_keep], ul_coords=np.stack([y_pos[inds_keep], x_pos[inds_keep]]).T,
            recon_shape=recon_shape, ul_offset=ul_offset, scale=.05*pre_downsample_factor)

# create tf variables:
a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
    learning_rates={'camera_focal_length': -1e-3, 'camera_height': -1e-3, 'ground_surface_normal': -1e-3,
                   'camera_in_plane_angle': -1e-3, 'rc': 10, 'gain': -1e-3, 'bias':-1e-3, 'ego_height': -1e-3},
                   variable_initial_values=None, remove_global_transform=True, antialiasing_filter=True)

# create dataset:
stack_downsamp, rc_downsamp = a.generate_dataset()

# optimization loop:
losses = list()
for ii in tqdm(range(401)):
    start = time()
    loss_i, recon, normalize, error_map = a.gradient_update(stack_downsamp, rc_downsamp)
    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())
    
    if ii % 10 == 0:
        print(ii, losses[-1], time() - start)
    if ii % 100 == 0:
        monitor_progress(recon, error_map, losses)
        
variable_initial_values = a.get_all_variables()  # store the optimized variables to initialize the next round;

## Round 2 of optimization
Next, optimize at a higher resolution, updating all parameters, including distortion, except for height map.

In [None]:
set_seed(2)

learning_rates = {'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                   'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': -1e-3}
                   
# initialization and learning rates for distortion models:
if distort_method == 'piecewise_linear':
    variable_initial_values['camera_distortion_center'] = np.zeros([1, 2])+1e-7
    variable_initial_values['radial_camera_distortion_piecewise_linear'] = np.zeros(50)
    learning_rates['radial_camera_distortion_piecewise_linear'] = 1e-3
    learning_rates['camera_distortion_center'] = 1e-4
elif distort_method == 'radial':
    variable_initial_values['camera_distortion_center'] = np.zeros([1, 2])+1e-7
    variable_initial_values['radial_camera_distortion'] = np.zeros([1, num_radial_terms])
    learning_rates['camera_distortion_center'] = 1e-4
    learning_rates['radial_camera_distortion'] = 1e-2
else:
    raise Exception('invalid distort_method')

# create mesoSfM object with particular settings:
a = mesoSfM(stack=im_stack[inds_keep], ul_coords=np.stack([y_pos[inds_keep], x_pos[inds_keep]]).T,
            recon_shape=recon_shape, ul_offset=ul_offset, scale=.15*pre_downsample_factor)

# create tf Variables:
a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
                   learning_rates=learning_rates,
                   variable_initial_values=variable_initial_values,
                   remove_global_transform=True, antialiasing_filter=True)

# create dataset:
stack_downsamp, rc_downsamp = a.generate_dataset()  # when not batching, this is just a one-batch dataset;

# optimization loop:
losses = list()
for ii in tqdm(range(401)):
    start = time()
    loss_i, recon, normalize, error_map = a.gradient_update(stack_downsamp, rc_downsamp)
    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())
    if ii % 10 == 0:
        print(ii, losses[-1], time() - start)
        
    if ii % 100 == 0:
        monitor_progress(recon, error_map, losses)
        if len(losses) > 201:
            plt.plot(losses[200:])
            plt.show()
            
variable_initial_values = a.get_all_variables()  # store the optimized variables to initialize the next round;
variable_indices = {var.name: i for i, var in enumerate(a.train_var_list)}  # to identify the optimizers
# so that we can anneal the learning rate (for CNN);

## Two options from this point
Now that we've gotten a decent optimized result using only camera parameters, we're ready to go to the final resolution AND optimize the height map. We have two options for round 3 for regularizing the height map: 1) reparameterization with a convolutional neural network (CNN), and 2) total variation (TV). Run only the cells corresponding to the desired regularization -- these cells are separated by different headers:
- Round 3 of optimization with total variation
- Round 3 of optimization with CNN reparameterization

## Round 3 of optimization with CNN reparameterization
To test out different CNN architectures, modify `CNN_architecture` below, which is a list of filter numbers (see supplementary document of paper). The main one used for the paper was `[16, 16, 16, 32, 32]`, but other ones tested were: `[16, 16, 32, 32]`, `[16, 16, 16, 16]`, and `[16, 16, 16, 16, 16]`. This round will take up the bulk of the total processing time.

In [None]:
CNN_architecture = [16, 16, 16, 32, 32]
reg_coefs = {'height_map': .1}
num_iter = 10001

tf.config.optimizer.set_experimental_options({'layout_optimizer': False})  # might save some memory

# specify learning rates:
learning_rates = {'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                   'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': 1e-3}
if distort_method == 'piecewise_linear':
    learning_rates['radial_camera_distortion_piecewise_linear'] = 1e-3
    learning_rates['camera_distortion_center'] = 1e-4
elif distort_method == 'radial':
    learning_rates['camera_distortion_center'] = 1e-4
    learning_rates['radial_camera_distortion'] = 1e-2
else:
    raise Exception('invalid distort_method')
    
# criterion for optimization divergence, to roll back to a previous checkpoint:
height_map_reg_threshold = 8
    
if distort_method == 'radial':
    height_map_reg_threshold *= 2  # with a poorer distortion model, the height maps are less consistent;

In [None]:
seed = 5
while True:
    # sometimes, the optimization unpredictably encounters an NaN, or it doesn't start converging; 
    # if this happens start over;
    # the try-except in the while loop basically says keep trying to complete the full optimization loop;
    
    try:
        set_seed(seed)

        # create mesoSfM object with particular settings:
        a = mesoSfM(stack=im_stack[inds_keep], ul_coords=np.stack([y_pos[inds_keep], x_pos[inds_keep]]).T,
                    recon_shape=recon_shape, ul_offset=ul_offset, scale=.5*pre_downsample_factor,
                    batch_size=6, momentum=.5, batch_across_images=True)

        # these are for getting absolute height values:
        a.use_absolute_scale_calibration = True
        a.effective_focal_length_mm = 4.3
        a.magnification_j = magnification

        # CNN settings:
        a.filters_list = CNN_architecture
        a.skip_list = [0] * len(CNN_architecture)  # no skip connections
        a.save_iter = 50  # checkpoint parameters every this many iterations
        a.recompute_CNN = True  # gradient checkpointing to reduce memory usage
        
        # create tf Variables:
        a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic_unet',
                           learning_rates=learning_rates, variable_initial_values=variable_initial_values,
                           remove_global_transform=True, antialiasing_filter=False)
        
        # create dataset:
        dataset = a.generate_dataset()

        # optimization loop:
        ii = 0
        losses = list()
        for stack_downsamp, rc_downsamp in dataset:
            start = time()

            loss_i = a.gradient_update(stack_downsamp, rc_downsamp, dither_coords=dither_coords,
                                       reg_coefs=reg_coefs, return_loss_only=True)
            if type(loss_i) is list:
                losses.append([loss.numpy() for loss in loss_i])
            else:
                losses.append(loss_i.numpy())

            if ii % 250 == 0:
                print(ii, losses[-1], time() - start)
            
            # conditions for rolling back to an earlier checkpoint:
            loss_num = 1
            last_few_losses = [loss[loss_num] for loss in losses[-10:-1]]
            roll_back_condition = losses[-1][loss_num] > height_map_reg_threshold

            if len(losses) <= a.save_iter:
                # don't do anything early on, let the optimizer explore a bit, while  checkpoints are accumulating;
                pass  

            elif (roll_back_condition and (2*a.save_iter >= len(losses) > a.save_iter)) or any(np.isnan(loss_i)):
                # now that the optimizer has explored a bit, if it's too far off track early on,
                # it may never find its way -- just restart; this exception will be picked up
                # by the try-except construct;
                seed += 1  # alter the seed to avoid running into the same issue;
                raise ValueError('Optimization diverged early on, restarting ...')

            elif ii == 2000 and losses[-1][loss_num] < 1e-4:
                # if by iteration 2000, the height map reg hasn't attained a certain value, then
                # restart;
                raise ValueError('Optimization never converged, restarting ...')

            elif roll_back_condition and len(losses) > a.save_iter*2:  
                # after 2*save_iter, there are two checkpoints to use;
                print('Last few losses: ' + str(last_few_losses))
                print('Current loss: ' + str(losses[-1][loss_num]))
                a.restore_all_variables()
                # anneal learning rate:
                optim = a.optimizer_list[variable_indices['ego_height:0']]
                optim.lr.assign(optim.lr * .9)
                # remove most recent loss values so that you don't influence the next test for divergence:
                losses = losses[:-2]

            if len(losses) % a.save_iter == 0:
                # periodically checkpoint all variables for rolling back; see above;
                a.checkpoint_all_variables(path='./tf_ckpts')

            # break out of for-loop when you reach num_iter iterations:
            if ii == num_iter:
                break
            else:
                ii += 1

        break
    except ValueError:
        print('Optimization diverged early on; restarting the optimization ...')

In [None]:
# when all is done, plot results:
loss_i, recon, normalize, error_map, tracked = a.gradient_update(stack_downsamp, rc_downsamp, update_gradient=False,
                                                                 dither_coords=False, reg_coefs=reg_coefs,
                                                                 return_tracked_tensors=True)
monitor_progress(recon, error_map, losses, tracked)

## Round 3 of optimization with total variation
To obtain results with total variation (TV) regularization, after round 2 above, skip down to this cell without running the cells under **Round 3 of optimization with CNN reparameterization**.

This round will be done in two steps: 
1. Optimize all parameters with downsampling. The difference betwen this and round 2 is that the height map is also optimized.
2. Optimize all parameters at the final resolution. This step also requires batching to satisfy memory constraints.

### First, optimize all parameters with downsampling
Regardless of the final TV coefficient, this step is always run. No need to modify anything here.

In [None]:
tf.config.optimizer.set_experimental_options({'layout_optimizer': False})  # might save some memory
set_seed(2)

# specify learning rates:
learning_rates = {'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                   'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': 1e-3}
reg_coefs = {'TV': .001, 'height_map': .1}  # specify regularization coefficients;
                   
if distort_method == 'piecewise_linear':
    learning_rates['radial_camera_distortion_piecewise_linear'] = 1e-3
    learning_rates['camera_distortion_center'] = 1e-4
elif distort_method == 'radial':
    learning_rates['camera_distortion_center'] = 1e-4
    learning_rates['radial_camera_distortion'] = 1e-2
else:
    raise Exception('invalid distort_method')

# create mesoSfM object with particular settings:
a = mesoSfM(stack=im_stack[inds_keep], ul_coords=np.stack([y_pos[inds_keep], x_pos[inds_keep]]).T,
            recon_shape=recon_shape, ul_offset=ul_offset, scale=.15*pre_downsample_factor)

# these are for getting absolute height values:
a.use_absolute_scale_calibration = True
a.effective_focal_length_mm = 4.3
a.magnification_j = magnification

# create tf Variables:
a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
                   learning_rates=learning_rates, variable_initial_values=variable_initial_values,
                   remove_global_transform=True, antialiasing_filter=True)

# create dataset:
stack_downsamp, rc_downsamp = a.generate_dataset()

# optimization loop:
losses = list()
for ii in tqdm(range(401)):
    start = time()
    loss_i, recon, normalize, error_map, tracked = a.gradient_update(stack_downsamp, rc_downsamp,
                                                            reg_coefs=reg_coefs, return_tracked_tensors=True)
    if type(loss_i) is list:
        losses.append([loss.numpy() for loss in loss_i])
    else:
        losses.append(loss_i.numpy())
    if ii % 10 == 0:
        print(ii, losses[-1], time() - start)
        
    if ii % 100 == 0:
        monitor_progress(recon, error_map, losses, tracked)
        if len(losses) > 201:
            plt.plot(losses[200:])
            plt.show()
        
variable_initial_values = a.get_all_variables()

### Second, optimize all parameters at the final resolution
Set the desired TV coefficient in the `reg_coefs` dictionary below. The ones used in the paper were 0.003, 0.01, 0.03, and 0.1.

In [None]:
reg_coefs = {'TV': .1, 'height_map': .1}  # set TV reg coef
num_iter = 2001  # number of iterations in optimization loop

# initialize with previous stitched result:
recon_previous = np.copy(recon.numpy())
normalize_previous = np.copy(normalize.numpy())

learning_rates = {'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                   'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': 1e-3}

if distort_method == 'piecewise_linear':
    learning_rates['radial_camera_distortion_piecewise_linear'] = 1e-3
    learning_rates['camera_distortion_center'] = 1e-4
elif distort_method == 'radial':
    learning_rates['camera_distortion_center'] = 1e-4
    learning_rates['radial_camera_distortion'] = 1e-2
else:
    raise Exception('invalid distort_method')

In [None]:
while True:
    # sometimes, the optimization unpredictably encounters an NaN; if this happens start over;
    # the try-except in the while loop basically says keep trying to complete the full optimization loop;
    
    try:
        set_seed(3)

        # create mesoSfM object with particular settings:
        a = mesoSfM(stack=im_stack[inds_keep], ul_coords=np.stack([y_pos[inds_keep], x_pos[inds_keep]]).T,
                    recon_shape=recon_shape, ul_offset=ul_offset, scale=.5*pre_downsample_factor,
                    batch_size=6, momentum=.5, batch_across_images=True)
        
        # these are for getting absolute height values:
        a.use_absolute_scale_calibration = True
        a.effective_focal_length_mm = 4.3
        a.magnification_j = magnification

        # create tf Variables:
        a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
                           learning_rates=learning_rates, variable_initial_values=variable_initial_values,
                           recon=recon_previous, normalize=normalize_previous,
                           remove_global_transform=True, antialiasing_filter=False)

        # create dataset:
        dataset = a.generate_dataset()

        # optimization loop:
        ii = 0
        losses = list()
        for stack_downsamp, rc_downsamp in dataset:
            start = time()

            loss_i = a.gradient_update(stack_downsamp, rc_downsamp, dither_coords=dither_coords, 
                                       reg_coefs=reg_coefs, return_loss_only=True)
            if type(loss_i) is list:
                losses.append([loss.numpy() for loss in loss_i])
            else:
                losses.append(loss_i.numpy())

            if ii % 250 == 0:
                print(ii, losses[-1], time() - start)
                
            if any(np.isnan(loss_i)) or loss_i[0] > 2e3:
                raise ValueError('NaN occurred or diverged, restarting ...') 

            if ii == num_iter:
                break
            else:
                ii += 1

        break  # break out of while loop if you successfully complete the optimization loop without errors;

    except ValueError:
        print('NaN occurred or diverged; restarting the optimization ...')

In [None]:
# when all is done, plot results:
loss_i, recon, normalize, error_map, tracked = a.gradient_update(stack_downsamp, rc_downsamp, 
                                                                 dither_coords=dither_coords, reg_coefs=reg_coefs,
                                                                 return_tracked_tensors=True)
monitor_progress(recon, error_map, losses, tracked)