# ***Automatic Neuron Tracking System for Unconstrained Nematodes (ANTSUN) v2.1.0-Unsupervised***

This notebook performs unsupservised NeuroPAL neuron labeling. It takes as input a fully-trained CellDiscoveryNet and a directory containing 4-D multispectral images of worms, their ROI (segmentation) images, and Euler-registered versions for every pair of images.

Note that it is assumed that Euler registration (including rotating the worms to lie on the same side) should already be performed BEFORE running this notebook, and it is also assumed that you've already trained CellDiscoveryNet. If you haven't done so, please refer to the training notebook.

### Set GPU and server version

In [None]:
nvidia_smi_device = 0
flv_c = 3

In [None]:
if flv_c == 1
    gpu_device = 0
elseif flv_c == 2
    if nvidia_smi_device < 2
        gpu_device = nvidia_smi_device + 2
    else
        gpu_device = nvidia_smi_device - 2
    end
elseif flv_c == 3
    if nvidia_smi_device < 2
        gpu_device = 1 - nvidia_smi_device
    else
        gpu_device = nvidia_smi_device
    end
else
    error("Unsupported flv-c version")
end 
gpu_device = 2

In [None]:
ENV["CUDA_VISIBLE_DEVICES"] = "$(gpu_device)"

In [None]:
# Flavell lab packages
using ND2Process
using GPUFilter
using WormFeatureDetector
using NRRDIO
using FlavellBase
using SegmentationTools
using ImageDataIO
using RegistrationGraph
using ExtractRegisteredData
using CaAnalysis
using BehaviorDataNIR
using UNet2D
using ImageRegistration

# Other packages
using ProgressMeter
using PyCall
using PyPlot
using Statistics
using StatsBase
using DelimitedFiles
using Images
using Cairo
using Distributions
using DataStructures
using HDF5
using Interact
using WebIO
using Plots
# using GraphPlot
# using LightGraphs
# using SimpleWeightedGraphs
using Dates
using JLD2
using TotalVariation
using VideoIO
using Distributions
using MultivariateStats
using FFTW
using LinearAlgebra
using GLMNet
using InformationMeasures
using CUDA
using LsqFit
using Optim
using Rotations
using CoordinateTransformations
using ImageTransformations
using Interpolations
using H5Zblosc
using SparseArrays
using StatsPlots

using Distributed
using Clustering
using CSV
using DataFrames
using LaTeXStrings

### Set dataset path parameters

`path_root_process_*` will be the output directories that this notebook will write to.

`root_path_data` is the directory containing the input data.

Most other parameters should not need to be changed from their default values.

In [None]:
## ND2 parameters
ch_bluegreen = 1 # GCaMP and BFP
ch_red = 2 # mNeptune
n_rec = 1 # number of Illumination Sequence module recordings (.nd2) in the HDF5 file 
h5_confocal_time_lag = 0 # if this is the second confocal dataset using the same h5 file, set equal to number of frames in 1st dataset
spacing_lat = 0.54 # Spacing of each voxel in the xy plane. If different, it will likely be necessary to modify pipeline parameters.
spacing_axi = 0.54 # Spacing of each voxel in the z dimension. If different, it will likely be necessary to modify pipeline parameters.

## Path to data directory
data_dir = "/data4"

## Freely-moving parameters
max_graph_num = 800 # time point just before laser intensity is changed in freely-moving dataset. If not changed, set to length of the dataset.
blue_laser = [13,15] # array of blue laser values before vs after param["max_graph_num"] in the freely-moving dataset
green_laser = [15,17] # array of green laser values before vs after param["max_graph_num"] in the freely-moving dataset


## Data path parameters
prj = "prj_register" # change to your project - ie: the location to save the data
path_raw_data = "" # change to the directory with all raw data files

datasets = Dict() # for each filter, update the corresponding entry in the dictionary the path to the ND2 file for that dataset (without the .nd2 extension)
datasets["freely_moving"] = "multicolor_deepreg_test_6"


## You should not need to change anything below this line except possibly `reg_timepts[dataset] = 30`.
datasets_freely_moving = ["freely_moving"] # freely-moving datasets
dataset_central = "freely_moving" # central dataset - register other datasets to this. Must contain camera-alignment registration
datasets_register = [] # datasets to register back to central dataset

n_timepts_merge = Dict() # number of registrations to attempt back to the central dataset, for each registered dataset
for dataset in datasets_register
    n_timepts_merge[dataset] = 200
end

## Immobilized parameters
reg_timepts = Dict()
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    reg_timepts[dataset] = 30 # timepoint to register everything to, for the immobilized datasets
end


rotate_img_x = false # whether to rotate the image about the x-axis (if the worm was put on the slide the wrong way)

fm = datasets["freely_moving"]
path_root_process_freelymoving = "$(data_dir)/$(prj)/data_processed/$(fm)_output"
path_root_process_immobilized = "$(data_dir)/$(prj)/data_processed/$(fm)_output/neuropal"

create_dir(path_root_process_freelymoving)
create_dir(path_root_process_immobilized)

root_path_data = "/data4/prj_register/multicolor_deepreg_test_5"

In [None]:
param_paths = Dict()
params = Dict()
data_dicts = Dict()
error_dicts = Dict();

### Set various directory paths

There's a lot of paths going on here, most of which can be left at default values. (This code block is imported from ANTSUN 2.1.0, but many steps are skipped in this notebook.)

However, the entry `param_path["path_deepreg_weights"]` should be set to the path of the trained CellDiscoveryNet weights.

