In [None]:
import numpy as np
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = ''  # restrict GPU usage
import matplotlib.pyplot as plt
import tensorflow as tf
from zstitch import zstitch, get_z_step_mm
from tqdm.notebook import tqdm
import scipy.io
from mcam_loading_scripts import load_xyz

# Choose MCAM datasets to train on
Specify the paths to the MCAM datasets and calibration files as well as the hyperparameters for the sample you wish to train on.

In [None]:
cam_slice0=(0, 9); cam_slice1=(0, 6)  # by default, use all cameras
camera_dims_during_acquisition = (9, 6)  # shape used during acquisition
filters_list = np.array([32, 32, 32, 64, 64, 64])  # CNN architecture (list of filter numbers)

# pick a sample:
sample_id = 'chair_painting'  # 'chair_painting', 'PCB', 'BGA', or 'PGA'

if sample_id == 'chair_painting':
    # chair painting sample:
    directory = '/data/20220213_chair_painting/'
    restore_path = '/data/20220213_green_noise_target/flat_ref_optimized_params.mat'
    restore_path_single_cam = '/data/20220213_green_noise_target/flat_ref_optimized_params_single_cam.mat'
    filters_list = np.array([32, 32, 32, 64, 64, 64])
    weighted_sharpness_thresholds = [1, 1.5]
elif sample_id == 'PCB':
    # PCB sample:
    directory = '/home/kevin/data/20211219_PCB_160B_rechunked'
    restore_path = '/home/kevin/data/20211219_noise_target/flat_ref_optimized_params.mat'
    restore_path_single_cam = '/home/kevin/data/20211219_noise_target/flat_ref_optimized_params_single_cam.mat'
    filters_list = np.array([32, 32, 32, 64, 64])
    weighted_sharpness_thresholds = [1.5, 2]
    cam_slice0=(0, 8)
elif sample_id == 'BGA':
    # BGA chips sample:
    directory = '/data/20220207_bga_chips/'
    restore_path = '/data/20220207_green_noise_target/flat_ref_optimized_params.mat'
    restore_path_single_cam = '/data/20220207_green_noise_target/flat_ref_optimized_params_single_cam.mat'
    filters_list = np.array([32, 32, 32, 64, 64, 64])
    weighted_sharpness_thresholds = [1.5, 2]
    cam_slice0=(0, 8)
elif sample_id == 'PGA':
    # pin array sample:
    directory = '/data/20220217_pin_array/'
    restore_path = '/data/20220217_green_noise_target/flat_ref_optimized_params.mat'
    restore_path_single_cam = '/data/20220217_green_noise_target/flat_ref_optimized_params_single_cam.mat'
    filters_list = np.array([32, 32, 32, 64, 64, 64])
    weighted_sharpness_thresholds = [1, 1.5]
    cam_slice0=(0, 8)
    cam_slice1=(0, 3)
    camera_dims_during_acquisition = (8, 3)
else:
    raise Exception('invalid sample_id')

blur_sigma = 8  # gaussian blur radius for calculating sharpness
num_patch = 2  # number of patches per batch
patch_size = 576  # size of square patch
skip_list = [0]*len(filters_list)  # no skip connections

# Physics self-supervised training (dynamically loading from storage)
Train a CNN to map from z-stacks to 3D height using focus cues and stereo, across the entire (up to) 2.1-TB datasets. Since the dataset is too large to load into computer RAM, we dynamically load patches randomly from the dataset.