In [None]:
for dataset in keys(datasets)
    exp_prefix = datasets[dataset]
    if dataset in datasets_freely_moving
        path_root_process = path_root_process_freelymoving
    else
        path_root_process = joinpath(path_root_process_immobilized, datasets[dataset])
    end
    
    if !(dataset in keys(params))
        params[dataset] = Dict()
    end
    param = params[dataset]
    
    if !(dataset in keys(param_paths))
        param_paths[dataset] = Dict{String,Any}()
    end
    param_path = param_paths[dataset]

    param_path["path_root_process"] = path_root_process

    param_path["path_unet_pred"] = "/home/adam/src/pytorch-3dunet/pytorch3dunet/predict.py" # Path to the pytorch-3dunet installation on the local machine
    param_path["path_unet_param"] = "/data3/shared/dl_weights/3dunet_NeuroPAL/instance_segmentation_test.yaml" # Path to the UNet parameter file
    param_path["path_unet_py_env"] = "/home/adam/.julia/conda/3/bin/activate" # Path to a script that should be run to configure environment variables for the pytorch-3dunet
    param_path["path_transformix"] = (flv_c == 2) ? "/bin/transformix" : "/usr/local/bin/transformix" # Path to transformix executable on the LOCAL machine
    param_path["path_elastix_local"] = (flv_c > 1) ? "/bin/elastix" : "/usr/local/bin/elastix" # Path to elastix executable on the LOCAL machine

    
    param_path["path_head_unet_model"] = "/data1/shared/head_detector_0124/head_detector_0124/unet2d-head-detector/unet2d-head-detector_best.pt" # path to head detection unet model
    param_path["path_2d_unet_param"] = "/data3/shared/dl_weights/behavior_nir/worm_segmentation_best_weights_0310.pt"

    param_path["path_dir_lock"] = "/home/adam/lock" # path to lock directory

    if dataset in datasets_freely_moving
        param_path["path_head_rotate"] = "/om2/group/flavell/script/registration/euler_registration/euler_head_rotate.py"
    else
        param_path["path_head_rotate"] = nothing
    end
    param_path["path_head_rotate_activity_marker"] = nothing
    param_path["path_elastix"] = "/om2/user/aaatanas/elastix-5.0.1/build/bin/elastix"
    param_path["path_run_elastix"] = "/om2/group/flavell/script/registration/run_elastix_command.sh"

    om_user = "aaatanas" # Your username on OpenMind
    param_path["path_om_home"] = "/om2/user/$om_user"
    param_path["path_om_env"] = "/home/$(om_user)/.bashrc" # path to user environment variable script on OpenMind
    param_path["path_om_data"] = joinpath(param_path["path_om_home"], "$(exp_prefix)_output")
    param_path["path_om_home_scripts"] = "/om2/user/$om_user"
    param_path["path_om_scripts"] = joinpath(param_path["path_om_home_scripts"], "$(exp_prefix)_output")

    param_path["path_om_euler_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_euler.txt" # Path to Euler parameter file on OpenMind
    param_path["path_om_euler_am_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_euler_output.txt" # Path to Euler parameter file on OpenMind
    param_path["path_om_affine_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_affine.txt" # Path to Affine parameter file on OpenMind
    
    if dataset in datasets_freely_moving
        param_path["path_om_bspline_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_bspline.txt" # Path to BSpline parameter file on OpenMind
        param["max_graph_num"] = max_graph_num # maximum length before graph split
        param["blue_laser"] = blue_laser
        param["green_laser"] = green_laser
        
        param["blue_zero_thresh"] = 4.5
        param["blue_min_laser"] = 5.2
        param["blue_max_interpolate"] = 2.0

        param["green_zero_thresh"] = 7.0
        param["green_min_laser"] = 7.7
        param["green_max_interpolate"] = 2.0

        intensity = h5read("/data3/shared/2022-05-11-laser-power.h5", "488nm/1/intensity")
        laser_perc = h5read("/data3/shared/2022-05-11-laser-power.h5", "488nm/1/laser_percent")

        intensity2 = h5read("/data3/shared/2022-05-11-laser-power.h5", "488nm/2/intensity")
        laser_perc2 = h5read("/data3/shared/2022-05-11-laser-power.h5", "488nm/2/laser_percent")
        
        param["blue_laser_intensity"] = (intensity .+ intensity2) ./ 2
        param["blue_laser_perc"] = laser_perc

        param["green_laser_intensity"] = h5read("/data3/shared/2022-05-11-laser-power.h5", "561nm/1/intensity")
        param["green_laser_perc"] = h5read("/data3/shared/2022-05-11-laser-power.h5", "561nm/1/laser_percent");
        
        param["good_registration_resolutions"] = [(0,0)] #= which registration resolutions
            are good enough to extract data from =#
        param["reg_n_resolution"] = [0,0,4]

        param["good_registration_resolutions_to_immobilized"] = [(2,0),(2,1),(2,2),(2,3)]

        @assert(laser_perc == laser_perc2)
        
        # NeuroPAL labeling settings
        param_path["path_root_process_label"] = joinpath(path_root_process, "neuropal_label")
        create_dir(param_path["path_root_process_label"])
        param_path["path_neuropal_img"] = joinpath(param_path["path_root_process_label"], "NeuroPAL.nrrd")
        param_path["path_neuropal_img_mNeptune_GCaMP"] = joinpath(param_path["path_root_process_label"], "NeuroPAL_mNeptune_GCaMP_bleedthrough.nrrd")
        param_path["path_bfp_img"] = joinpath(param_path["path_root_process_label"], "BFP.nrrd")
        param_path["path_ofp_img"] = joinpath(param_path["path_root_process_label"], "OFP.nrrd")
        param_path["path_gcamp_img"] = joinpath(param_path["path_root_process_label"], "OFP_GCaMP.nrrd") # actual GCaMP
        param_path["path_mNeptune_img"] = joinpath(param_path["path_root_process_label"], "mNeptune.nrrd")
        param_path["path_mNeptune_gcamp_img"] = joinpath(param_path["path_root_process_label"], "mNeptune_GCaMP_bleedthrough.nrrd") # GCaMP bleedthrough
        param_path["path_all_red_img"] = joinpath(param_path["path_root_process_label"], "all_red.nrrd")
        param_path["path_neuron_img"] = joinpath(param_path["path_root_process_label"], "neuron_rois.nrrd")
        param_path["path_dir_autolabel_input"] = joinpath(param_path["path_root_process_label"], "autolabel_data")
        param_path["path_h5_autolabel_input"] = joinpath(param_path["path_dir_autolabel_input"], "NeuroPAL.h5")
        param_path["path_neuron_img_crop"] = joinpath(param_path["path_root_process_label"], "neuron_rois_cropped.h5")
        param_path["path_autolabel_csv"] = joinpath(param_path["path_root_process_label"], "labels.csv")
        param_path["path_autolabel_param"] = "/data3/shared/dl_weights/NeuroPAL_autolabel_release/v2/instance_segmentation_test.yaml"
        param_path["path_autolabel_neuron_ids"] = "/data3/shared/dl_weights/NeuroPAL_autolabel_release/v2/extracted_neuron_ids.h5"

        param["crop_size"] = (284, 120, 64, 4)

        # Deepnet registration settings
        np = pyimport("numpy")

        param_path["path_dir_nrrd_filt_recropped"] = joinpath(param_path["path_root_process"], "NRRD_filtered_recropped")
        param_path["path_dir_roi_watershed_recropped"] = joinpath(param_path["path_root_process"], "img_roi_watershed_recropped")
        param_path["path_dir_ch1_recropped"] = joinpath(param_path["path_root_process"], "ch1_recropped")
        param_path["path_dir_ch1_registered"] = joinpath(param_path["path_root_process"], "ch1_registered")
        param_path["path_deepreg_weights"] = "/data4/prj_register/multicolor_deepreg_test_5/multicolor_gncc/save/ckpt-596"
        param_path["path_deepreg_config"] = "/data3/shared/dl_weights/deepreg/large_crop/config.yaml"
        
        param_path["path_dir_nrrd_filt_recropped"] = joinpath(param_path["path_root_process"], "NRRD_filtered_recropped")
        create_dir(param_path["path_dir_nrrd_filt_recropped"])

        param["deepreg_batch_size"] = 3
        if flv_c == 3 && nvidia_smi_device == 1
            param["deepreg_batch_size"] = 6
        end
        if flv_c == 2 && nvidia_smi_device < 2
            param["deepreg_batch_size"] = 2
        end
        param["deepreg_label_size"] = (200, 3)
        

        param["euler_downsample_factor"] = 4
        param["euler_batch_size"] = 5000
        param["euler_x_translation_range_1"] = np.sort(np.concatenate((
            np.linspace(-0.24, 0.24, 49),
            np.linspace(-0.46, -0.25, 8),
            np.linspace(0.25, 0.46, 8),
            np.linspace(0.5, 1, 3),
            np.linspace(-1, -0.5, 3))))
        param["euler_x_translation_range_2"] = np.zeros(1)
        param["euler_y_translation_range_1"] = np.sort(np.concatenate((
            np.linspace(-0.28, 0.28, 29),
            np.linspace(-0.54, -0.3, 5),
            np.linspace(0.3, 0.54, 5),
            np.linspace(0.6, 1.4, 3),
            np.linspace(-1.4, -0.6, 3))))
        param["euler_y_translation_range_2"] = np.zeros(1)
        param["euler_z_translation_range_1"] = np.linspace(-1.0, 1.0, 201)
        param["euler_z_translation_range_2"] = np.zeros(1)
        param["euler_theta_rotation_range_xy"] = np.sort(np.concatenate((
            np.linspace(0, 19, 20),
            np.linspace(20, 160, 29),
            np.linspace(161, 199, 39),
            np.linspace(200, 340, 29),
            np.linspace(341, 359, 19))))
        param["euler_theta_rotation_range_xz"] = np.zeros(1)
        param["euler_theta_rotation_range_yz"] = np.zeros(1)

        create_dir(param_paths[dataset_central]["path_dir_roi_watershed_recropped"])
        create_dir(param_paths[dataset]["path_dir_nrrd_filt_recropped"])
        
    else
        param_path["path_om_bspline_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_immobilized_bspline.txt" # Path to BSpline parameter file on OpenMind
        param["reg_timept"] = reg_timepts[dataset]
        param["reg_n_resolution"] = [0,0,2] # the number of euler, affine and BSpline registrations
        param["good_registration_resolutions"] = [(2,0), (2,1)] #= which registration resolutions
            are good enough to extract data from =#
        param["registration_resolution_neuropal"] = (0,1) # registration resolution for neuroPAL immobilized registration
    end

    # BFP camera alignment settings
    if dataset == "BFP"
        param["use_camera_alignment"] = true
        
        param_path["path_camera_alignment"] = joinpath(path_root_process, "camera_align_BFP")
        param_path["path_elastix_param_camera_alignment"] = "/data3/shared/elastix_parameters/parameters_bfp_registration_euler.txt"
    end

    
    param["nonsynced_roi_val"] = 4095

    param_path["path_om_tmp"] = joinpath(param_path["path_om_data"], "temp"); # Path to a temporary directory on OpenMind to store data
    
    param["spacing_axi"] = spacing_axi
    param["spacing_lat"] = spacing_lat
    param["n_z"] = 80
    param["z_range"] = 4:80

    
    ### crop
    param["crop_threshold_size"] = 20
    param["crop_threshold_intensity"] = 7.

    ### UNet
    param["seg_threshold_unet"] = 0.5 # The UNet output confidence threshold for a pixel to be counted as a neuron.
    param["seg_min_neuron_size"] = 7 # Neurons with fewer than this many voxels will be discarded.
    param["seg_threshold_watershed"] = [0.7, 0.8, 0.9] #= The image is re-segmented at these thresholds,
        and checked for neurons that were split - these neurons will be segmented by watershed.=#
    param["seg_watershed_min_neuron_sizes"] = [5,4,4] #= When the image is re-segmented, neurons smaller
        than the size for the corresponding threshold will be discarded.=#
    param["instance_seg_num_batches"] = 5 # number of batches of data to run simultaneously
        # each batch = number of threads. Lower number = less memory used, but more chance of bad parity

    ### Registration graph
    param["degree"] = 10 # degree of registration graph
    param["degree_dataset"] = 2 # degree of inter-dataset registration graph

    ### Head finder
    param["head_threshold"] = 0.5 # head detection UNet threshold
    param["head_max_distance"] = [20,35,35] #= distance thresholds for the blobification of the worm
        that finds the head location=#
    param["head_err_threshold"] = 50 # sets a warning flag in the blobification
    param["head_vc_err_threshold"] = 150 # sets a warning flag in the blobification
    param["head_edge_err_threshold"] = 1 # sets a warning flag in the blobification

    ### Worm curve finder
    param["worm_curve_n_pts"] = 14 # number of points in worm curve detector
    param["worm_curve_head_idx"] = 4 #= index of "head" point - ideally near the front of the
        brain but not on the tip of the nose=#
    param["worm_curve_tail_idx"] = 11 #= index of tail" point - ideally near the back of the
        brain but not on the ventral cord=#
    param["worm_curve_downscale"] = 2 # binning factor for worm curve detector

    ### Registration parameters
    param["reg_n_resolution_activity_marker"] = [3]


    ### Quality control and heuristic parameters
    param["smeared_neuron_threshold"] = 20000 # ROIs that get smeared out to more than this many pixels due to registration will be deleted
    param["max_centroid_dist"] = 10 # maximum centroid distance post-registration 
    param["quality_metric"] = "NCC" # registration quality metric
    param["regularization_key"] = "nonrigid_penalty"
    

    param["matrix_self_weight"] = 1.0 # weight of the diagonal of registration map matrix
    param["min_cluster_weight"] = 1e-6 # minimum weight
    param["overlap_weight"] = 1.0 # weight for ROI overlaps
    param["centroid_weight"] = 1.0 # weight for centroid distance between ROIs
    param["activity_diff_weight"] = 3.0 # weight for red activity difference
    param["q_weight"] = 25.0 # weight for NCC of registration
    param["regularization_weight"] = 1.0 # weight of DDF regularization penalty
    param["displacement_weight"] = 2.0 # weight of displacement heuristic (ROIs that move farther are less likely to be correct)

    ### Clustering parameters
    param["cluster_height_thresh"] = -0.0001 # clustering threshold for multiple clusters will be merged together. Reasonable range -0.3 to 0
    param["cluster_overlap_thresh"] = 0.05 # tolerance for a cluster to have multiple ROIs from the same frame. Reasonable range 0 to 0.2
    
    param["frac_detections_threshold"] = 0.5 # fraction of time points that an ROI candidates must be detected in
    

    ### Deconvolution parameters
    param["gcamp_decay_response"] = 0.52; # GCaMP decay time, in seconds
    
    param["duration"] = Dates.Time(4,0,0)
    param["duration_julia"] = "2-0"
    param["partition_sbatch_script"] = "normal"
    param["memory"] = 1
    param["cpu_per_task"] = 16
    param["ch_marker"] = ch_red
    param["ch_activity"] = ch_bluegreen
    param["list_ch"] = unique([param["ch_marker"], param["ch_activity"]])
    param["email"] = nothing
    param["use_sbatch"] = true
    param["server"] = "openmind7.mit.edu"
    param["server_dtn"] = "openmind-dtn.mit.edu"
    param["job_name"] = "elx"
    param["job_name_activity_marker"] = "elx_ch12"
    param["job_name_marker_activity"] = "elx_ch21"
    param["array_size"] = 450
    param["partition"] = "use-everything"
    param["elx_wait_delay"] = 3600
    param["lock_wait"] = 0.5
    param["user"] = om_user;
    
    param["FLIR_FPS"] = 20.0; # FPS of NIR camera
    param["bad_tracking"] = []; # points to delete (NIR timeframe) due to bad tracking
    param["n_rec"] = 1 # number of Illumination Sequence module recordings (.nd2) in the HDF5 file

    ### Spline parameters
    param["num_center_pts"] = 1000 # number of points in raw worm spline
    param["segment_len"] = 7 # length of raw spline per segment of distance-corrected spline
    param["img_label_size"] = (480,360) # size of UNet input image

    ### Spline self-intersect correction parameters
    param["med_axis_shorten_threshold"] = 14 # if the medial axis shortens by at least this amount, trigger self-intersect correction
    param["nose_confidence_threshold"] = 0.99 # threshold for UNet's confidence of nose location for it to be used to crop medial axis
    param["nose_crop_threshold"] = 20 # maximum number of points in medial axis that can be cropped to the nose
    param["med_axis_self_dist_threshold"] = 40 # proximity of far-away medial axis points that will trigger self-intersect correction
    param["loop_dist_threshold"] = 60 # distance between medial axis points beyond which collisions/proximity trigger self-intersect correction
    param["max_med_axis_delta"] = Inf # change in pixels from points in one time point's medial axis to the next to trigger self-intersect correction
    param["trim_head_tail"] = 15; # amount to trim head and tail of worm boundary to allow motion of worm
    param["boundary_thickness"] = 5; # thickness of worm boundary
    param["close_pts_threshold"] = 30; # minimum distance on spline for boundary to be enforced
    param["worm_thickness_pad"] = 3 # gap between worm and worm boundary
    param["min_len_percent"] = 90; # lower percentile of non-self-intersect time points to be included in worm thickness computation
    param["max_len_percent"] = 98; # upper percentile of non-self-intersect time points to be included in worm thickness computation
    
    ### Body angle parameters
    param["body_angle_pos_lag"] = 2; # how far to enforce angle-continuity along a time point's spline
    param["body_angle_t_lag"] = 40; # how far to enforce angle-continuity along a particular spline-angle across time points
    param["max_pt"] = 31; # crop the spline at this location

    param["nose_pts"] = [1,2,3]; # points in spline corresponding to nose angle
    param["head_pts"] = [1,5,8]; # points in spline corresponding to head angle
    param["seg_range"] = param["head_pts"][3]:param["max_pt"]; # range of points for computing worm centroid

    ### Reversal parameters
    param["rev_len_thresh"] = 2; # length threshold for reversal events
    param["rev_v_thresh"] = -0.005; # velocity threshold for reversal events

    ### Angular velocity parameters
    param["filt_len_angvel"] = 150

    ### Pumping parameters
    param["filt_len_pumping"] = 40

    ### Other parameters
    param["num_lag"] = 1;  # number of NIR time points to smooth velocity and other variables with SG filter
    param["v_stage_m_filt"] = 10; # length of NIR time point window for position filtering
    param["v_stage_λ_filt"] = 250.0; # parameter for position filtering
    
    param["num_eigenworm"] = 5 # number of eigenworms to compute

    # behavior variables for MI and neuron prediction GLM in addition to eigenworm and body angle
    param["vars"] = ["velocity_stage", "worm_angle", "head_angle", "nose_angle", "angular_velocity", "worm_curvature"]

    # behavior variables for multi-dataset concatenation
    param["concat_vars"] = ["angular_velocity", "worm_angle", "head_angle", "nose_angle", "speed_stage", "worm_curvature",
            "ventral_worm_curvature", "nose_curling", "velocity_stage",
            "zeroed_x_confocal", "zeroed_y_confocal", "body_angle", "body_angle_all", "body_angle_absolute", "rev_times"]
    param["t_concat_vars"] = ["all_rev"];
    param["nir_concat_vars"] = ["nir_velocity_stage", "nir_head_angle", "nir_body_angle", "nir_body_angle_all",
            "nir_body_angle_absolute", "nir_nose_curling",
            "zeroed_x", "zeroed_y"]

    param["beta_values"] = [Int32(round(1.6^x)) for x=0:9] # inverse values of L1 norm penalty to use for neuron prediction GLM
    param["t_thresh_vals_base"] = -100:50:100; # thresholds between training and validation data to use for neuron prediction GLM

    param["rev_neuron_thresh"] = 0.5 # goodness of fit required to classify a neuron as a reversal neuron
    param["type1_thresh"] = 2; # maximum reversal length threshold for a type-1 reversal neuron

    param["fwd_neuron_thresh"] = 0.75; # goodness of fit required to classify a neuron as a forward neuron
    param["turning_neuron_thresh"] = 0.75; # goodness of fit required to classify a neuron as a turning neuron
    
    
    
    param_path["path_param_path"] = joinpath(path_root_process, "param_path.jld2")

    param_path["path_nd2"] = joinpath(path_raw_data, exp_prefix*".nd2")
    
    if dataset in datasets_freely_moving
        param_path["path_h5"] = joinpath(path_raw_data, exp_prefix*".h5")
    end
    path_dir_nd2, name_nd2 = splitdir(param_path["path_nd2"])
    param_path["img_prefix"] = splitext(name_nd2)[1]
    path_root_raw = path_dir_nd2
    list_ch = param["list_ch"]

    dir_nrrd = "NRRD"
    dir_MIP = "MIP"
    dir_nrrd_shearcorrect = "NRRD_shearcorrect"
    dir_MIP_shearcorrect = "MIP_shearcorrect"
    dir_nrrd_crop = "NRRD_cropped"
    dir_MIP_crop = "MIP_cropped"
    dir_nrrd_filt = "NRRD_filtered"
    dir_MIP_filt = "MIP_filtered"

    dir_roi = "img_roi"
    dir_roi_watershed = "img_roi_watershed"
    dir_roi_watershed_uncropped = "img_roi_watershed_uncropped"
    dir_marker_signal = "ch4_signal"
    dir_centroid = "centroids"
    dir_unet_data = "unet_data"

    dir_worm_curve = "worm_curve"
    dir_reg = "Registered"
    dir_reg_activity_marker = "Registered_ch1to2"
    dir_reg_marker_activity = "Registered_ch2to1"

    dir_transformed = "img_roi_transformed"
    dir_transformed_activity_marker = "NRRD_activity_transformed"
    dir_activity_signal = "gcamp_activity_transformed"
    dir_cmd = "elx_commands"
    dir_cmd_am = "elx_commands_ch1to2"
    dir_cmd_ma = "elx_commands_ch2to1"
    dir_cmd_array = "elx_commands_array"
    dir_log = "log"

    dir_roi_candidates = "roi_candidates"

    txt_file_extension = ".txt"
    name_head_pos = "head_pos"
    name_elastix_difficulty = "elastix_difficulty"
    name_reg_prob = "registration_problems"
    name_reg_quality = "registration_quality"
    name_qdict = "q_dict"
    name_best_reg = "best_reg"
    name_roi_overlap = "roi_overlap"
    name_roi_activity_diff = "roi_activity_diff"
    name_label_map = "label_map"
    name_inv_map = "inv_map"
    name_data_dict = "data_dict.jld2"
    name_param = "param.jld2"
    name_param_path = "param_path.jld2"
    name_error_dict = "error_dict.jld2"

    param_path["path_root_process"] = path_root_process
    param_path["name_head_rotate_logfile"] = "headrotate.log"
    param_path["name_transform_activity_marker"] = "TransformParameters.0.txt"
    param_path["name_transform_activity_marker_avg"] = "TransformParameters.0_avg.txt"
    param_path["name_transform_activity_marker_roi"] = "TransformParameters.0_roi.txt"
    param_path["key_transform_parameters"] = "TransformParameters"
    
    if !(dataset in datasets_freely_moving)
        param_path["path_dir_reg_activity"] = joinpath(param_path["path_root_process"], "Registered_G")
        param_path["path_om_reg_activity"] = joinpath(param_path["path_om_data"], "Registered_G")
        param_path["path_dir_cmd_activity"] = joinpath(param_path["path_root_process"], "elx_commands_G")
        param_path["path_om_cmd_activity"] = joinpath(param_path["path_om_data"], "elx_commands_G")
        param["job_name_activity"] = "elx_ch1"
        param_path["path_nrrd_avg"] = joinpath.(param_path["path_root_process"], ["avg_ch1.nrrd", "avg_ch2.nrrd"])
        param_path["path_nrrd_avg_ch1toch2"] = joinpath(param_path["path_root_process"], "avg_ch1toch2")
        param_path["path_nrrd_avg_ch1toch2_reg"] = joinpath(param_path["path_root_process"], "avg_ch1toch2_reg")
        param_path["nrrd_avg_res"] = "result.2.R0.nrrd"
        param_path["parameter_files_local"] = [joinpath("$(data_dir)/shared/elastix_parameters", "parameters_immobilized_bspline_turbo_v2.txt")]
        param_path["path_dir_reg_neuropal"] = joinpath(param_path["path_root_process"], "Registered_to_$(dataset_immobilized_central)")
        param_path["euler_head_params"] = joinpath(param_path["path_dir_reg_neuropal"], "euler.txt")
        param_path["path_nrrd_avg_translate"] = joinpath.(param_path["path_root_process"], ["avg_ch1_translate.nrrd", "avg_ch2_translate.nrrd"])
        param_path["path_nrrd_avg_ch1toch2_translate"] = joinpath(param_path["path_root_process"], "avg_ch1toch2_translate.nrrd")
    end


    param_path["path_dir_unet_data"], param_path["path_dir_nrrd"], param_path["path_dir_MIP"], 
        param_path["path_dir_nrrd_shearcorrect"], param_path["path_dir_MIP_shearcorrect"], 
        param_path["path_dir_nrrd_crop"],
        param_path["path_dir_nrrd_filt"], param_path["path_dir_MIP_crop"], param_path["path_dir_MIP_filt"] = 
            joinpath.(path_root_process, [dir_unet_data, dir_nrrd, dir_MIP, dir_nrrd_shearcorrect, dir_MIP_shearcorrect, dir_nrrd_crop, dir_nrrd_filt, dir_MIP_crop, dir_MIP_filt])

    param_path["path_dir_transformed"], param_path["path_dir_transformed_activity_marker"], param_path["path_dir_activity_signal"] =
        joinpath.(path_root_process, [dir_transformed, dir_transformed_activity_marker, dir_activity_signal])

    param_path["path_dir_roi"], param_path["path_dir_roi_watershed"], param_path["path_dir_roi_watershed_uncropped"], 
        param_path["path_dir_marker_signal"], param_path["path_dir_centroid"], param_path["path_dir_unet_data"] =
            joinpath.(path_root_process, [dir_roi, dir_roi_watershed, dir_roi_watershed_uncropped, dir_marker_signal,
                dir_centroid, dir_unet_data])

    param_path["path_dir_worm_curve"], param_path["path_dir_reg"], param_path["path_dir_reg_activity_marker"], 
        param_path["path_dir_reg_marker_activity"],
        param_path["path_roi_candidates"], param_path["path_dir_cmd"], param_path["path_dir_cmd_am"], param_path["path_dir_cmd_ma"],
        param_path["path_dir_cmd_array"], param_path["path_data_dict"], 
        param_path["path_param"], param_path["path_param_path"], param_path["path_error_dict"] =
            joinpath.(path_root_process,
                [dir_worm_curve, dir_reg, dir_reg_activity_marker, dir_reg_marker_activity, dir_roi_candidates, dir_cmd, 
                    dir_cmd_am, dir_cmd_ma, dir_cmd_array, name_data_dict, name_param,
                    name_param_path, name_error_dict])

    param_path["path_om_nrrd_filt"], param_path["path_om_nrrd"], param_path["path_om_reg"], 
        param_path["path_om_reg_activity_marker"], param_path["path_om_reg_marker_activity"],
        param_path["path_om_log"] = joinpath.(param_path["path_om_data"],
            [dir_nrrd_filt, dir_nrrd, dir_reg, dir_reg_activity_marker, dir_reg_marker_activity, dir_log])

    param_path["path_om_cmd"], param_path["path_om_cmd_am"], param_path["path_om_cmd_ma"], 
        param_path["path_om_cmd_array"] = joinpath.(param_path["path_om_scripts"], 
            [dir_cmd, dir_cmd_am, dir_cmd_ma, dir_cmd_array])

    param_path["path_head_pos"], param_path["path_elastix_difficulty"], param_path["path_reg_prob"], 
        param_path["path_qdict"], param_path["path_best_reg"], param_path["path_roi_overlap"], 
        param_path["path_roi_activity_diff"], param_path["path_reg_quality"], param_path["path_label_map"], param_path["path_inv_map"] = joinpath.(path_root_process, 
            [name_head_pos, name_elastix_difficulty, name_reg_prob, name_qdict, name_best_reg,
            name_roi_overlap, name_roi_activity_diff, name_reg_quality, name_label_map,
            name_inv_map] .* txt_file_extension)
    

    param_path["path_dir_mask"] = nothing
    param_path["path_om_mask"] = nothing

    param_path["parameter_files"] = [param_path["path_om_euler_param"], param_path["path_om_affine_param"], param_path["path_om_bspline_param"]]
    param_path["parameter_files_activity_marker"] = [param_path["path_om_euler_am_param"]]

    create_dir(param_path["path_root_process"])
    if !(dataset in datasets_freely_moving)
        create_dir(param_path["path_nrrd_avg_ch1toch2"])
        create_dir(param_path["path_dir_reg_neuropal"])
    end

    create_dir(param_path["path_dir_transformed_activity_marker"])
    create_dir(param_path["path_dir_transformed"])
    create_dir(param_path["path_roi_candidates"])
    create_dir(param_path["path_dir_activity_signal"])
    create_dir(param_path["path_dir_marker_signal"])
    create_dir(param_path["path_dir_reg"])
end

If you don't want to keep the default batch size, you can change it here. The main relevant factor is GPU VRAM.

In [None]:
params[dataset_central]["deepreg_batch_size"] = 6

In [None]:
## Uncomment this code to overwrite previous dictionaries.
for dataset in keys(datasets)
    if "get_basename" in keys(param_paths[dataset])
        delete!(param_paths[dataset], "get_basename")
    end

    param_path = param_paths[dataset]
    param = params[dataset]
    JLD2.@save(param_path["path_param_path"], param_path)
    add_get_basename!(param_path)
    JLD2.@save(param_path["path_param"], param)
end

Set `param["t_range"]` to be a list of datasets (numerically, by index) that you want to analyze. We use all available datasets here - even the one with incorrect human labels, as this unsupervised algorithm does not care.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        path_root_process = path_root_process_freelymoving
    else
        path_root_process = joinpath(path_root_process_immobilized, datasets[dataset])
    end
    path_param_path = joinpath(path_root_process, "param_path.jld2")
    
    if isfile(path_param_path)
        f = JLD2.jldopen(path_param_path)
        param_paths[dataset] = f["param_path"]
        close(f)
    end
    param_path = param_paths[dataset]

    change_rootpath!(param_path, path_root_process)

    if isfile(param_path["path_param"])
        f = JLD2.jldopen(param_path["path_param"])
        params[dataset] = f["param"]
        close(f)
    end
    
    param = params[dataset]

    add_get_basename!(param_path)
    
    if isfile(param_path["path_data_dict"])
        f = JLD2.jldopen(param_path["path_data_dict"])
        data_dicts[dataset] = f["data_dict"]
        close(f)
    else
        data_dicts[dataset] = Dict()
    end

    data_dict = data_dicts[dataset]

    param["t_range"] = [t for t in 1:102] # change this line of code as needed

    error_dicts[dataset] = Dict()
end

for dataset in datasets_register
    path = joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2")
    if isfile(path)
        f = JLD2.jldopen(path)
        data_dicts["$(dataset)_to_central"] = f["data_dict"]
        close(f)
    else
        data_dicts["$(dataset)_to_central"] = Dict()
    end
    error_dicts["$(dataset)_to_central"] = Dict()
end

In [None]:
# let
#     f = JLD2.jldopen("/scratch/adam/multicolor_deepreg_test_3/euler_parameters.jld2")
#     data_dicts[dataset_central]["euler_parameters"] = f["euler_parameters"]
#     close(f)
# end

### Specify which animals to use

Each dataset corresponds to one animal. If you want to keep track of which datasets were used in training/validation/testing of this notebook's version of CellDiscoveryNet, set `datasets_train`, `datasets_val`, and `datasets_test` here appropriately.

`datasets_` should contain all animals you wish to analyze (train, val, and test).

In [None]:
datasets_prj_neuropal = ["2022-07-15-06", "2022-07-15-12", "2022-07-20-01", "2022-07-26-01", "2022-08-02-01", "2023-01-23-08", "2023-01-23-15", "2023-01-23-21", "2023-01-19-08", "2023-01-19-22", "2023-01-09-28", "2023-01-17-01", "2023-01-19-15", "2023-01-23-01", "2023-03-07-01", "2022-12-21-06", "2023-01-05-18", "2023-01-06-01", "2023-01-06-08", "2023-01-09-08", "2023-01-09-15", "2023-01-09-22", "2023-01-10-07", "2023-01-10-14", "2023-01-13-07", "2023-01-16-01", "2023-01-16-08", "2023-01-16-15", "2023-01-16-22", "2023-01-17-07", "2023-01-17-14", "2023-01-18-01"]
datasets_prj_rim = ["2023-06-09-01", "2023-07-28-04", "2023-06-24-02", "2023-07-07-11", "2023-08-07-01", "2023-06-24-11", "2023-07-07-18", "2023-08-18-11", "2023-06-24-28", "2023-07-11-02", "2023-08-22-08", "2023-07-12-01", "2023-07-01-09", "2023-07-13-01", "2023-06-09-10", "2023-07-07-01", "2023-08-07-16", "2023-08-22-01", "2023-08-23-23", "2023-08-25-02", "2023-09-15-01", "2023-09-15-08", "2023-08-18-18", "2023-08-19-01", "2023-08-23-09", "2023-08-25-09", "2023-09-01-01", "2023-08-31-03", "2023-07-01-01", "2023-07-01-23"]

datasets_prj_aversion = ["2023-03-30-01", "2023-06-29-01", "2023-06-29-13", "2023-07-14-08", "2023-07-14-14", "2023-07-27-01", "2023-08-08-07", "2023-08-14-01", "2023-08-16-01", "2023-08-21-01", "2023-09-07-01", "2023-09-14-01", "2023-08-15-01", "2023-10-05-01", "2023-06-23-08", "2023-12-11-01", "2023-06-21-01"]
datasets_prj_5ht = ["2022-07-26-31", "2022-07-26-38", "2022-07-27-31", "2022-07-27-38", "2022-07-27-45", "2022-08-02-31", "2022-08-02-38", "2022-08-03-31"]
datasets_prj_starvation = ["2023-05-25-08", "2023-05-26-08", "2023-06-05-10", "2023-06-05-17", "2023-07-24-27", "2023-09-27-14", "2023-05-25-01", "2023-05-26-01", "2023-07-24-12", "2023-07-24-20", "2023-09-12-01", "2023-09-19-01", "2023-09-29-19", "2023-10-09-01", "2023-09-13-02"]

# append all datasets togther
datasets_ = []
append!(datasets_, datasets_prj_neuropal)
append!(datasets_, datasets_prj_rim)
append!(datasets_, datasets_prj_aversion)
append!(datasets_, datasets_prj_5ht)
append!(datasets_, datasets_prj_starvation)

datasets_val = ["2023-06-24-02", "2023-08-07-01", "2023-08-19-01", # RIM datasets
                "2022-07-26-01", "2023-01-23-21", "2023-01-23-01", # NeuroPAL datasets
                "2023-07-14-08", # Aversion datasets
                "2022-08-02-31", # 5-HT datasets
                "2023-07-24-27", "2023-07-24-20"] # Starvation datasets
datasets_test = ["2023-08-22-01", "2023-07-07-18", "2023-07-01-23",  # RIM datasets
                 "2023-01-06-01", "2023-01-10-07", "2023-01-17-07", # Neuropal datasets
                 "2023-08-21-01", "2023-06-23-08", # Aversion datasets
                 "2022-07-27-38", # 5-HT datasets
                 "2023-10-09-01", "2023-09-13-02" # Starvation datasets
                 ]
datasets_train = [dataset for dataset in datasets_ if !(dataset in datasets_val) && !(dataset in datasets_test)]
datasets_ = deepcopy(datasets_train)
append!(datasets_, datasets_val)
append!(datasets_, datasets_test);

You can also speficy which side the worm was laying on here. Note that you should have already rotated the worm to lie on the same side before running this notebook, but there are still optical differences between the two sides. Loading the values here allows certain diagnostic code blocks to run that can check whether there is an issue registering one orientation to the other.

In [None]:
θh_pos_is_ventral = Dict(
    "2023-06-09-01"=> true,
    "2023-06-24-02"=> false,
    "2023-06-24-28"=> true,
    "2023-07-01-01"=> true,
    "2023-07-01-09"=> false,
    "2023-07-07-01"=> false,
    "2023-07-07-18"=> true,
    "2023-07-11-02"=> false,
    "2023-07-28-04"=> true,
    "2023-07-07-11"=> false,
    "2023-07-12-01"=> true,
    "2023-08-07-01"=> false,
    "2023-08-22-08"=> true,
    "2023-08-18-11"=> false,
    "2023-06-24-11"=> true,
    "2023-07-13-01"=> false,
    "2023-08-07-16"=> false,
    "2023-06-09-10"=> true,
    "2023-08-22-01"=> false,
    "2023-08-23-23"=> false,
    "2023-08-25-02"=> true,
    "2023-09-15-01"=> true,
    "2023-09-15-08"=> true,
    "2023-08-18-18"=> false,
    "2023-08-19-01"=> true,
    "2023-08-23-09"=> true,
    "2023-09-02-10"=> true,
    "2023-08-25-09"=> false,
    "2023-09-01-01"=> true,
    "2023-08-31-03"=> false,
    "2023-07-01-23"=> false,
    "2021-05-26-07"=> true,
    "2021-06-11-01"=> true,
    "2021-08-04-06"=> false,
    "2021-08-17-01"=> true,
    "2021-08-18-01"=> true,
    "2021-09-06-09"=> true,
    "2021-09-14-01"=> true,
    "2021-09-14-05"=> false,
    "2021-09-22-05"=> true,
    "2021-09-23-01"=> true,
    "2021-09-30-01"=> false,
    "2021-10-26-01"=> false,
    "2021-11-12-01"=> true,
    "2021-11-12-05"=> false,
    "2022-01-07-03"=> true, # NOT ACTUALLY COMPUTED
    "2022-01-09-01"=> false,
    "2022-01-17-01"=> false,
    "2022-01-23-01"=> true,
    "2022-01-26-01"=> true,
    "2022-01-27-01"=> false,
    "2022-01-27-04"=> true,
    "2022-02-08-01"=> true,
    "2022-02-08-04"=> false,
    "2022-02-16-01"=> false,
    "2022-02-16-04"=> true,
    "2022-03-15-04"=> true,
    "2022-03-16-01"=> true, # NOT ACTUALLY COMPUTED
    "2022-03-16-02"=> true, # NOT ACTUALLY COMPUTED
    "2022-03-22-01"=> true,
    "2022-04-05-01"=> true,
    "2022-04-12-04"=> true,
    "2022-04-14-04"=> true,
    "2022-04-18-04"=> false,
    "2022-05-17-01"=> false,
    "2022-05-17-06"=> false,
    "2022-05-25-02"=> false,
    "2022-06-14-01"=> true,
    "2022-06-14-07"=> true,
    "2022-06-14-13"=> true,
    "2022-06-28-01"=> true,
    "2022-06-28-07"=> true,
    "2022-07-15-06"=> true,
    "2022-07-15-12"=> true,
    "2022-07-20-01"=> true,
    "2022-07-26-01"=> true,
    "2022-07-29-08"=> true,
    "2022-08-02-01"=> true,
    "2022-12-21-06"=> true,
    "2023-01-05-01"=> true,
    "2023-01-05-18"=> true,
    "2023-01-06-01"=> true,
    "2023-01-06-08"=> true,
    "2023-01-06-15"=> true,
    "2023-01-09-08"=> true,
    "2023-01-09-15"=> true,
    "2023-01-09-22"=> true,
    "2023-01-09-28"=> true,
    "2023-01-10-07"=> true,
    "2023-01-10-14"=> true,
    "2023-01-13-07"=> true,
    "2023-01-16-01"=> true,
    "2023-01-16-08"=> true,
    "2023-01-16-15"=> true,
    "2023-01-16-22"=> true,
    "2023-01-17-01"=> true,
    "2023-01-17-07"=> true,
    "2023-01-17-14"=> true,
    "2023-01-18-01"=> true,
    "2023-01-19-01"=> false,
    "2023-01-19-08"=> true,
    "2023-01-19-15"=> false,
    "2023-01-19-22"=> true,
    "2023-01-23-01"=> true,
    "2023-01-23-08"=> true,
    "2023-01-23-15"=> true,
    "2023-01-23-21"=> true,
    "2023-03-07-01"=> true,
    "2022-07-26-31"=> true,
    "2022-07-26-38"=> true,
    "2022-07-27-31"=> true,
    "2022-07-27-38"=> true,
    "2022-07-27-45"=> true,
    "2022-08-02-31"=> true,
    "2022-08-02-38"=> true,
    "2022-08-03-31"=> false,
    "2023-03-30-01"=> true,
    "2023-06-21-01"=> false,
    "2023-06-23-08"=> true,
    "2023-06-29-01"=> false,
    "2023-06-29-13"=> true,
    "2023-07-14-08"=> true,
    "2023-07-14-14"=> false,
    "2023-07-27-01"=> true,
    "2023-07-27-08"=> true,
    "2023-08-08-07"=> false,
    "2023-08-14-01"=> true,
    "2023-08-15-01"=> false,
    "2023-08-16-01"=> true,
    "2023-08-21-01"=> true,
    "2023-09-07-01"=> false,
    "2023-09-14-01"=> true,
    "2023-09-25-01"=> true,
    "2023-10-05-01"=> false,
    "2023-12-11-01"=> true,
    "2023-05-25-08"=> false,
    "2023-05-26-08"=> false,
    "2023-06-05-10"=> true,
    "2023-06-05-17"=> false,
    "2023-07-24-27"=> true,
    "2023-09-27-14"=> false,
    "2023-05-25-01"=> false,
    "2023-05-26-01"=> false,
    "2023-05-30-14"=> false,
    "2023-07-24-12"=> false,
    "2023-07-24-20"=> false,
    "2023-09-12-01"=> false,
    "2023-09-19-01"=> false,
    "2023-09-29-19"=> false,
    "2023-10-09-01"=> false,
    "2023-10-09-07"=> false,
    "2023-09-13-02"=> false
);

In [None]:
root_path = root_path_data;

### Set up registration graph

We define the registration graph to be the complete graph. 

In [None]:
data_dicts[dataset_central]["registration_problems"] = []
for i in 1:length(datasets_)
    for j in i+1:length(datasets_)
        push!(data_dicts[dataset_central]["registration_problems"], (i,j))
    end
end

### Run CellDiscoveryNet

In [None]:
new_process = addprocs(1)[1]

In [None]:
@everywhere workers() begin
    # ENV["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
    using ImageDataIO
    using PyCall
    using NRRDIO
    using Statistics
    using ImageRegistration
    using FlavellBase
    using ProgressMeter
    using MultivariateStats
    using HDF5
end

In [None]:
@eval @everywhere param_paths=$param_paths
@eval @everywhere datasets=$datasets
@eval @everywhere registration_problems=$(Dict(dataset => data_dicts[dataset]["registration_problems"] for dataset in datasets_freely_moving))
@eval @everywhere params=$params
@eval @everywhere datasets_freely_moving=$datasets_freely_moving

In [None]:
result = @spawnat new_process begin
    predict = pyimport("deepreg.predict")
    overwrite_existing = false # set this to `true` to overwrite existing registrations; otherwise the code will skip over them and focus on incomplete registrations
    model = nothing
    nrrd_spacing = nothing
    for dataset in datasets_freely_moving
        batch_size = params[dataset]["deepreg_batch_size"]
        for i in 1:batch_size:length(registration_problems[dataset])
            batch = registration_problems[dataset][i:min(i+batch_size-1, length(registration_problems[dataset]))]
            if isfile(joinpath(param_paths[dataset]["path_dir_reg"], "$(batch[end][1])to$(batch[end][2])", "predicted_fixed_image.nrrd")) && !overwrite_existing
                continue
            end
            fixed_images = zeros((batch_size, params[dataset]["crop_size"]...))
            moving_images = zeros((batch_size, params[dataset]["crop_size"]...))
            

            for j in 1:length(batch)
                moving, fixed = batch[j]

                h5open(joinpath(root_path, "euler_tfm_moving/$(moving)_$(fixed).h5"), "r") do f
                    moving_images[j,:,:,:,:] .= permutedims(read(f["raw"]), (4,3,2,1)) # need to permute dims since julia and python read images in different dimensions
                end

                h5open(joinpath(root_path, "img_fixed/$(fixed).h5"), "r") do f
                    fixed_images[j,:,:,:,:] .= permutedims(read(f["raw"]), (4,3,2,1)) # need to permute dims since julia and python read images in different dimensions
                end
            end
            
            ddf, pred_fixed_image, model = predict.unwrapped_predict(fixed_images, moving_images, "/dev/null", params[dataset]["deepreg_label_size"], params[dataset]["deepreg_label_size"], model, 
                    param_paths[dataset]["path_deepreg_weights"], joinpath(root_path, "config_batch.yaml"))
            for (j, k) in enumerate(batch)
                ddf_batch = ddf[j,:,:,:,:]
                predicted_fixed_image_batch = pred_fixed_image[j,:,:,:,:]
                
                save_dir_problem = joinpath(param_paths[dataset]["path_dir_reg"], "$(k[1])to$(k[2])")

                create_dir(save_dir_problem)
            
                h5open(joinpath(save_dir_problem, "ddf.h5"), "w") do file
                    write(file, "ddf", ddf_batch)
                end
            
                write_nrrd(joinpath(save_dir_problem, "predicted_fixed_image.nrrd"), Float32.(predicted_fixed_image_batch), (0.54, 0.54, 0.54))
            end
        end
    end
    nothing
end

In [None]:
fetch(result)

In [None]:
rmprocs(new_process)

### Transform ROI images

This section transforms ROI images through the DDFs generated by CellDiscoveryNet.

In [None]:
new_process = addprocs(1)[1]

In [None]:
@everywhere workers() begin
    using ImageDataIO
    using PyCall
    using NRRDIO
    using Statistics
    using ImageRegistration
    using FlavellBase
    using MultivariateStats
    using HDF5
end

In [None]:
@eval @everywhere param_paths=$param_paths
@eval @everywhere datasets=$datasets
@eval @everywhere registration_problems=$(Dict(dataset => data_dicts[dataset]["registration_problems"] for dataset in datasets_freely_moving))
@eval @everywhere params=$params
@eval @everywhere datasets_freely_moving=$datasets_freely_moving

In [None]:
result = @spawnat new_process begin
    tf = pyimport("tensorflow")
    layer = pyimport("deepreg.model.layer")

    memory_dict = Dict()
    error_dicts = Dict()

    for dataset in datasets_freely_moving
        error_dicts[dataset] = Dict()
        for problem in registration_problems[dataset]
            try
                moving, fixed = problem
                euler_transformed_moving_roi_image = nothing

                h5open(joinpath(root_path, "euler_tfm_moving_roi", "$(moving)_to_$(fixed).h5"), "r") do f
                    euler_transformed_moving_roi_image = permutedims(read(f[collect(keys(f))[1]]), (3,2,1))
                end

                ddf = nothing
        
                h5open(joinpath(param_paths[dataset]["path_dir_reg"], "$(moving)to$(fixed)", "ddf.h5"), "r") do f
                    ddf = read(f["ddf"])
                end
        
                warping = layer.Warping(fixed_image_size=params[dataset]["crop_size"], batch_size=1, interpolation="nearest")
                euler_transformed_moving_roi_image_tf = tf.cast(tf.expand_dims(euler_transformed_moving_roi_image, axis=0), tf.float32)
        
                ddf_transformed_roi_image = warping(inputs=[ddf, euler_transformed_moving_roi_image_tf]).numpy()[1,:,:,:]
        
                save_dir = joinpath(param_paths[dataset]["path_dir_transformed"], "$(moving)to$(fixed)")
                create_dir(save_dir)
        
                write_nrrd(joinpath(save_dir, "result.nrrd"), floor.(UInt16, clamp.(ddf_transformed_roi_image, typemin(UInt16), typemax(UInt16))), (0.54, 0.54, 0.54))
            catch e
                error_dicts[dataset][problem] = e
            end
        end
    end
    error_dicts
end


In [None]:
error_dict = fetch(result)

In [None]:
rmprocs(new_process)

### Compute centroids

Compute centroids of all neuron ROIs and transform the centroids through CellDiscoveryNet's DDFs.

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["roi_centroids_recropped"] = Dict()
    @showprogress for t in params[dataset]["t_range"]
        img_roi_watershed = nothing

        h5open(joinpath(root_path, "roi_fixed", "$(t).h5"), "r") do f
            img_roi_watershed = permutedims(read(f["roi"]), (3,2,1)) # (x,y,z)
        end

        data_dicts[dataset]["roi_centroids_recropped"][t] = get_centroids_preservenum(img_roi_watershed)
    end

    data_dicts[dataset]["roi_centroids_transformed"] = Dict()
    @showprogress for problem in data_dicts[dataset]["registration_problems"]
        img_roi_watershed = read_img(NRRD(joinpath(param_paths[dataset]["path_dir_transformed"], "$(problem[1])to$(problem[2])", "result.nrrd")))
        data_dicts[dataset]["roi_centroids_transformed"][problem] = get_centroids_preservenum(img_roi_watershed)
    end
    
    data_dicts[dataset]["roi_displacement"] = Dict()
    @showprogress for problem in data_dicts[dataset]["registration_problems"]
        moving, fixed = problem
        data_dicts[dataset]["roi_displacement"][problem] = Dict()
        ddf = Dict()
        h5open(joinpath(param_paths[dataset]["path_dir_reg"], "$(moving)to$(fixed)", "ddf.h5"), "r") do f
            ddf["ddf"] = read(f["ddf"])
        end
        ddf = ddf["ddf"]

        fixed_roi = nothing
        h5open(joinpath(root_path, "roi_fixed", "$(fixed).h5"), "r") do f
            fixed_roi = permutedims(read(f["roi"]), (3,2,1))
        end

        dist_dict = Dict()
        for i in 1:size(data_dicts[dataset]["roi_centroids_recropped"][problem[2]],1)
            x, y, z = Int.(round.(data_dicts[dataset]["roi_centroids_recropped"][problem[2]][i,:]))
            if x > 0
                dist_dict[i] = [norm(ddf[x,y,z,:])]
            end
        end
        for roi_val in keys(dist_dict)
            data_dicts[dataset]["roi_displacement"][problem][roi_val] = mean(dist_dict[roi_val])
        end
    end
    data_dicts[dataset]["centroid_dist_dict"] = compute_centroid_dist_dict(data_dicts[dataset]["roi_centroids_recropped"], data_dicts[dataset]["roi_centroids_transformed"], max_dist=params[dataset]["max_centroid_dist"])
end

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["roi_centroids_euler_tfm"] = Dict()
    @showprogress for (moving, fixed) in data_dicts[dataset]["registration_problems"]
        img_roi_watershed = nothing

        h5open(joinpath(root_path_data, "euler_tfm_moving_roi", "$(moving)_to_$(fixed).h5"), "r") do f
            img_roi_watershed = permutedims(read(f[collect(keys(f))[1]]), (3,2,1))
        end

        data_dicts[dataset]["roi_centroids_euler_tfm"][(moving, fixed)] = get_centroids_preservenum(Int.(img_roi_watershed))
    end
end

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["euler_centroid_dist_dict"] = compute_centroid_dist_dict(data_dicts[dataset]["roi_centroids_recropped"], data_dicts[dataset]["roi_centroids_euler_tfm"], max_dist=params[dataset]["max_centroid_dist"])
end

### Compute nonrigidity of the DDF transformations

This is used as a heuristic. More nonrigid DDFs are weighted less heavily.

In [None]:
new_process = addprocs(1)[1]

In [None]:
@everywhere workers() begin
    using HDF5
    using PyCall
    using ImageDataIO
end

In [None]:
@eval @everywhere param_paths=$param_paths
@eval @everywhere datasets=$datasets
@eval @everywhere registration_problems=$(Dict(dataset => data_dicts[dataset]["registration_problems"] for dataset in datasets_freely_moving))
@eval @everywhere params=$params
@eval @everywhere datasets_freely_moving=$datasets_freely_moving
@eval @everywhere dataset_central=$dataset_central

In [None]:
result = @spawnat new_process begin
    reg_metric_dict = Dict()

    deepreg_loss = pyimport("deepreg.loss.deform")
    nonrigid_penalty = deepreg_loss.NonRigidPenalty(params[dataset_central]["crop_size"])

    for dataset in datasets_freely_moving
        reg_metric_dict[dataset] = Dict()
    
        for problem in registration_problems[dataset]
            ddf_ = zeros(params[dataset]["crop_size"][1:3]..., 3)
            file_path = joinpath(param_paths[dataset]["path_dir_reg"], "$(problem[1])to$(problem[2])", "ddf.h5")
        
            h5open(file_path, "r") do file
                ddf_[:,:,:,:] .= file["ddf"][:,:,:,:]
                ddf_ = reshape(ddf_, (1, size(ddf_)...))
            end
        
            if isnothing(ddf_)
                @warn "No DDF found for $(problem[1])to$(problem[2])"
                continue
            end
        
            reg_metric_dict[dataset][problem] = Dict()
            reg_metric_dict[dataset][problem]["nonrigid_penalty"] = nonrigid_penalty(ddf_).numpy()[1]
        end
    end
    reg_metric_dict
end

In [None]:
fetch(result)

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["reg_metric_dict"] = result[dataset]
end

In [None]:
rmprocs(new_process)

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

### Make quality dictionary

This code computes a NCC-based metric for how well each registration did.


In [None]:
for dataset in keys(datasets)
    evaluation_functions = Dict()
    param = params[dataset]
    param_path = param_paths[dataset]
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    
    path_fixed = (dataset in datasets_freely_moving) ? param_path["path_dir_nrrd_filt_recropped"] : param_path["path_dir_nrrd"]
    
    function eval_fn(moving, fixed, resolution)
        fixed_img = nothing
        
        h5open(joinpath(root_path, "img_fixed/$(fixed).h5"), "r") do f
            fixed_img = permutedims(read(f["raw"]), (4,3,2,1)) # need to permute dims since julia and python read images in different dimensions
        end
        pred_fixed_img = read_img(NRRD(joinpath(param_path["path_dir_reg"], "$(moving)to$(fixed)", "predicted_fixed_image.nrrd")))
        return calculate_ncc(fixed_img, pred_fixed_img)
    end

    evaluation_functions[param["quality_metric"]] = eval_fn
    data_dict["q_dict"], data_dict["best_reg"], error_dict["q_dict_errors"] = make_quality_dict(param_path, param, data_dict["registration_problems"], evaluation_functions);
end;

In [None]:
for dataset in keys(datasets)
    for k in keys(data_dicts[dataset]["q_dict"])
        for r in keys(data_dicts[dataset]["q_dict"][k])
            if isnan(data_dicts[dataset]["q_dict"][k][r]["NCC"])
                data_dicts[dataset]["q_dict"][k][r]["NCC"] = 0.0
            end
        end
    end
end

Fraction of registrations above a certain NCC


In [None]:
let
    dataset = "freely_moving"
    length([k for k in keys(data_dicts["$(dataset)"]["q_dict"]) if maximum([data_dicts["$(dataset)"]["q_dict"][k][r]["NCC"] for r in keys(data_dicts["$(dataset)"]["q_dict"][k])]) > 0.8]) / length(data_dicts["$(dataset)"]["registration_problems"])
end

### Get activity in each channel (color)

In [None]:
img_ = nothing
img_roi_watershed_ = nothing

@showprogress for t in params[dataset_central]["t_range"]
    img = nothing
        
    h5open(joinpath(root_path, "img_fixed", "$(t).h5"), "r") do f
        img = permutedims(read(f["raw"]), (4,3,2,1)) # need to permute dims since julia and python read images in different dimensions
    end

    img_roi_watershed = nothing

    h5open(joinpath(root_path, "roi_fixed", "$(t).h5"), "r") do f
        img_roi_watershed = permutedims(read(f["roi"]), (3,2,1))
    end

    for ch in 1:4
        activity = Float64.(get_activity(img_roi_watershed, img[:,:,:,ch] .* 4095))
        create_dir(joinpath(param_paths[dataset_central]["path_root_process"], "ch$(ch)_signal"))
        write_activity(activity, joinpath(param_paths[dataset_central]["path_root_process"], "ch$(ch)_signal", "$(t).txt"))
    end
end


### Compute ROI match matrix

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["roi_overlaps"], data_dicts[dataset]["roi_activity_diff"], error_dicts[dataset]["overlap_errors"] = extract_roi_overlap_deepreg_multicolor(
        data_dicts[dataset]["registration_problems"], param_paths[dataset], params[dataset], joinpath(root_path, "roi_fixed")
    );
end

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["regmap_matrix"], data_dicts[dataset]["label_map"] = make_regmap_matrix_multicolor(
        data_dicts[dataset]["centroid_dist_dict"], data_dicts[dataset]["roi_overlaps"], 
        data_dicts[dataset]["q_dict"], data_dicts[dataset]["best_reg"], data_dicts[dataset]["reg_metric_dict"], 
        data_dicts[dataset]["roi_displacement"], param_paths[dataset], param_paths[dataset], activity_diff_weight=0.0, color_diff_weight=7.0
    );
end

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

### Load human and AutoCellLabeler labels

So far, ANTSUN 2U and CellDiscoveryNet have never used any human labels. The labels are being loaded here to check accuracy of computed unsupervised labels.

In [None]:
autolabel = pyimport("autolabel")

data_dicts[dataset_central]["labels"] = []
@showprogress for (t, dataset) in enumerate(datasets_)
    push!(data_dicts[dataset_central]["labels"], autolabel.map_roi_to_neuron(joinpath("/data3/adam/new_unet_train/csv_paper_2", dataset * " Neuron ID.csv"), confidence_threshold=4))
end

In [None]:
data_dicts[dataset_central]["autolabels"] = []
@showprogress for (t, dataset) in enumerate(datasets_)
    push!(data_dicts[dataset_central]["autolabels"], autolabel.map_roi_to_neuron(joinpath("/data3/adam/new_unet_train/csv_extracted_paper_all", dataset * ".csv"), confidence_threshold=3))
end

In [None]:
neuron_ids_nolr = ["I3", "MI", "M4", "RMED", "RMEV", "ALA", "RID", "M1", "I6", "RIS", "RIH", "M5", "I4", "AQR", "RIR", "VB02", "VB01", "I5", "AVL", "VA01", "VD01", "AVG", "DD01", "DB02", "UNKNOWN", "glia", "granule"];

### Perform clustering with different values of $w_7$

The following code block clusters the registration heuristic matrix with different $w_7$ values, and computes accuracy (against humans and against autolabel) on each clustering.

In [None]:
accuracy_vals = Dict()
accuracy_vals_autolabel = Dict()
accuracy_vals_cumulative = Dict()
accuracy_vals_cumulative_autolabel = Dict()
cutoff_vals = [0.0] # this is a list of negative w_7 values
for i in 1:36
    if i != 12 # skip -1e-9, save it to the end
        push!(cutoff_vals, -1e-12 * 10^(i/4))
    end
end

push!(cutoff_vals, -1e-9) # add back -1e-9
exclude = false
@showprogress for cutoff = cutoff_vals
    if exclude
        data_dicts[dataset_central]["new_label_map"], data_dicts[dataset_central]["inv_map"] = find_neurons(data_dicts[dataset_central]["regmap_matrix_exclude"], data_dicts[dataset_central]["label_map_exclude"], overlap_threshold=0.00, height_threshold=cutoff);
    else
        data_dicts[dataset_central]["new_label_map"], data_dicts[dataset_central]["inv_map"] = find_neurons(data_dicts[dataset_central]["regmap_matrix"], data_dicts[dataset_central]["label_map"], overlap_threshold=0.00, height_threshold=cutoff);
    end

    data_dicts[dataset_central]["inv_map_labels"] = Dict()
    data_dicts[dataset_central]["inv_map_classes"] = Dict()
    data_dicts[dataset_central]["inv_map_all_labels"] = Dict()
    data_dicts[dataset_central]["inv_map_all_classes"] = Dict()

    data_dicts[dataset_central]["autolabel_inv_map_labels"] = Dict()
    data_dicts[dataset_central]["autolabel_inv_map_classes"] = Dict()
    data_dicts[dataset_central]["autolabel_inv_map_all_labels"] = Dict()
    data_dicts[dataset_central]["autolabel_inv_map_all_classes"] = Dict()

    for n in keys(data_dicts[dataset_central]["inv_map"])
        threshold = (length(datasets_) - 1) / 2 + 1
        if exclude
            threshold = (length(datasets_) - length(timepoints_exclude)) / 2 + 1
        end
        if length(data_dicts[dataset_central]["inv_map"][n]) < threshold
            continue
        end

        data_dicts[dataset_central]["inv_map_all_classes"][n] = []
        data_dicts[dataset_central]["inv_map_labels"][n] = Dict()
        data_dicts[dataset_central]["inv_map_classes"][n] = Dict()
        data_dicts[dataset_central]["inv_map_all_labels"][n] = []

        data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n] = []
        data_dicts[dataset_central]["autolabel_inv_map_labels"][n] = Dict()
        data_dicts[dataset_central]["autolabel_inv_map_classes"][n] = Dict()
        data_dicts[dataset_central]["autolabel_inv_map_all_labels"][n] = []

        for t in keys(data_dicts[dataset_central]["inv_map"][n])
            data_dicts[dataset_central]["inv_map_labels"][n][t] = []
            data_dicts[dataset_central]["inv_map_classes"][n][t] = []

            data_dicts[dataset_central]["autolabel_inv_map_labels"][n][t] = []
            data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t] = []
            if t == 63 # incorrectly labeled dataset (labels are mismatched with ROIs)
                continue
            end
            for roi in data_dicts[dataset_central]["inv_map"][n][t]
                found_label = false
                found_autolabel = false
                if roi in keys(data_dicts[dataset_central]["labels"][t][1])
                    append!(data_dicts[dataset_central]["inv_map_all_labels"][n], data_dicts[dataset_central]["labels"][t][1][roi])
                    for label in data_dicts[dataset_central]["labels"][t][1][roi]
                        push!(data_dicts[dataset_central]["inv_map_labels"][n][t], label)
                        if occursin("-alt", label)
                            label = label[1:end-4]
                        end
                        if !(label in neuron_ids_nolr)
                            label = label[1:end-1]
                        end
                        if occursin("?", label)
                            continue
                        end
                        push!(data_dicts[dataset_central]["inv_map_all_classes"][n], label)
                        push!(data_dicts[dataset_central]["inv_map_classes"][n][t], label)
                        found_label = true
                    end
                end

                if !found_label
                    push!(data_dicts[dataset_central]["inv_map_labels"][n][t], "UNKNOWN")
                    push!(data_dicts[dataset_central]["inv_map_classes"][n][t], "UNKNOWN")
                end

                if t == 61
                    continue
                end
                if roi in keys(data_dicts[dataset_central]["autolabels"][t][1])
                    append!(data_dicts[dataset_central]["autolabel_inv_map_all_labels"][n], data_dicts[dataset_central]["autolabels"][t][1][roi])
                    for label in data_dicts[dataset_central]["autolabels"][t][1][roi]
                        push!(data_dicts[dataset_central]["autolabel_inv_map_labels"][n][t], label)
                        if occursin("-alt", label)
                            label = label[1:end-4]
                        end
                        if !(label in neuron_ids_nolr)
                            label = label[1:end-1]
                        end
                        if occursin("?", label)
                            continue
                        end
                        push!(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n], label)
                        push!(data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t], label)
                        found_autolabel = true
                    end
                end


                if !found_autolabel
                    push!(data_dicts[dataset_central]["autolabel_inv_map_labels"][n][t], "UNKNOWN")
                    push!(data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t], "UNKNOWN")
                end
            end
        end
    end

    accuracy_vals[cutoff] = Dict()
    n_correct = 0
    n_incorrect = 0
    for n in keys(data_dicts[dataset_central]["inv_map_all_classes"])
        accuracy_vals[cutoff][n] = Dict()
        most_common = countmap(data_dicts[dataset_central]["inv_map_all_classes"][n])
        if length(most_common) == 0
            accuracy_vals[cutoff][n] = 0
            continue
        end
        if length(data_dicts[dataset_central]["inv_map_all_classes"][n]) > 2
            n_correct += maximum(values(most_common))
            n_incorrect += length(data_dicts[dataset_central]["inv_map_all_classes"][n]) - maximum(values(most_common))
        elseif cutoff == -1e-9
            println(n)
        end
        accuracy_vals[cutoff][n] = maximum(values(most_common)) / length(data_dicts[dataset_central]["inv_map_all_classes"][n])
    end

    accuracy_vals_cumulative[cutoff] = n_correct / (n_correct + n_incorrect)

    n_correct_autolabel = 0
    n_incorrect_autolabel = 0
    accuracy_vals_autolabel[cutoff] = Dict()
    for n in keys(data_dicts[dataset_central]["autolabel_inv_map_all_classes"])
        accuracy_vals_autolabel[cutoff][n] = Dict()
        most_common = countmap(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n])
        if length(most_common) == 0
            accuracy_vals_autolabel[cutoff][n] = 0
            continue
        end
        if length(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n]) > 2
            n_correct_autolabel += maximum(values(most_common))
            n_incorrect_autolabel += length(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n]) - maximum(values(most_common))
        end
        accuracy_vals_autolabel[cutoff][n] = maximum(values(most_common)) / length(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n])
    end

    accuracy_vals_cumulative_autolabel[cutoff] = n_correct_autolabel / (n_correct_autolabel + n_incorrect_autolabel)
end
            

Accuracy value at $w_7 = 10^{-9}$

In [None]:
accuracy_vals_cumulative[-1e-9]


### Plot accuracy vs number of neurons tradeoff curve

In [None]:
med_accuracy = []
mean_accuracy = []
cum_accuracy = []
n_detection = []

selected_point = -1e-9

for cutoff in reverse(sort(collect(keys(accuracy_vals))))
    if length(accuracy_vals[cutoff]) > 5
        push!(med_accuracy, median(values(accuracy_vals[cutoff])))
        push!(mean_accuracy, mean(values(accuracy_vals[cutoff])))
        push!(n_detection, length(values(accuracy_vals[cutoff])))
        push!(cum_accuracy, accuracy_vals_cumulative[cutoff])
    end
    # uncomment the following line of code to print out accuracy values
    # println(cutoff, " ", accuracy_vals_cumulative[cutoff], " ", length(values(accuracy_vals[cutoff])))
end

gr()
Plots.scatter(cum_accuracy, n_detection, label=nothing, size=(210,130), color="#1f77b4")
Plots.scatter!([accuracy_vals_cumulative[selected_point]], [length(values(accuracy_vals[selected_point]))], label=L"w_7=10^{-9}", color="red", legend=:bottomleft)
# Plots.xlabel!("ANTSUN-Unsupervised accuracy")
# Plots.ylabel!("Number of linked neuron IDs")
Plots.ylims!(0,150)
Plots.xlims!(0.9, 1.0001)
Plots.xticks!(0.9:0.02:1.0)
plot!(
    xguidefont = font("DejaVu Sans", pointsize=7),
    yguidefont = font("DejaVu Sans", pointsize=7),
    xtickfont = font("DejaVu Sans", pointsize=7),
    ytickfont = font("DejaVu Sans", pointsize=7),
    grid = false,
    tickdir=:out
)