In [None]:
def generate_model_and_dataset(im_stack_inds_keep, recon_shape, ul_offset, downsamp, nominal_z_slices, 
                               nominal_z_slices_global,  # needed for training from disk, but not per camera
                               filters_list, skip_list, variable_initial_values, z_step_mm,
                               num_patch=num_patch, patch_size=patch_size,  # for making dataset
                               preferred_camera=(0, 0)
                              ):
    # Although when training in this mode (i.e., streaming from disk), we don't need to load full per-camera
    # datasets into memory -- load just one (supplying im_stack_inds_keep) so that we can monitor performance
    # during optimization.
    
    # create visitation_log at lower scale for all cameras:
    a = zstitch(stack=im_stack_inds_keep, 
                ul_coords=np.zeros((len(im_stack_inds_keep), 2)),  # ul_coords will be replaced
                recon_shape=recon_shape, 
                ul_offset=ul_offset,
                scale=.01*downsamp,
                batch_size=None,
                momentum=None,
                report_error_map=False,
                sigma=blur_sigma,
                truncate=2,
                z_step_mm=z_step_mm,
                camera_dims=camera_dims_during_acquisition,
               )
    a.nominal_z_slices = tf.constant(nominal_z_slices)  # tensorarray.write doesn't accept np arrays
    a.use_camera_calibration = True
    a.effective_focal_length_mm = 25.05486
    a.magnification_j = 0.8448
    a.weighted_sharpness_loss = True 
    a.weighted_sharpness_thresholds = weighted_sharpness_thresholds
    a.create_variables(deformation_model='camera_parameters_perspective_to_orthographic',
        learning_rates={'camera_focal_length': -1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
                       'camera_in_plane_angle': 1e-3, 'rc': .1, 'gain': -1e-3, 'bias': -1e-3, 'ego_height': 1e-3,
                       'radial_camera_distortion': 1e-3},
                       variable_initial_values=variable_initial_values,
                       remove_global_transform=True, antialiasing_filter=False)
    stack_downsamp, rc_downsamp = a.generate_dataset()
    
    loss_i, recon, normalize, error_map, tracked = a.gradient_update(stack_downsamp, rc_downsamp,
                                                                     return_tracked_tensors=True, 
                                                                     update_gradient=False)
    a.recompute_CNN = True  # only if you need to save memory
    a.unet_scale = .0001
    a.define_network_and_camera_params(tracked['vanish_warp'], tracked['camera_to_vanish_point_xyz'], 
                                       num_channels_rgb=im_stack_inds_keep.shape[-1],
                                       architecture='fcnn', filters_list=filters_list, 
                                       skip_list=skip_list,
                                       learning_rate=1e-3/10)
    
    # generate visitation log:
    with tf.device('/CPU:0'):
        if camera_dims_during_acquisition == (9,6):
            # If you choose to slice a fraction of the cameras, you don't need to indicate that here if you 
            # chose to slice the cameras AFTER acquisition.
            # e.g., if you acquired all 9x6 cameras, but wish to only use 8x5 (specified via cam_slice0, 
            # cam_slice1), you don't need to indicate that here.
            visitation_log_vars = a.generate_visitation_log_for_all_cameras(restore_path, reuse_log=False,  
                                                                            preferred_camera=preferred_camera
                                                                           )
        else:
            # however, if you acquired fewer than 9x6 cameras, then here you need to specify the crop 
            # here via cam_slice0/1.
            visitation_log_vars = a.generate_visitation_log_for_all_cameras(restore_path, reuse_log=False,  
                                                                            preferred_camera=preferred_camera,
                                                                            cam_slice0=cam_slice0, 
                                                                            cam_slice1=cam_slice1
                                                                           )
        
    # generate dataset
    a.visitation_log_scale = a.scale
    dataset = a.generate_patched_dataset_from_disk(directory, num_patch, patch_size, nominal_z_slices_global, 
                                                   patch_size*2, prefetch=-1, sample_margin=.2, 
                                                   cam_slice0=cam_slice0, cam_slice1=cam_slice1)
        
    return a, dataset, visitation_log_vars

def generate_full_recon_on_cpu(a):
    # compute reconstruction for the data loaded for a single camera, for monitoring progress
    with tf.device('/CPU:0'):
        recon_full, _ = a.generate_full_recon(margin=2)
        recon_full = recon_full.numpy()
        plt.figure(figsize=(10, 10))
        plt.imshow(recon_full[:, :, :-1].astype(np.uint8))
        plt.show()
        plt.figure(figsize=(10, 10))
        plt.imshow(recon_full[:, :, -1], cmap='turbo')
        clims = np.percentile(recon_full[::4,::4, -1][recon_full[::4,::4, -1]!=0], [0.5, 99.99])
        plt.clim(clims)
        plt.show()

In [None]:
downsamp = 1
z_step_ratio = 1  # ratio of z step size of flat reference to sample of interest
camera_dims = camera_dims_during_acquisition
xy_scans = (8, 8)
array_dims = (camera_dims[0]*xy_scans[0], camera_dims[1]*xy_scans[1])
num_cameras = np.prod(camera_dims)  # e.g., 54
num_images = np.prod(array_dims)  # e.g., 3456
image_inds = np.arange(num_images).reshape(array_dims)  # 0, ..., 3455 shaped as (72, 48)
# ^ used to pick out from variable_initial_values

ckpt_path = os.path.join(directory, 'CNN_ckpts')
print('CNN checkpoint path: ' + ckpt_path)
train_from_scratch = True  # if not, then load from above ckpt

# get precalibrated values:
restored = scipy.io.loadmat(restore_path)
if 'z_step_mm__' in restored:
    z_step_mm = restored['z_step_mm__']
else:
    z_step_mm = get_z_step_mm(directory)  # z step in mm
    
nominal_z_slices_global = restored['nominal_z_slices__'].flatten().astype(np.float32)
z_mean = nominal_z_slices_global.mean()
nominal_z_slices_global = np.float32((nominal_z_slices_global - z_mean)*z_step_ratio + z_mean)

if 'pre_downsample_factor__' in restored:
    pre_downsample_factor = restored['pre_downsample_factor__'].flatten()
else:
    pre_downsample_factor = 1
variable_initial_values_global = {key:restored[key].squeeze() 
                           for key in restored if '__' not in key}