Plots.savefig("/data3/prj_register/figures/figure_6/accuracy_vs_detection.pdf")
Plots.plot!()
# Plots.ylims!(0,250)

### Compute centroid distance

In [None]:
data_dicts[dataset_central]["inv_labels"] = []

for (t, dataset) in enumerate(datasets_)
    push!(data_dicts[dataset_central]["inv_labels"], Dict())
    for roi in keys(data_dicts[dataset_central]["labels"][t][1])
        for i in 1:length(data_dicts[dataset_central]["labels"][t][1][roi])
            neuron = data_dicts[dataset_central]["labels"][t][1][roi][i]
            if !(neuron in keys(data_dicts[dataset_central]["inv_labels"][t]))
                data_dicts[dataset_central]["inv_labels"][t][neuron] = []
            end
            push!(data_dicts[dataset_central]["inv_labels"][t][neuron], roi)
        end
    end
end

In [None]:
data_dicts[dataset_central]["label_centroid_distance"] = []
data_dicts[dataset_central]["label_centroid_distance_highncc"] = []
data_dicts[dataset_central]["euler_label_centroid_distance"] = []
data_dicts[dataset_central]["label_centroid_distance_train"] = []
data_dicts[dataset_central]["label_centroid_distance_valtest"] = []
ncc_dist = Float64[]

for problem in data_dicts[dataset_central]["registration_problems"]
    push!(ncc_dist, data_dicts[dataset_central]["q_dict"][problem][(0,0)]["NCC"])
end

perc_90_ncc = quantile(ncc_dist, 0.9)
@showprogress for problem in data_dicts[dataset_central]["registration_problems"]
    dist_problem = []
    moving, fixed = problem
    if moving == 63 || fixed == 63 # corrupted labels
        continue
    end
    for (label_moving, rois_moving) in data_dicts[dataset_central]["inv_labels"][moving]
        for (label_fixed, rois_fixed) in data_dicts[dataset_central]["inv_labels"][fixed]
            # compute centroid distance only over matching labels
            if label_moving != label_fixed
                continue
            end
            # ignore labels contianing '?' character
            if occursin("?", label_moving) || occursin("?", label_fixed)
                continue
            end

            rois_moving = [r for r in rois_moving if r <= size(data_dicts[dataset_central]["roi_centroids_transformed"][(moving, fixed)],1)]
            rois_fixed = [r for r in rois_fixed if r <= size(data_dicts[dataset_central]["roi_centroids_recropped"][fixed],1)]

            centroid_moving = mean(data_dicts[dataset_central]["roi_centroids_transformed"][(moving, fixed)][rois_moving,:], dims=1)
            centroid_euler_moving = mean(data_dicts[dataset_central]["roi_centroids_euler_tfm"][(moving, fixed)][rois_moving,:], dims=1)
            centroid_fixed = mean(data_dicts[dataset_central]["roi_centroids_recropped"][fixed][rois_fixed,:], dims=1)

            if centroid_moving[1] < 0 || centroid_fixed[1] < 0
                continue
            end

            dist = norm(centroid_moving .- centroid_fixed)
            if isnan(dist)
                continue
            end

            push!(data_dicts[dataset_central]["label_centroid_distance"], dist)
            if problem[1] <= length(datasets_train) && problem[2] <= length(datasets_train)
                push!(data_dicts[dataset_central]["label_centroid_distance_train"], dist)
            else
                push!(data_dicts[dataset_central]["label_centroid_distance_valtest"], dist)
            end
            push!(data_dicts[dataset_central]["euler_label_centroid_distance"], norm(centroid_euler_moving .- centroid_fixed))
            if data_dicts[dataset_central]["q_dict"][problem][(0,0)]["NCC"] > perc_90_ncc
                push!(data_dicts[dataset_central]["label_centroid_distance_highncc"], dist)
            end
            # if norm(centroid_euler_moving .- centroid_fixed) > 200
            #     println(moving, " ", fixed, " ", label_moving, " ", label_fixed, " ", rois_moving, " ", rois_fixed, " ", dist, " ", norm(centroid_euler_moving .- centroid_fixed))
            # end
        end
    end