variable_initial_values_global['rc'] = variable_initial_values_global['rc'] / downsamp * pre_downsample_factor

restored_single_cam = scipy.io.loadmat(restore_path_single_cam)
recon_shape = restored_single_cam['recon_shape__'].flatten()
ul_offset = restored_single_cam['ul_offset__'].flatten()

In [None]:
# load a single xyz stack from one camera, for monitoring:
rr = 3  # pick a random camera (row and column)
cc = 2

im_stack_rgb = load_xyz(directory, 'full', cam_slice0=[rr, rr+1], cam_slice1=[cc, cc+1], keep_green_only=True)
im_stack_rgb = im_stack_rgb.transpose(1, 2, 3, 4, 5, 0)  # move z-stack dim to end
im_stack = im_stack_rgb.squeeze()  # just use green channel; converting to grayscale takes lot of memory
im_stack = im_stack.reshape([-1] + list(im_stack.shape[2:]))  # flatten

if downsamp > 1:
    im_stack = im_stack[:, ::downsamp, ::downsamp, :]
    recon_shape = recon_shape // downsamp
    ul_offset = ul_offset // downsamp

# pick out variable_initial_values from variable_initial_values_global:
variable_inds = image_inds[rr*xy_scans[0]:(rr+1)*xy_scans[0], cc*xy_scans[1]:(cc+1)*xy_scans[1]].flatten()
variable_initial_values = variable_initial_values_global.copy()
for var_name in variable_initial_values:
    variable = variable_initial_values[var_name]
    if len(variable.shape) > 0 and variable.shape[0] == num_images:
        variable_initial_values[var_name] = variable[variable_inds]
nominal_z_slices = nominal_z_slices_global[variable_inds]

with tf.device('/CPU:0'):
    a, dataset, visitation_log_vars = generate_model_and_dataset(im_stack, recon_shape, ul_offset, 
                                                                 downsamp, nominal_z_slices, 
                                                                 nominal_z_slices_global,
                                                                 filters_list, skip_list,
                                                                 variable_initial_values,
                                                                 z_step_mm=z_step_mm,
                                                                 num_patch=num_patch, patch_size=patch_size,
                                                                 preferred_camera=(rr, cc))

plt.imshow((visitation_log_vars[0][rr, cc].numpy()>0).sum(2))
plt.title('visitation log')
plt.show()

# restore CNN:
if train_from_scratch:
    print('did not restore CNN from ckpt')
else:
    a.ckpt = None
    a.checkpoint_all_variables(ckpt_path, skip_saving=True)
    a.restore_all_variables(ckpt_no=0)
    a.ckpt = None
    print('restored CNN from ckpt')
    

In [None]:
num_iter = 100001

ii = 0
losses = list()

for batch in tqdm(dataset):
    if any([len(batch[0][i]) < 2 for i in range(num_patch)]):
        # there needs to be at least two patches to register!
        continue

    loss_i, recon_i, normalize_i, grads_i, norm_i = a.gradient_update_patch(batch, height_map_reg_coef=.5,
                                                                            return_gradients=True,
                                                                            clip_gradient_norm=10,
                                                                            dither_coords=True,
                                                                            downsample_factor=1,
                                                                            sharpness_reg_coef=5000,
                                                                            stitch_loss_coef=None,
                                                                            argmax_loss_coef=1,
                                                                            use_hpf_for_MSE_loss=False,
                                                                            orthorectify=False,
                                                                           )
    losses.append([l.numpy() for l in loss_i])

    if len(losses) % 100 == 0:
        print(len(losses), 'Loss terms: ' + str(losses[-1]))

    if len(losses) % 1000 == 0 or len(losses) == 1:
        plt.plot(losses)
        plt.title('loss history')
        plt.show()

        if len(losses) > 51:
            if type(loss_i) is list:
                for i in range(len(loss_i)):
                    plt.plot(np.convolve(np.array([loss[i] for loss in losses][50:]), np.ones(200)/200, 'valid'))
                    plt.show()
            else:
                plt.plot(np.convolve(np.array(losses[50:]).squeeze(), np.ones(200)/200, 'valid'))
                plt.title('blurred loss')
                plt.show()

        plt.subplot(121)
        plt.imshow(np.uint8(recon_i[:,:,0, :-1]))
        plt.title('patch green channel')
        plt.subplot(122)
        plt.imshow(recon_i[:,:,0, -1], cmap='jet')
        plt.title('patch height')
        plt.show()

    if ii % 10000 == 0 and ii != 0:
        # check progress on single-camera data:
        generate_full_recon_on_cpu(a)

    if ii == num_iter:
        break
    else:
        ii+=1

In [None]:
# save CNN in a unique directory:
CNN_ckpt_path = os.path.join(directory, 'CNN_ckpts')

a.ckpt = None
a.checkpoint_all_variables(path=CNN_ckpt_path)
print(ckpt_path)
print('checkpointed network')