end



### Compute average number of labels per animal

In [None]:
len_include = length(datasets_) - 1 # subtract 1 for dataset 63 which has incorrect labels and is excluded

sum([length(data_dicts[dataset_central]["inv_map"][x]) for x in keys(data_dicts[dataset_central]["inv_map"]) if length(data_dicts[dataset_central]["inv_map"][x]) > len_include / 2]) / len_include

### Plot NCC distribution

In [None]:
let
    ncc_dist = Float64[]
    train_ncc_dist = Float64[]
    val_test_ncc_dist = Float64[]
    for problem in data_dicts[dataset_central]["registration_problems"]
        push!(ncc_dist, data_dicts[dataset_central]["q_dict"][problem][(0,0)]["NCC"])
        if problem[1] <= length(datasets_train) && problem[2] <= length(datasets_train)
            push!(train_ncc_dist, data_dicts[dataset_central]["q_dict"][problem][(0,0)]["NCC"])
        else
            push!(val_test_ncc_dist, data_dicts[dataset_central]["q_dict"][problem][(0,0)]["NCC"])
        end
    end
    println(length(train_ncc_dist))
    println(mean(ncc_dist))
    println(median(ncc_dist))
    println(percentile(ncc_dist, 90))
    println(mean(train_ncc_dist))
    println(mean(val_test_ncc_dist))

    println(median(ncc_dist))

    gr()
    StatsPlots.violin([ncc_dist, train_ncc_dist, val_test_ncc_dist], label=nothing, size=(210,130),  color="#1f77b4")

    plot!(
        xguidefont = font("DejaVu Sans", pointsize=7),
        yguidefont = font("DejaVu Sans", pointsize=7),
        xtickfont = font("DejaVu Sans", pointsize=7),
        ytickfont = font("DejaVu Sans", pointsize=7),
        grid = false,
        tickdir=:out
    )
    ylims!(0,1)
    # ylabel!("NCC")
    xticks!(1:3, ["All", "Train", "Val+\nTest"])

end

Plots.savefig("/data3/prj_register/figures/figure_6/NCC_violin.pdf")
plot!()



### Plot centroid distance distribution

In [None]:
gr()
StatsPlots.violin([log.(2, x) for x in [#data_dicts[dataset_central]["euler_label_centroid_distance"],
        data_dicts[dataset_central]["label_centroid_distance"], 
        data_dicts[dataset_central]["label_centroid_distance_train"], 
        data_dicts[dataset_central]["label_centroid_distance_valtest"], 
        data_dicts[dataset_central]["label_centroid_distance_highncc"]]],
    label=nothing, size=(210,130), color="#1f77b4", scale=:area)


println(median(data_dicts[dataset_central]["euler_label_centroid_distance"]))
println(median(data_dicts[dataset_central]["label_centroid_distance"]))
println(percentile(data_dicts[dataset_central]["label_centroid_distance"], 20))
println(mean(data_dicts[dataset_central]["label_centroid_distance"]))
println(median(data_dicts[dataset_central]["label_centroid_distance_train"]))
println(median(data_dicts[dataset_central]["label_centroid_distance_valtest"]))
println(median(data_dicts[dataset_central]["label_centroid_distance_highncc"]))
println(percentile(data_dicts[dataset_central]["label_centroid_distance_highncc"], 20))
ylims!(-4,8)
yticks!(-4:2:8, ["1/16", "1/4", "1", "4", "16", "64", "256"])
xticks!(1:4, ["All", "Train", "Val+\nTest", "High\nNCC"])
# ylabel!("Centroid distance")
plot!(
    xguidefont = font("DejaVu Sans", pointsize=7),
    yguidefont = font("DejaVu Sans", pointsize=7),
    xtickfont = font("DejaVu Sans", pointsize=7),
    ytickfont = font("DejaVu Sans", pointsize=7),
    grid = false,
    tickdir=:out
)
Plots.savefig("/data3/prj_register/figures/figure_6/centroid_distance_violin.pdf")
Plots.plot!()



### Plot matrix of labels per animal

Each entry in the matrix is one label detected in one animal. Its color indicates accuracy - green means accurate (matching plurality label), orange means inaccurate, blue means not labeled by the ground truth human labeler, and black means not labeled by ANTSUN 2U.

Since this analysis is based on human labels, we have to exclude dataset 63 which has incorrect labels:

In [None]:
params[dataset_central]["t_range"] = [t for t in 1:102 if t != 63]

In [None]:
data_dicts[dataset_central]["labels_deepreg"] = Dict()
thresh = 3 # number of matching labels needed to give a row label

let
    unknown_new = []
    u_count = 0
    mtx = zeros(length(keys(data_dicts[dataset_central]["inv_map_classes"])), length(params[dataset_central]["t_range"]))
    labels = []
    for (i, n) in enumerate(collect(keys(data_dicts[dataset_central]["inv_map_classes"])))
        most_common = countmap(data_dicts[dataset_central]["inv_map_all_classes"][n])
        if length(most_common) == 0 || maximum(values(most_common)) < thresh
            most_common_label = "UNKNOWN"
        else
            most_common_label = first(sort(collect(most_common), by=x->x[2], rev=true))[1]
        end

        # now disambiguate L/R (mostly for CSV output file purposes, as L/R is not currently used in analysis)
        valid_labels = [most_common_label]
        if !(most_common_label in neuron_ids_nolr)
            valid_labels = [most_common_label * "L", most_common_label * "R"]
        end
        most_common_full = countmap([label for label in data_dicts[dataset_central]["inv_map_all_labels"][n] if label in valid_labels])
        if (length(most_common_full) == 0) || (maximum(values(most_common)) < thresh)
            @assert(most_common_label == "UNKNOWN")
            most_common_full_label = "NEW " * string(u_count + 1)
            u_count += 1
            push!(unknown_new, length([k for k in keys(data_dicts[dataset_central]["inv_map"][n]) if k != 63]))
        else
            most_common_full_label = first(sort(collect(most_common_full), by=x->x[2], rev=true))[1]
        end

        push!(labels, most_common_full_label)


        most_common_class = most_common_full_label
        if !(most_common_full_label in neuron_ids_nolr)
            most_common_class = most_common_full_label[1:end-1]
        end

        if most_common_class != most_common_label
            @warn((n, " ", most_common_label, " ", most_common_full_label))
        end

        for (j, t) in enumerate(params[dataset_central]["t_range"])
            if !(t in keys(data_dicts[dataset_central]["inv_map_classes"][n]))
                mtx[i,j] = 0
            elseif length(data_dicts[dataset_central]["inv_map_classes"][n][t]) == 0
                mtx[i,j] = 1
            elseif length(data_dicts[dataset_central]["inv_map_classes"][n][t]) > 1
                mtx[i,j] = 2
            else
                label = data_dicts[dataset_central]["inv_map_classes"][n][t][1]
                if label == "UNKNOWN" || most_common_label == "UNKNOWN"
                    mtx[i,j] = 3
                elseif label != most_common_label
                    mtx[i,j] = 4
                else
                    mtx[i,j] = 5
                end
            end
            if !(t in keys(data_dicts[dataset_central]["inv_map"][n]))
                continue
            end
            if !(t in keys(data_dicts[dataset_central]["labels_deepreg"]))
                data_dicts[dataset_central]["labels_deepreg"][t] = Dict()
            end
            for roi in data_dicts[dataset_central]["inv_map"][n][t]
                @assert !(roi in keys(data_dicts[dataset_central]["labels_deepreg"][t]))
                data_dicts[dataset_central]["labels_deepreg"][t][roi] = most_common_full_label
            end
        end
    end

    mtx_autolabel = zeros(length(keys(data_dicts[dataset_central]["autolabel_inv_map_classes"])), length(params[dataset_central]["t_range"]))
    labels_autolabel = []
    for (i, n) in enumerate(collect(keys(data_dicts[dataset_central]["autolabel_inv_map_classes"])))
        most_common = countmap(data_dicts[dataset_central]["autolabel_inv_map_all_classes"][n])
        if length(most_common) == 0
            most_common_label = "UNKNOWN"
        else
            most_common_label = first(sort(collect(most_common), by=x->x[2], rev=true))[1]
        end
        push!(labels_autolabel, most_common_label)
        for (j, t) in enumerate(params[dataset_central]["t_range"])
            if !(t in keys(data_dicts[dataset_central]["autolabel_inv_map_classes"][n]))
                mtx_autolabel[i,j] = 0
            elseif length(data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t]) == 0
                mtx_autolabel[i,j] = 1
            elseif length(data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t]) > 1
                mtx_autolabel[i,j] = 2
                println(n, " ", t)
            else
                label = data_dicts[dataset_central]["autolabel_inv_map_classes"][n][t][1]
                if label == "UNKNOWN" || most_common_label == "UNKNOWN"
                    mtx_autolabel[i,j] = 3
                elseif label != most_common_label
                    mtx_autolabel[i,j] = 4
                else
                    mtx_autolabel[i,j] = 5
                end
            end
        end
    end

    # Sort columns of matrix by number of incorrect (4) labels
    incorrect_labels = sum(mtx .== 4, dims=1)[1,:] ./ (sum(mtx .== 4, dims=1)[1,:] .+ sum(mtx .== 5, dims=1)[1,:])

    # Function to compute categorical distance between two vectors
    function categorical_dist(a, b)
        sum(a .!= b)
    end

    # Compute the row distance matrix
    num_rows = size(mtx, 1)
    row_dist = zeros(num_rows, num_rows)
    for i in 1:num_rows
        for j in 1:num_rows
            row_dist[i, j] = categorical_dist(mtx[i, :], mtx[j, :])
        end
    end

    # Compute the column distance matrix
    num_cols = size(mtx, 2)
    col_dist = zeros(num_cols, num_cols)
    for i in 1:num_cols
        for j in 1:num_cols
            col_dist[i, j] = categorical_dist(mtx[:, i], mtx[:, j])
        end
    end

    # Perform hierarchical clustering
    row_clustering = hclust(row_dist, linkage=:average)
    col_clustering = hclust(col_dist, linkage=:average)

    # Get the order of rows and columns
    row_order = row_clustering.order
    col_order = col_clustering.order

    # Reorder the matrix
    reordered_mtx = mtx[row_order, col_order]
    reordered_mtx_autolabel = mtx_autolabel[row_order, col_order]
    reordered_labels = labels[row_order]
    reordered_labels_autolabel = labels_autolabel[row_order]

    cmap = PyPlot.matplotlib.colors.ListedColormap(["black", "red", "C0", "C0", "C1", "C2"])
    figure(figsize=(10,10))
    imshow(reordered_mtx, cmap=cmap)
    font = Dict("family" => "DejaVu Sans", "size" => 7)

    PyPlot.xlabel("dataset", fontdict=font)

    PyPlot.yticks(0:length(reordered_labels)-1, reordered_labels, fontsize=5, fontname="DejaVu Sans")
    PyPlot.xticks([], [], fontsize=7, fontname="DejaVu Sans")

    PyPlot.ylabel("neuron", fontdict=font)

    create_dir("/data3/prj_register/figures/figure_S6")
    PyPlot.savefig("/data3/prj_register/figures/figure_S6/full_label_mtx.pdf")

    figure(figsize=(10,10))
    imshow(reordered_mtx_autolabel, cmap=cmap)
    font = Dict("family" => "DejaVu Sans", "size" => 7)

    PyPlot.xlabel("animal", fontdict=font)

    PyPlot.yticks(0:length(reordered_labels_autolabel)-1, reordered_labels_autolabel, fontsize=5, fontname="DejaVu Sans")
    PyPlot.xticks([], [], fontsize=7, fontname="DejaVu Sans")

    PyPlot.ylabel("neuron", fontdict=font)

    labels_sorted = sort(collect(labels))
end

Compute number of labeled neurons per animal in each of the testing datasets, for comparison with AutoCellLabeler and human relabels.

In [None]:
n_labels = []

for idx=92:102
    push!(n_labels, length(data_dicts[dataset_central]["labels_deepreg"][idx]))
end

n_labels

Helper function to output ANTSUN 2U labels in the same format as AutoCellLabeler outputs, compatible with any post-processing code.

In [None]:
function write_csv_files(output_dir::String, labels::Dict, roi_coords::Dict, datasets::Vector{String}, add_qm::Bool=true, neuron_ids_nolr::Vector{String}=neuron_ids_nolr)
    create_dir(output_dir)
    for (dataset_idx, label_dict) in labels
        dataset_name = datasets[dataset_idx]
        csv_filename = joinpath(output_dir, "$dataset_name Neuron ID.csv")
        
        # Prepare the data
        neuron_class = String[]
        coordinates = String[]
        roi_id = Int[]
        confidence = Int[]
        notes = String[]
        
        for (roi_idx, label) in label_dict
            if add_qm && !(label in neuron_ids_nolr)
                label = label * "?"
            end
            push!(neuron_class, label)
            coord = roi_coords[dataset_idx][roi_idx, :]
            push!(coordinates, "$(coord[1]),$(coord[2]),$(coord[3])")
            push!(roi_id, roi_idx)
            push!(confidence, 5)
            push!(notes, "")
        end
        
        # Create DataFrame with spaces in column names
        df = DataFrame(
            Symbol("Neuron Class") => neuron_class, 
            Symbol("Coordinates") => coordinates, 
            Symbol("ROI ID") => roi_id, 
            Symbol("Confidence") => confidence, 
            Symbol("Notes") => notes
        )
        
        # Write DataFrame to CSV
        CSV.write(csv_filename, df)
    end
end

Assign "UNKNOWN" label to all unlabeled ROIs

In [None]:
data_dicts[dataset_central]["padded_labels_deepreg"] = Dict()

for dataset_idx in keys(data_dicts[dataset_central]["labels_deepreg"])
    data_dicts[dataset_central]["padded_labels_deepreg"][dataset_idx] = Dict()
    for roi in keys(data_dicts[dataset_central]["labels_deepreg"][dataset_idx])
        data_dicts[dataset_central]["padded_labels_deepreg"][dataset_idx][roi] = data_dicts[dataset_central]["labels_deepreg"][dataset_idx][roi]
        if occursin("UNKNOWN", data_dicts[dataset_central]["labels_deepreg"][dataset_idx][roi])
            data_dicts[dataset_central]["padded_labels_deepreg"][dataset_idx][roi] = "UNKNOWN"
        end
    end
    for roi in 1:size(data_dicts[dataset_central]["roi_centroids_recropped"][dataset_idx], 1)
        if !(roi in keys(data_dicts[dataset_central]["padded_labels_deepreg"][dataset_idx]))
            data_dicts[dataset_central]["padded_labels_deepreg"][dataset_idx][roi] = "UNKNOWN"
        end
    end
end


Write the files

In [None]:
write_csv_files(joinpath(param_paths[dataset_central]["path_root_process"], "csv_deepreg"), data_dicts[dataset_central]["padded_labels_deepreg"], data_dicts[dataset_central]["roi_centroids_recropped"], datasets_, false)

### Save your data

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end