# ***Automatic Neuron Tracking System for Unconstrained Nematodes (ANTSUN) v1.0.0***

This pipeline takes as input `.nd2` files from the confocal scope corresponding to freely-moving and immobilized recordings for NeuroPAL, together with a `.h5` file from the NIR tracking camera. It processes the confocal data to locate and track neurons in the worm's heads, and thereby generate GCaMP traces for them, and generate the NeuroPAL RGB image used for labeling. It also processes the NIR data to extract behavioral variables that can be correlated with the Ca traces.


## General notes

### GPU acceleration

Many steps of `ANTSUN` are GPU-accelerated (they are labeled as such beforehand). Thus, you need to ensure that there is sufficient GPU memory on your device before running those steps. It is recommended to have at least 8GB available before running any GPU-accelerated step. You can view how much GPU memory is available by running `nvidia-smi`.

- On `flv-c1`, there is only one GPU with 8GB memory. If anyone else is using it, it is recommended to wait to run any GPU-accelerated `ANTSUN` steps until the GPU is completely free.
- On `flv-c2`, there are two GPUs with 16GB memory each. By default, `ANTSUN` uses GPU 0 (the top one in `nvidia-smi`), with the exception of the 3D UNet which uses GPU 1 by default. If GPU 0 is full, you can force `ANTSUN` to use the GPU 1 by running the command `device!(1)` in this notebook.

Another note is that Julia is not good at managing to keep track of its GPU memory allocations, so at the end of this notebook it will likely get a substantial amount of memory stuck in the GPU. To free up this memory, **please restart the kernel after the notebook is done running.**

### Error recovery

If a step of the pipeline encounters an unexpected error, it is usually advisable to restart the kernel. **WARNING:** Restarting the kernel while save operations are in progress could corrupt the data files and possibly mean you would have to start over from the very beginning. Thus, before restarting the kernel, you should check that `JLD2` files are not actively being saved, and only restart the kernel once the pipeline is on a step other than saving them.

After you restart the kernel, rerun the "Load packages", "Data parameters" and "Initialize data dictionaries" steps of the pipeline. This should reset the pipeline into its last usable state. If the error was due to running out of memory or some other hardware issue, assuming there are computing resources available, you should then be able to run the pipeline from the step that generated the error through the end.

If the error was due to you making a mistake in the "Data parameters" section of the pipeline, you may need to rerun the "Pipeline parameters" and "Save parameter settings" sections of the notebook after running the "Initialize data dictionaries" step, but note that you should **not** rerun the "Read meta-parameters from the ND2 file" step.

## Load packages

In [None]:
# Flavell lab packages
using ND2Process
using GPUFilter
using WormFeatureDetector
using NRRDIO
using FlavellBase
using SegmentationTools
using ImageDataIO
using RegistrationGraph
using ExtractRegisteredData
using CaAnalysis
using BehaviorDataNIR
using UNet2D


# Other packages
using ProgressMeter
using PyCall
using PyPlot
using Statistics
using StatsBase
using DelimitedFiles
using Images
using Cairo
using Distributions
using DataStructures
using HDF5
using Interact
using WebIO
using Plots
# using GraphPlot
# using LightGraphs
# using SimpleWeightedGraphs
using Dates
using JLD2
using TotalVariation
using VideoIO
using Distributions
using MultivariateStats
using FFTW
using LinearAlgebra
using GLMNet
using InformationMeasures
using CUDA
using LsqFit
using Optim
using Rotations
using CoordinateTransformations
using ImageTransformations
using Interpolations
using H5Zblosc

# Parameter settings

Expected runtime: 1 minute

### Data parameters

This section includes dataset-specific parameters, including the filename, voxel spacing, laser intensity step-up, and channel identities. It must be run every time you initialize the pipeline. 

**Important:** If you already have dictionaries saved and are rerunning the pipeline, skip to "Initialize Data Dictionaries" after running this section, and do not rerun the "Pipeline parameters" section unless you have modified something here.

In [None]:
## ND2 parameters
ch_bluegreen = 1 # GCaMP and BFP
ch_red = 2 # mNeptune
n_rec = 1 # number of Illumination Sequence module recordings (.nd2) in the HDF5 file 
h5_confocal_time_lag = 0 # if this is the second confocal dataset using the same h5 file, set equal to number of frames in 1st dataset
spacing_lat = 0.54 # Spacing of each voxel in the xy plane. If different, it will likely be necessary to modify pipeline parameters.
spacing_axi = 0.54 # Spacing of each voxel in the z dimension. If different, it will likely be necessary to modify pipeline parameters.

## Server parameter
flv_c = 2 # version of flv-c server you are using

## Freely-moving parameters
max_graph_num = 800 # time point just before laser intensity is changed in freely-moving dataset. If not changed, set to length of the dataset.
blue_laser = [13,15] # array of blue laser values before vs after param["max_graph_num"] in the freely-moving dataset
green_laser = [15,17] # array of green laser values before vs after param["max_graph_num"] in the freely-moving dataset


## Data path parameters
prj = "prj_XYZ" # change to your project - ie: the location to save the data
path_raw_data = "/data1/$(prj)/data_raw/2023-DATASET" # change to the directory with all raw data files

datasets = Dict() # for each filter, update the corresponding entry in the dictionary the path to the ND2 file for that dataset (without the .nd2 extension)
datasets["freely_moving"] = "2023-DATASET"
datasets["all_red"] = "2023-DATASET"
datasets["OFP"] = "2023-DATASET"
datasets["mNeptune"] = "2023-DATASET"
datasets["mNeptune_gcamp"] = "2023-DATASET" # dataset with GCaMP bleedthrough
datasets["BFP"] = "2023-DATASET"


## You should not need to change anything below this line except possibly `reg_timepts[dataset] = 30`.
datasets_freely_moving = ["freely_moving"] # freely-moving datasets
dataset_central = "freely_moving" # central dataset - register other datasets to this. Must contain camera-alignment registration
datasets_register = ["all_red"] # datasets to register back to central dataset
dataset_immobilized_central = "all_red" # dataset to register immobilized datasets to, for NeuroPAL
@assert(dataset_immobilized_central in datasets_register, "Immobilized dataset must be registered to freely moving")

n_timepts_merge = Dict() # number of registrations to attempt back to the central dataset, for each registered dataset
for dataset in datasets_register
    n_timepts_merge[dataset] = 200
end

## Immobilized parameters
reg_timepts = Dict()
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    reg_timepts[dataset] = 30 # timepoint to register everything to, for the immobilized datasets
end


rotate_img_x = false # whether to rotate the image about the x-axis (if the worm was put on the slide the wrong way)

fm = datasets["freely_moving"]
path_root_process_freelymoving = "/data1/$(prj)/data_processed/$(fm)_output"
path_root_process_immobilized = "/data1/$(prj)/data_processed/$(fm)_output/neuropal"


path_neuropal_img = joinpath(path_root_process_immobilized, "NeuroPAL.nrrd")
path_neuropal_img_mNeptune_GCaMP = joinpath(path_root_process_immobilized, "NeuroPAL_mNeptune_GCaMP_bleedthrough.nrrd")
path_bfp_img = joinpath(path_root_process_immobilized, "BFP.nrrd")
path_ofp_img = joinpath(path_root_process_immobilized, "OFP.nrrd")
path_gcamp_img = joinpath(path_root_process_immobilized, "OFP_GCaMP.nrrd") # actual GCaMP
path_mNeptune_img = joinpath(path_root_process_immobilized, "mNeptune.nrrd")
path_mNeptune_gcamp_img = joinpath(path_root_process_immobilized, "mNeptune_GCaMP_bleedthrough.nrrd") # GCaMP bleedthrough
path_all_red_img = joinpath(path_root_process_immobilized, "all_red.nrrd")
path_neuron_img = joinpath(path_root_process_immobilized, "neuron_rois.nrrd")
create_dir(path_root_process_freelymoving)
create_dir(path_root_process_immobilized)

In [None]:
param_paths = Dict()
params = Dict()
data_dicts = Dict()
error_dicts = Dict();

### Pipeline parameters

This section includes parameters that influence how the pipeline will process the data. They are mostly heuristic parameters that can be tuned for optimal performance, as well as some parameters that change OpenMind compute settings. These are likely safe to leave as default unless you have significantly changed the data structure.

Also sets default file name parameters, which specify where intermediate parameters and outputs of the pipeline will be stored. It is generally not advised to change them.

Note that the first time you set up the pipeline, you will need to change the following entries to point to your files instead of Adam's:

- `param_path["path_unet_pred"]` should point to your `pytorch-3dunet` installation
- `param_path["path_unet_py_env"]` should point to the python environment corresponding to your `pytorch-3dunet` installation (usually Julia's `conda` environment)
- `param_path["path_dir_lock"]` should point to a lock directory (you need to make this directory if it doesn't exist)
- `om_user` should be your OpenMind username (kerberos)

Most pipeline parameters are stored in the dictionaries `param` and `param_path`. For useful functions for interacting with these dictionaries, see https://flavell-lab.github.io/ImageDataIO.jl/stable/.

In [None]:
for dataset in keys(datasets)
    exp_prefix = datasets[dataset]
    if dataset in datasets_freely_moving
        path_root_process = path_root_process_freelymoving
    else
        path_root_process = joinpath(path_root_process_immobilized, datasets[dataset])
    end
    
    if !(dataset in keys(params))
        params[dataset] = Dict()
    end
    param = params[dataset]
    
    if !(dataset in keys(param_paths))
        param_paths[dataset] = Dict{String,Any}()
    end
    param_path = param_paths[dataset]
    
    param_path["path_unet_pred"] = "/home/adam/src/pytorch-3dunet/pytorch3dunet/predict.py" # Path to the pytorch-3dunet installation on the local machine
    param_path["path_unet_param"] = "/data1/shared/dl_weights/3dunet_540nm_voxels/instance_segmentation_test.yaml" # Path to the UNet parameter file
    param_path["path_unet_py_env"] = "/home/adam/.julia/conda/3/bin/activate" # Path to a script that should be run to configure environment variables for the pytorch-3dunet
    param_path["path_transformix"] = (flv_c == 2) ? "/bin/transformix" : "/usr/local/bin/transformix" # Path to transformix executable on the LOCAL machine
    param_path["path_elastix_local"] = (flv_c == 2) ? "/bin/elastix" : "/usr/local/bin/elastix" # Path to elastix executable on the LOCAL machine

    
    param_path["path_head_unet_model"] = "/data1/shared/dl_weights/head_detector_unet_0622/unet2d-head-detector_best.pt" # path to head detection unet model
    param_path["path_2d_unet_param"] = "/data1/shared/dl_weights/behavior_nir/worm_segmentation_best_weights_0310.pt"

    param_path["path_dir_lock"] = "/home/adam/lock" # path to lock directory

    if dataset in datasets_freely_moving
        param_path["path_head_rotate"] = "/om2/group/flavell/script/registration/euler_registration/euler_head_rotate.py"
    else
        param_path["path_head_rotate"] = nothing
    end
    param_path["path_head_rotate_activity_marker"] = nothing
    param_path["path_elastix"] = "/om2/user/aaatanas/elastix-5.0.1/build/bin/elastix"
    param_path["path_run_elastix"] = "/om2/group/flavell/script/registration/run_elastix_command.sh"

    om_user = "aaatanas" # Your username on OpenMind
    param_path["path_om_home"] = "/om2/user/$om_user"
    param_path["path_om_env"] = "/home/$(om_user)/.bashrc" # path to user environment variable script on OpenMind
    param_path["path_om_data"] = joinpath(param_path["path_om_home"], "$(exp_prefix)_output")
    param_path["path_om_home_scripts"] = "/om2/user/$om_user"
    param_path["path_om_scripts"] = joinpath(param_path["path_om_home_scripts"], "$(exp_prefix)_output")

    param_path["path_om_euler_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_euler.txt" # Path to Euler parameter file on OpenMind
    param_path["path_om_euler_am_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_euler_output.txt" # Path to Euler parameter file on OpenMind
    param_path["path_om_affine_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_affine.txt" # Path to Affine parameter file on OpenMind
    
    if dataset in datasets_freely_moving
        param_path["path_om_bspline_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_freely_moving_bspline.txt" # Path to BSpline parameter file on OpenMind
        param["max_graph_num"] = max_graph_num # maximum length before graph split
        param["blue_laser"] = blue_laser
        param["green_laser"] = green_laser
        
        param["blue_zero_thresh"] = 4.5
        param["blue_min_laser"] = 5.2
        param["blue_max_interpolate"] = 2.0

        param["green_zero_thresh"] = 7.0
        param["green_min_laser"] = 7.7
        param["green_max_interpolate"] = 2.0

        intensity = h5read("/data1/shared/2022-05-11-laser-power.h5", "488nm/1/intensity")
        laser_perc = h5read("/data1/shared/2022-05-11-laser-power.h5", "488nm/1/laser_percent")

        intensity2 = h5read("/data1/shared/2022-05-11-laser-power.h5", "488nm/2/intensity")
        laser_perc2 = h5read("/data1/shared/2022-05-11-laser-power.h5", "488nm/2/laser_percent")
        
        param["blue_laser_intensity"] = (intensity .+ intensity2) ./ 2
        param["blue_laser_perc"] = laser_perc

        param["green_laser_intensity"] = h5read("/data1/shared/2022-05-11-laser-power.h5", "561nm/1/intensity")
        param["green_laser_perc"] = h5read("/data1/shared/2022-05-11-laser-power.h5", "561nm/1/laser_percent");
        param["reg_n_resolution"] = [3,3,4] # the number of euler, affine and BSpline registrations
        
        param["good_registration_resolutions"] = [(2,0), (2,1), (2,2), (2,3)] #= which registration resolutions
            are good enough to extract data from =#
        
        @assert(laser_perc == laser_perc2)
    else
        param_path["path_om_bspline_param"] = "/om2/group/flavell/script/registration/elastix_parameters/parameters_immobilized_bspline.txt" # Path to BSpline parameter file on OpenMind
        param["reg_timept"] = reg_timepts[dataset]
        param["reg_n_resolution"] = [0,0,2] # the number of euler, affine and BSpline registrations
        param["good_registration_resolutions"] = [(2,0), (2,1)] #= which registration resolutions
            are good enough to extract data from =#
        param["registration_resolution_neuropal"] = (2,3) # registration resolution for neuroPAL immobilized registration
    end

    
    param["nonsynced_roi_val"] = 4095

    param_path["path_om_tmp"] = joinpath(param_path["path_om_data"], "temp"); # Path to a temporary directory on OpenMind to store data
    
    param["spacing_axi"] = spacing_axi
    param["spacing_lat"] = spacing_lat
    param["n_z"] = 80
    param["z_range"] = 4:80

    
    ### crop
    param["crop_threshold_size"] = 20
    param["crop_threshold_intensity"] = 7.

    ### UNet
    param["seg_threshold_unet"] = 0.5 # The UNet output confidence threshold for a pixel to be counted as a neuron.
    param["seg_min_neuron_size"] = 7 # Neurons with fewer than this many voxels will be discarded.
    param["seg_threshold_watershed"] = [0.7, 0.8, 0.9] #= The image is re-segmented at these thresholds,
        and checked for neurons that were split - these neurons will be segmented by watershed.=#
    param["seg_watershed_min_neuron_sizes"] = [5,4,4] #= When the image is re-segmented, neurons smaller
        than the size for the corresponding threshold will be discarded.=#
    param["instance_seg_num_batches"] = 5 # number of batches of data to run simultaneously
        # each batch = number of threads. Lower number = less memory used, but more chance of bad parity

    ### Registration graph
    param["degree"] = 10 # degree of registration graph
    param["degree_dataset"] = 2 # degree of inter-dataset registration graph

    ### Head finder
    param["head_threshold"] = 0.5 # head detection UNet threshold
    param["head_max_distance"] = [20,35,35] #= distance thresholds for the blobification of the worm
        that finds the head location=#
    param["head_err_threshold"] = 50 # sets a warning flag in the blobification
    param["head_vc_err_threshold"] = 150 # sets a warning flag in the blobification
    param["head_edge_err_threshold"] = 1 # sets a warning flag in the blobification

    ### Worm curve finder
    param["worm_curve_n_pts"] = 14 # number of points in worm curve detector
    param["worm_curve_head_idx"] = 4 #= index of "head" point - ideally near the front of the
        brain but not on the tip of the nose=#
    param["worm_curve_tail_idx"] = 11 #= index of tail" point - ideally near the back of the
        brain but not on the ventral cord=#
    param["worm_curve_downscale"] = 2 # binning factor for worm curve detector

    ### Registration parameters
    param["reg_n_resolution_activity_marker"] = [3]


    ### Clustering parameters
    param["cluster_height_thresh"] = -0.002 # clustering threshold for multiple clusters will be merged together. Reasonable range -0.3 to 0
    param["cluster_overlap_thresh"] = 0.05 # tolerance for a cluster to have multiple ROIs from the same frame. Reasonable range 0 to 0.2

    ### Quality control and heuristic parameters
    param["smeared_neuron_threshold"] = 20000 # ROIs that get smeared out to more than this many pixels due to registration will be deleted
    param["activity_diff_threshold"] = 0.3 # threshold for difference in mNeptune activity of two ROIs to trigger a penalty
    param["watershed_error_penalty"] = 0.5 # has no effect unless watershed error detection is enabled
    param["quality_metric"] = "NCC" # registration quality metric
    param["reg_quality_threshold"] = 0.9 # quality threshold for registrations
    param["matrix_self_weight"] = 0.5 # weight of the diagonal of registration map matrix
    param["size_mismatch_penalty"] = 2 # penalty for ROIs being different sizes
    param["frac_detections_threshold"] = 0.5 # fraction of time points that an ROI candidates must be detected in


    ### Deconvolution parameters
    param["gcamp_decay_response"] = 0.52; # GCaMP decay time, in seconds
    
    param["duration"] = Dates.Time(4,0,0)
    param["duration_julia"] = "5-0"
    param["memory"] = 1
    param["cpu_per_task"] = 16
    param["ch_marker"] = ch_red
    param["ch_activity"] = ch_bluegreen
    param["list_ch"] = unique([param["ch_marker"], param["ch_activity"]])
    param["email"] = nothing
    param["use_sbatch"] = true
    param["server"] = "openmind7.mit.edu"
    param["server_dtn"] = "openmind-dtn.mit.edu"
    param["job_name"] = "elx"
    param["job_name_activity_marker"] = "elx_ch12"
    param["job_name_marker_activity"] = "elx_ch21"
    param["array_size"] = 499
    param["partition"] = "use-everything"
    param["elx_wait_delay"] = 3600
    param["lock_wait"] = 0.5
    param["user"] = om_user;
    
    param["FLIR_FPS"] = 20.0; # FPS of NIR camera
    param["bad_tracking"] = []; # points to delete (NIR timeframe) due to bad tracking
    param["n_rec"] = 1 # number of Illumination Sequence module recordings (.nd2) in the HDF5 file

    ### Spline parametersparame
    param["num_center_pts"] = 1000 # number of points in raw worm spline
    param["segment_len"] = 7 # length of raw spline per segment of distance-corrected spline
    param["img_label_size"] = (480,360) # size of UNet input image

    ### Spline self-intersect correction parameters
    param["med_axis_shorten_threshold"] = 14 # if the medial axis shortens by at least this amount, trigger self-intersect correction
    param["nose_confidence_threshold"] = 0.99 # threshold for UNet's confidence of nose location for it to be used to crop medial axis
    param["nose_crop_threshold"] = 20 # maximum number of points in medial axis that can be cropped to the nose
    param["med_axis_self_dist_threshold"] = 40 # proximity of far-away medial axis points that will trigger self-intersect correction
    param["loop_dist_threshold"] = 60 # distance between medial axis points beyond which collisions/proximity trigger self-intersect correction
    param["max_med_axis_delta"] = Inf # change in pixels from points in one time point's medial axis to the next to trigger self-intersect correction
    param["trim_head_tail"] = 15; # amount to trim head and tail of worm boundary to allow motion of worm
    param["boundary_thickness"] = 5; # thickness of worm boundary
    param["close_pts_threshold"] = 30; # minimum distance on spline for boundary to be enforced
    param["worm_thickness_pad"] = 3 # gap between worm and worm boundary
    param["min_len_percent"] = 90; # lower percentile of non-self-intersect time points to be included in worm thickness computation
    param["max_len_percent"] = 98; # upper percentile of non-self-intersect time points to be included in worm thickness computation
    
    ### Body angle parameters
    param["body_angle_pos_lag"] = 2; # how far to enforce angle-continuity along a time point's spline
    param["body_angle_t_lag"] = 40; # how far to enforce angle-continuity along a particular spline-angle across time points
    param["max_pt"] = 31; # crop the spline at this location

    param["nose_pts"] = [1,2,3]; # points in spline corresponding to nose angle
    param["head_pts"] = [1,5,8]; # points in spline corresponding to head angle
    param["seg_range"] = param["head_pts"][3]:param["max_pt"]; # range of points for computing worm centroid

    ### Reversal parameters
    param["rev_len_thresh"] = 2; # length threshold for reversal events
    param["rev_v_thresh"] = -0.005; # velocity threshold for reversal events

    ### Angular velocity parameters
    param["filt_len_angvel"] = 150

    ### Pumping parameters
    param["filt_len_pumping"] = 40

    ### Other parameters
    param["num_lag"] = 1;  # number of NIR time points to smooth velocity and other variables with SG filter
    param["v_stage_m_filt"] = 10; # length of NIR time point window for position filtering
    param["v_stage_λ_filt"] = 250.0; # parameter for position filtering
    
    param["num_eigenworm"] = 5 # number of eigenworms to compute

    # behavior variables for MI and neuron prediction GLM in addition to eigenworm and body angle
    param["vars"] = ["velocity_stage", "worm_angle", "head_angle", "nose_angle", "angular_velocity", "worm_curvature"]

    # behavior variables for multi-dataset concatenation
    param["concat_vars"] = ["angular_velocity", "worm_angle", "head_angle", "nose_angle", "speed_stage", "worm_curvature",
            "ventral_worm_curvature", "nose_curling", "velocity_stage",
            "zeroed_x_confocal", "zeroed_y_confocal", "body_angle", "body_angle_all", "body_angle_absolute", "rev_times"]
    param["t_concat_vars"] = ["all_rev"];
    param["nir_concat_vars"] = ["nir_velocity_stage", "nir_head_angle", "nir_body_angle", "nir_body_angle_all",
            "nir_body_angle_absolute", "nir_nose_curling",
            "zeroed_x", "zeroed_y"]

    param["beta_values"] = [Int32(round(1.6^x)) for x=0:9] # inverse values of L1 norm penalty to use for neuron prediction GLM
    param["t_thresh_vals_base"] = -100:50:100; # thresholds between training and validation data to use for neuron prediction GLM

    param["rev_neuron_thresh"] = 0.5 # goodness of fit required to classify a neuron as a reversal neuron
    param["type1_thresh"] = 2; # maximum reversal length threshold for a type-1 reversal neuron

    param["fwd_neuron_thresh"] = 0.75; # goodness of fit required to classify a neuron as a forward neuron
    param["turning_neuron_thresh"] = 0.75; # goodness of fit required to classify a neuron as a turning neuron
    
    
    
    param_path["path_param_path"] = joinpath(path_root_process, "param_path.jld2")

    param_path["path_nd2"] = joinpath(path_raw_data, exp_prefix*".nd2")
    
    if dataset in datasets_freely_moving
        param_path["path_h5"] = joinpath(path_raw_data, exp_prefix*".h5")
    end
    path_dir_nd2, name_nd2 = splitdir(param_path["path_nd2"])
    param_path["img_prefix"] = splitext(name_nd2)[1]
    path_root_raw = path_dir_nd2
    list_ch = param["list_ch"]

    dir_nrrd = "NRRD"
    dir_MIP = "MIP"
    dir_nrrd_shearcorrect = "NRRD_shearcorrect"
    dir_MIP_shearcorrect = "MIP_shearcorrect"
    dir_nrrd_crop = "NRRD_cropped"
    dir_MIP_crop = "MIP_cropped"
    dir_nrrd_filt = "NRRD_filtered"
    dir_MIP_filt = "MIP_filtered"

    dir_roi = "img_roi"
    dir_roi_watershed = "img_roi_watershed"
    dir_roi_watershed_uncropped = "img_roi_watershed_uncropped"
    dir_marker_signal = "marker_signal"
    dir_centroid = "centroids"
    dir_unet_data = "unet_data"

    dir_worm_curve = "worm_curve"
    dir_reg = "Registered"
    dir_reg_activity_marker = "Registered_ch1to2"
    dir_reg_marker_activity = "Registered_ch2to1"

    dir_transformed = "img_roi_transformed"
    dir_transformed_activity_marker = "NRRD_activity_transformed"
    dir_activity_signal = "gcamp_activity_transformed"
    dir_cmd = "elx_commands"
    dir_cmd_am = "elx_commands_ch1to2"
    dir_cmd_ma = "elx_commands_ch2to1"
    dir_cmd_array = "elx_commands_array"
    dir_log = "log"

    dir_roi_candidates = "roi_candidates"

    txt_file_extension = ".txt"
    name_head_pos = "head_pos"
    name_elastix_difficulty = "elastix_difficulty"
    name_reg_prob = "registration_problems"
    name_reg_quality = "registration_quality"
    name_qdict = "q_dict"
    name_best_reg = "best_reg"
    name_roi_overlap = "roi_overlap"
    name_roi_activity_diff = "roi_activity_diff"
    name_label_map = "label_map"
    name_inv_map = "inv_map"
    name_data_dict = "data_dict.jld2"
    name_param = "param.jld2"
    name_param_path = "param_path.jld2"
    name_error_dict = "error_dict.jld2"

    param_path["path_root_process"] = path_root_process
    param_path["name_head_rotate_logfile"] = "headrotate.log"
    param_path["name_transform_activity_marker"] = "TransformParameters.0.txt"
    param_path["name_transform_activity_marker_avg"] = "TransformParameters.0_avg.txt"
    param_path["name_transform_activity_marker_roi"] = "TransformParameters.0_roi.txt"
    param_path["key_transform_parameters"] = "TransformParameters"
    
    if !(dataset in datasets_freely_moving)
        param_path["path_dir_reg_activity"] = joinpath(param_path["path_root_process"], "Registered_G")
        param_path["path_om_reg_activity"] = joinpath(param_path["path_om_data"], "Registered_G")
        param_path["path_dir_cmd_activity"] = joinpath(param_path["path_root_process"], "elx_commands_G")
        param_path["path_om_cmd_activity"] = joinpath(param_path["path_om_data"], "elx_commands_G")
        param["job_name_activity"] = "elx_ch1"
        param_path["path_nrrd_avg"] = joinpath.(param_path["path_root_process"], ["avg_ch1.nrrd", "avg_ch2.nrrd"])
        param_path["path_nrrd_avg_ch1toch2"] = joinpath(param_path["path_root_process"], "avg_ch1toch2")
        param_path["path_nrrd_avg_ch1toch2_reg"] = joinpath(param_path["path_root_process"], "avg_ch1toch2_reg")
        param_path["nrrd_avg_res"] = "result.2.R0.nrrd"
        param_path["parameter_files_local"] = joinpath.("/data1/shared/elastix_parameters", ["parameters_euler_turbo_output.txt", "parameters_immobilized_affine.txt", "parameters_immobilized_bspline_turbo.txt"])
        param_path["path_dir_reg_neuropal"] = joinpath(param_path["path_root_process"], "Registered_to_$(dataset_immobilized_central)")
        param_path["euler_head_params"] = joinpath(param_path["path_dir_reg_neuropal"], "euler.txt")
        param_path["path_nrrd_avg_translate"] = joinpath.(param_path["path_root_process"], ["avg_ch1_translate.nrrd", "avg_ch2_translate.nrrd"])
        param_path["path_nrrd_avg_ch1toch2_translate"] = joinpath(param_path["path_root_process"], "avg_ch1toch2_translate.nrrd")
    end


    param_path["path_dir_unet_data"], param_path["path_dir_nrrd"], param_path["path_dir_MIP"], 
        param_path["path_dir_nrrd_shearcorrect"], param_path["path_dir_MIP_shearcorrect"], 
        param_path["path_dir_nrrd_crop"],
        param_path["path_dir_nrrd_filt"], param_path["path_dir_MIP_crop"], param_path["path_dir_MIP_filt"] = 
            joinpath.(path_root_process, [dir_unet_data, dir_nrrd, dir_MIP, dir_nrrd_shearcorrect, dir_MIP_shearcorrect, dir_nrrd_crop, dir_nrrd_filt, dir_MIP_crop, dir_MIP_filt])

    param_path["path_dir_transformed"], param_path["path_dir_transformed_activity_marker"], param_path["path_dir_activity_signal"] =
        joinpath.(path_root_process, [dir_transformed, dir_transformed_activity_marker, dir_activity_signal])

    param_path["path_dir_roi"], param_path["path_dir_roi_watershed"], param_path["path_dir_roi_watershed_uncropped"], 
        param_path["path_dir_marker_signal"], param_path["path_dir_centroid"], param_path["path_dir_unet_data"] =
            joinpath.(path_root_process, [dir_roi, dir_roi_watershed, dir_roi_watershed_uncropped, dir_marker_signal,
                dir_centroid, dir_unet_data])

    param_path["path_dir_worm_curve"], param_path["path_dir_reg"], param_path["path_dir_reg_activity_marker"], 
        param_path["path_dir_reg_marker_activity"],
        param_path["path_roi_candidates"], param_path["path_dir_cmd"], param_path["path_dir_cmd_am"], param_path["path_dir_cmd_ma"],
        param_path["path_dir_cmd_array"], param_path["path_data_dict"], 
        param_path["path_param"], param_path["path_param_path"], param_path["path_error_dict"] =
            joinpath.(path_root_process,
                [dir_worm_curve, dir_reg, dir_reg_activity_marker, dir_reg_marker_activity, dir_roi_candidates, dir_cmd, 
                    dir_cmd_am, dir_cmd_ma, dir_cmd_array, name_data_dict, name_param,
                    name_param_path, name_error_dict])

    param_path["path_om_nrrd_filt"], param_path["path_om_nrrd"], param_path["path_om_reg"], 
        param_path["path_om_reg_activity_marker"], param_path["path_om_reg_marker_activity"],
        param_path["path_om_log"] = joinpath.(param_path["path_om_data"],
            [dir_nrrd_filt, dir_nrrd, dir_reg, dir_reg_activity_marker, dir_reg_marker_activity, dir_log])

    param_path["path_om_cmd"], param_path["path_om_cmd_am"], param_path["path_om_cmd_ma"], 
        param_path["path_om_cmd_array"] = joinpath.(param_path["path_om_scripts"], 
            [dir_cmd, dir_cmd_am, dir_cmd_ma, dir_cmd_array])

    param_path["path_head_pos"], param_path["path_elastix_difficulty"], param_path["path_reg_prob"], 
        param_path["path_qdict"], param_path["path_best_reg"], param_path["path_roi_overlap"], 
        param_path["path_roi_activity_diff"], param_path["path_reg_quality"], param_path["path_label_map"], param_path["path_inv_map"] = joinpath.(path_root_process, 
            [name_head_pos, name_elastix_difficulty, name_reg_prob, name_qdict, name_best_reg,
            name_roi_overlap, name_roi_activity_diff, name_reg_quality, name_label_map,
            name_inv_map] .* txt_file_extension)
    

    param_path["path_dir_mask"] = nothing
    param_path["path_om_mask"] = nothing

    param_path["parameter_files"] = [param_path["path_om_euler_param"], param_path["path_om_affine_param"], param_path["path_om_bspline_param"]]
    param_path["parameter_files_activity_marker"] = [param_path["path_om_euler_am_param"]]

    create_dir(param_path["path_root_process"])
    if !(dataset in datasets_freely_moving)
        create_dir(param_path["path_nrrd_avg_ch1toch2"])
        create_dir(param_path["path_dir_reg_neuropal"])
    end

    create_dir(param_path["path_dir_transformed_activity_marker"])
    create_dir(param_path["path_dir_transformed"])
    create_dir(param_path["path_roi_candidates"])
    create_dir(param_path["path_dir_activity_signal"])
    create_dir(param_path["path_dir_marker_signal"])
end

### Read meta-parameters from ND2 file

The following section initializes variables based on the ND2 parameters.
**Do not run it again if you have restarted the kernel on the pipeline.**

In [None]:
for dataset = keys(datasets)
    param_path = param_paths[dataset]
    param = params[dataset]
    println("Reading $(dataset) ND2 file from: "* param_path["path_nd2"])
    n_x, n_y, n_z, n_t, n_c = nd2dim(param_path["path_nd2"])
    
    param["x_range"] = 1:n_x
    param["y_range"] = 1:n_y
    param["θ"] = nothing # rotation
    
    t_range = 1:floor(Int32, n_t / param["n_z"]);
    param["num_detections_threshold"] = param["frac_detections_threshold"] * length(t_range);
    param["t_range"] = t_range
    param["max_t"] = maximum(param["t_range"])
    param["t_thresh_vals"] = param["t_thresh_vals_base"] .+ floor(Int32, param["max_t"] / 2)
    # if dataset in datasets_freely_moving
    #     @assert(param["max_graph_num"] >= param["max_t"] || 2*param["max_graph_num"] == param["max_t"], "Pipeline does not currently support a laser intensity change anywhere other than halfway through the dataset.")
    # end
end

In [None]:
for dataset = datasets_register
    param_paths[dataset_central]["path_dir_cmd_$dataset"] = joinpath(param_paths[dataset_central]["path_root_process"], "elx_commands_$dataset")
    param_paths[dataset_central]["path_om_cmd_$dataset"] = joinpath(param_paths[dataset_central]["path_om_data"], "elx_commands_$dataset")
    params[dataset_central]["job_name_$dataset"] = dataset
    param_paths[dataset_central]["path_elastix_difficulty_$(dataset)"] = replace(param_paths[dataset_central]["path_elastix_difficulty"], ".txt"=>"_$(dataset).txt")
    param_paths[dataset_central]["path_reg_prob_$(dataset)"] = replace(param_paths[dataset_central]["path_reg_prob"], ".txt"=>"_$(dataset).txt")
    param_paths[dataset_central]["path_dir_reg_$(dataset)"] = param_paths[dataset_central]["path_dir_reg"]*"_$(dataset)"
    param_paths[dataset_central]["path_om_reg_$(dataset)"] = param_paths[dataset_central]["path_om_reg"]*"_$(dataset)"
    param_paths[dataset_central]["path_dir_transformed_$(dataset)"] = param_paths[dataset_central]["path_dir_transformed"]*"_$(dataset)";
    create_dir(param_paths[dataset_central]["path_dir_transformed_$(dataset)"])
end

#### Save parameter settings

This section saves all of the parameter settings to keep a record of what parameters were used to run the pipeline. By default, it will not overwrite previous saved parameter files - you can uncomment the last block to disable this. Note that the `param["t_range"]` variable is dynamically updated as the code progresses and certain time points are removed for being problematic. Reverting back to the default version of this parameter can cause the code to crash, so if you intend to save a new version of the `param` dictionary, be sure it has the correct `t_range` variable already loaded.

In [None]:
## Uncomment this code to overwrite previous dictionaries.
for dataset in keys(datasets)
    if "get_basename" in keys(param_paths[dataset])
        delete!(param_paths[dataset], "get_basename")
    end

    param_path = param_paths[dataset]
    param = params[dataset]
    JLD2.@save(param_path["path_param_path"], param_path)
    add_get_basename!(param_path)
    JLD2.@save(param_path["path_param"], param)
end

### Initialize data dictionaries

By default, the code tries to load parameter and data dictionaries from a previous run of the pipeline. If it can't find them, it will initialize data and error dictionaries to be empty. The error dictionary loading usually crashes, so it is left commented out by default.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        path_root_process = path_root_process_freelymoving
    else
        path_root_process = joinpath(path_root_process_immobilized, datasets[dataset])
    end
    path_param_path = joinpath(path_root_process, "param_path.jld2")
    
    if isfile(path_param_path)
        f = JLD2.jldopen(path_param_path)
        param_paths[dataset] = f["param_path"]
        close(f)
    end
    param_path = param_paths[dataset]

    change_rootpath!(param_path, path_root_process)

    if isfile(param_path["path_param"])
        f = JLD2.jldopen(param_path["path_param"])
        params[dataset] = f["param"]
        close(f)
    end
    
    param = params[dataset]

    add_get_basename!(param_path)
    
    if isfile(param_path["path_data_dict"])
        f = JLD2.jldopen(param_path["path_data_dict"])
        data_dicts[dataset] = f["data_dict"]
        close(f)
    else
        data_dicts[dataset] = Dict()
    end


    data_dict = data_dicts[dataset]

    # in case of errors, initialize t_range parameter to last available value
    if !haskey(param, "t_range") || length(param["t_range"]) == 0
        if "head_pos" in keys(data_dict) && length(data_dict["head_pos"]) > 0
            param["t_range"] = sort(collect(keys(data_dict["head_pos"])))
        elseif "dict_param_crop_rot" in keys(data_dict) && length(data_dict["dict_param_crop_rot"]) > 0
            param["t_range"] = sort(collect(keys(data_dict["dict_param_crop_rot"])))
        elseif "shear_params" in keys(data_dict) && length(data_dict["shear_params"]) > 0
            param["t_range"] = sort(collect(keys(data_dict["shear_params"])))
        else
            n_x, n_y, n_z, n_t, n_c = nd2dim(path_nd2)
            param["t_range"] = 1:floor(Int32, n_t / param["n_z"]);
        end
    end
    #if isfile(param_path["path_error_dict"])
    #    f = JLD2.jldopen(param_path["path_error_dict"])
    #    error_dict = f["error_dict"]
    #    close(f)
    #else
    error_dicts[dataset] = Dict()
    #end
end

for dataset in datasets_register
    path = joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2")
    if isfile(path)
        f = JLD2.jldopen(path)
        data_dicts["$(dataset)_to_central"] = f["data_dict"]
        close(f)
    else
        data_dicts["$(dataset)_to_central"] = Dict()
    end
    error_dicts["$(dataset)_to_central"] = Dict()
end

In [None]:
# if "confocal_to_nir" in keys(data_dicts["freely_moving"])
#     vec_to_confocal = vec -> nir_vec_to_confocal(vec, data_dict["confocal_to_nir"], param["max_t"])
# end

# Pre-processing

See https://flavell-lab.github.io/ImageDataIO.jl/stable/, as well as the packages mentioned in each section.

### Sync NIR to confocal, find heat-stims

Expected runtime: 10 seconds.

This section matches timestamps between the confocal and NIR videos, so that neural activity and behavior can be matched across time.

If this section crashes, it is very likely that your dataset is unusable.

In [None]:
get_timing_info!(data_dicts["freely_moving"], params["freely_moving"], param_paths["freely_moving"]["path_h5"], h5_confocal_time_lag);

You can preview the dataset to view z-projected worm images here:

In [None]:
let
    dataset = "freely_moving"
    ch_view = 2 # 1 is blue or green, 2 is red
    figure(figsize=(10,10))
    stack = nd2preview(param_paths[dataset]["path_nd2"], return_data=true, z_crop=params[dataset]["z_range"], ch=ch_view, n_z=params[dataset]["n_z"]);
end

### Saving to NRRD

Expected runtime: 1 hour.

For more information, please refer to https://github.com/flavell-lab/ND2Process.jl

This section crops out piezo motion artifacts and saves each time point in the ND2 confocal recording as an NRRD file. If this section crashes, it may mean that you need to update the parameter `param["n_z"]`, or that your confocal recording terminated prematurely.

In [None]:
for dataset in keys(datasets)
    param_path = param_paths[dataset]
    param = params[dataset]
    nd2_to_nrrd(param_path["path_nd2"], param_path["path_root_process"], param["spacing_lat"], param["spacing_axi"], true,
        x_crop=param["x_range"], y_crop=param["y_range"], z_crop=param["z_range"], θ=param["θ"], chs=param["list_ch"], n_z=param["n_z"],
        NRRD_dir_name=param_path["path_dir_nrrd"], MIP_dir_name=param_path["path_dir_MIP"])
end

You can view the raw data by uncommenting this code:

In [None]:
let
    dataset = "freely_moving"
    t = 250
    ch = 2 # 1 is blue or green, 2 is red
    img = read_img(NRRD(joinpath(param_paths[dataset]["path_dir_nrrd"], param_paths[dataset]["get_basename"](t,ch)*".nrrd")));
    view_roi_3D(img, nothing, nothing, raw_contrast=2)
end

### Shear Correction

Expected runtime: 1 hour. GPU-accelerated.

The worm motion introduces shear into the frames, which needs to be corrected for subsequent steps of the pipeline to work well. Only the freely-moving dataset requires shear correction.

For more information, please see https://github.com/flavell-lab/FFTRegGPU.jl

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["shear_params"] = Dict()
    shear_correction_nrrd!(param_paths[dataset], params[dataset], params[dataset]["ch_marker"], data_dicts[dataset]["shear_params"]);
end

You can view the shear-corrected data by running this code. The shear-correction operates in the `z`-dimension, so it's recommended to set `axis` to 1 or 2 to be able to view the neurons in the `z`-axis.

In [None]:
let
    t = 250
    ch = params["freely_moving"]["ch_marker"]
    img = read_img(NRRD(joinpath(param_paths["freely_moving"]["path_dir_nrrd_shearcorrect"], param_paths["freely_moving"]["get_basename"](t,ch)*".nrrd")));
    view_roi_3D(img, nothing, nothing, raw_contrast=2, axis=2)
end

### Crop and rotate images

Expected runtime: 15 minutes.

We next crop and rotate the images to avoid spending computation time dealing with empty pixels. Since we are dealing with freely moving animals, we cannot do this globally as allowed in `nd2_to_nrrd`, so we must instead do it on a frame-by-frame basis, modifying the cropping parameters each time to account for the worm's motion.

This code will usually give a warning about out-of-focus timepoints. If there are too many of them (eg: several hundred), your dataset likely had severe focus problems and it may give low quality data.

For more information, please refer to https://flavell-lab.github.io/SegmentationTools.jl/stable/crop

In [None]:
for dataset = keys(datasets)
    data_dicts[dataset]["dict_param_crop_rot"] = Dict()
    nrrd_key = (dataset in datasets_freely_moving) ? "path_dir_nrrd_shearcorrect" : "path_dir_nrrd"
    error_dicts[dataset]["dict_error_crop_rot"], data_dicts[dataset]["focus_issues"] = crop_rotate!(param_paths[dataset], params[dataset], params[dataset]["t_range"], [params[dataset]["ch_marker"]], data_dicts[dataset]["dict_param_crop_rot"], save_MIP=true, nrrd_dir_key=nrrd_key)
end

### Delete bad frames
Remove all time points where the worm was not visible or otherwise unusable. (Do not delete out of focus frames.)

In [None]:
for dataset = keys(datasets)
    param = params[dataset]
    param["t_range"] = [t for t in param["t_range"] if !(t in keys(error_dicts[dataset]["dict_error_crop_rot"])) || error_dicts[dataset]["dict_error_crop_rot"][t] == ErrorException("Worm out of focus")]
    data_dicts[dataset]["t_range"] = param["t_range"]
    JLD2.@save(param_paths[dataset]["path_param"], param)
end

### Total-Variation Noise Filtering

Expected runtime: 10 minutes. GPU-accelerated.

This script filters the cropped data in the marker (red) channel to denoise it and make it easier for the registration to converge. The channel alignment registration does not use filtered data.

For more information, please refer to https://github.com/flavell-lab/GPUFilter.jl

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

In [None]:
for dataset = keys(datasets)
    filter_nrrd_gpu(param_paths[dataset], param_paths[dataset]["path_dir_nrrd_crop"], params[dataset]["t_range"], [params[dataset]["ch_marker"]], param_paths[dataset]["get_basename"])
end

In [None]:
CUDA.reclaim()
GC.gc()

### Finding head location

Expected runtime: 20 minutes. GPU-accelerated.

The first step in determining which timepoints to register to each other is finding the location of the worm's head in the frames, and frames where the head cannot be found from the analysis. This is done with a custom-trained 2D UNet.

For more information, please refer to https://flavell-lab.github.io/SegmentationTools.jl/stable/find_head

In [None]:
for dataset = keys(datasets)
    param_path = param_paths[dataset]
    data_dict = data_dicts[dataset]
    param = params[dataset]
    error_dict = error_dicts[dataset]
    
    nrrd_dir = (dataset in datasets_freely_moving) ? "path_dir_nrrd_shearcorrect" : "path_dir_nrrd"
    img_size = size(read_img(NRRD(joinpath(param_path[nrrd_dir], param_path["get_basename"](1,2)*".nrrd"))))
    model = create_model(1,1,8,param_path["path_head_unet_model"]);
    data_dict["head_pos"], error_dict["head_errs"] = find_head_unet(param_path, param, data_dict["dict_param_crop_rot"], model, img_size, nrrd_dir=nrrd_dir);
    if !(dataset in datasets_freely_moving)
        data_dict["head_pos_uncropped"], error_dict["head_errs_uncropped"] = find_head_unet(param_path, param, data_dict["dict_param_crop_rot"], model, img_size, nrrd_dir=nrrd_dir, crop=false);
    end 
end;

In [None]:
for dataset = datasets_freely_moving
    param_path = param_paths[dataset]
    param = deepcopy(params[dataset])
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    param["t_range"] = [t for t in keys(data_dicts["freely_moving"]["dict_param_crop_rot"]) if "crop" in keys(data_dicts["freely_moving"]["dict_param_crop_rot"][t]) && !(t in keys(data_dict["head_pos"]))]
    nrrd_dir = "path_dir_nrrd"
    img_size = size(read_img(NRRD(joinpath(param_path[nrrd_dir], param_path["get_basename"](1,2)*".nrrd"))))
    model = create_model(1,1,8,param_path["path_head_unet_model"]);
    head_pos, error_dict["head_errs_retry"] = find_head_unet(param_path, param, data_dict["dict_param_crop_rot"], model, img_size, nrrd_dir=nrrd_dir);
    for t in keys(head_pos)
        data_dict["head_pos"][t] = head_pos[t]
    end
end

This line of code tells you the fraction of times that the head-detection UNet was successful at determining the worm's head location. If this number is below 90%, it usually means that your dataset is very low quality. If you made a new strain and are consistently getting a value below 90%, you may need to retrain the 2D UNet on that strain.

In [None]:
length(data_dicts["freely_moving"]["head_pos"]) / params["freely_moving"]["max_t"]

### Manually adjusting head position

This line of code lets you manually examine the head UNet output at individual timepoints in individual datasets. If you need to manually fix errors, you can comment out line 15 (`hp = data_dicts[dataset]["head_pos"][t]`) and uncomment line 16 (`hp = [x,y]`), replacing `x` and `y` with your manual estimates of the head position.

In [None]:
let
    dataset = "all_red"
    t = reg_timepts[dataset] # you may want to set this to reg_timepts[dataset] for immobilized datasets
    println("Time point displayed: $t")
    contrast = 1
    
    path_nrrd = joinpath(param_paths[dataset]["path_dir_nrrd_crop"], param_paths[dataset]["get_basename"](t,2) * ".nrrd")
    img = maxprj(read_img(NRRD(path_nrrd)), dims=3)#[1:322,1:210]
    
    img_max = percentile(reshape(img, (prod(size(img)),)), 100-contrast)
    img_min = percentile(reshape(img, (prod(size(img)),)), contrast)
    
    img_hp = zeros(size(img))
    println("Image size: $(size(img))")
    hp = data_dicts[dataset]["head_pos"][t]
#     hp = [x,y]
    println("Detected head position: $(hp)")
    img_hp[hp[1]-5:hp[1]+5,
            hp[2]-5:hp[2]+5] .= 1
    
    RGB.(max.(0,min.(1,(img .- img_min) ./ (img_max - img_min))), img_hp, 0)
end

If you need to manually adjust any incorrect or missing head positions, you can do so here by uncommenting this code and replacing `dataset`, `t`, `x`, and `y` with the appropriate values:

In [None]:
# data_dicts[dataset]["head_pos"][t] = [x,y]

### Saving head position

In [None]:
for dataset = datasets_freely_moving
    param = params[dataset]
    param["t_range"] = [t for t in param["t_range"] if (t in keys(data_dicts[dataset]["head_pos"]))]
    data_dicts[dataset]["t_range"] = param["t_range"]
    JLD2.@save(param_paths[dataset]["path_param"], param)
end

The following code will display a warning message if the head detector failed on any timepoints in any datasets that are being registered back to the freely-moving dataset. If this happens, it is recommended to run the code in "Manually adjusting head position" for `reg_timepts[dataset]"` and ensure that the value is correct, or modify it if it is not.

In [None]:
for dataset = keys(datasets)
    f = open(param_paths[dataset]["path_head_pos"],"w")
    for t in params[dataset]["t_range"]
        if !(t in keys(data_dicts[dataset]["head_pos"]))
            if dataset in datasets_register
                @warn("$(dataset) timepoint $(t) head location not found.")
            end
            continue
        end
        hp = data_dicts[dataset]["head_pos"][t]
        write(f, "$(t)    $(hp[1]) $(hp[2])\n")
    end
    close(f)
end;

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

# Registration

### Registration Graph

Expected runtime: 5 minutes.

This script determines which frames are similar to each other and generates a list of pairs of time points to register to each other. For more information, see https://flavell-lab.github.io/WormFeatureDetector.jl, https://github.com/flavell-lab/WormCurveFinder.jl, and https://flavell-lab.github.io/RegistrationGraph.jl.

In [None]:
for dataset = keys(datasets)
    data_dicts[dataset]["registration_problems"] = []
    if dataset in datasets_freely_moving
        continue
    end
    for t=1:params[dataset]["max_t"]
        if t!=params[dataset]["reg_timept"]
            push!(data_dicts[dataset]["registration_problems"], (t,params[dataset]["reg_timept"]))
        end
    end
end

In [None]:
params["freely_moving"]["t_range"] = sort(collect(keys(data_dicts["freely_moving"]["head_pos"])));

In [None]:
for dataset in datasets_register
    data_dicts["$(dataset)_to_central"] = Dict()
    data_dicts["$(dataset)_to_central"]["max_t"] = params["freely_moving"]["max_t"]+params[dataset]["max_t"]
    data_dicts["$(dataset)_to_central"]["curves"] = Vector{Tuple{Vector{Float64}, Vector{Float64}}}(undef, data_dicts["$(dataset)_to_central"]["max_t"])
    heur! = (t1, t2) -> elastix_difficulty_wormcurve!(data_dicts["$(dataset)_to_central"]["curves"], params["freely_moving"], param_paths["freely_moving"], param_paths[dataset], t1, t2, params["freely_moving"]["ch_marker"], save_curve_fig=false, max_fixed_t=params["freely_moving"]["max_t"]);
    data_dicts["$(dataset)_to_central"]["t_range"] = collect(deepcopy(params["freely_moving"]["t_range"]))
    append!(data_dicts["$(dataset)_to_central"]["t_range"], params[dataset]["t_range"].+params["freely_moving"]["max_t"])
    data_dicts["$(dataset)_to_central"]["elastix_difficulty"] = generate_elastix_difficulty(param_paths["freely_moving"]["path_elastix_difficulty_$(dataset)"], data_dicts["$(dataset)_to_central"]["t_range"], heur!);
end

In [None]:
for dataset in datasets_register
    if dataset in datasets_freely_moving
        error("Registering freely-moving datasets together not yet supported in this version of ANTSUN. Please see old ANTSUN_merge.ipynb script.")
    else
        data_dicts["$(dataset)_to_central"]["registration_problems"] = []
        best_timepts = sort(1:length(params[dataset_central]["t_range"]), by=i->data_dicts["$(dataset)_to_central"]["elastix_difficulty"][1:end-length(params[dataset]["t_range"]), length(params[dataset_central]["t_range"])+1+params[dataset]["reg_timept"]][i])
        for i=1:n_timepts_merge[dataset]
            push!(data_dicts["$(dataset)_to_central"]["registration_problems"], (params[dataset]["reg_timept"], params[dataset_central]["t_range"][best_timepts[i]]))
        end
    end
end

In [None]:
for dataset in datasets_freely_moving
    data_dict = data_dicts[dataset]
    param = params[dataset]
    param_path = param_paths[dataset]
    data_dict["curves"] = Vector{Tuple{Vector{Float64}, Vector{Float64}}}(undef, param["max_t"])
    heur! = (t1, t2) -> elastix_difficulty_wormcurve!(data_dict["curves"], param, param_path, param_path, t1, t2, param["ch_marker"], save_curve_fig=true);
    data_dict["elastix_difficulty"] = fill(Inf, (param["max_t"], param["max_t"]))
    data_dict["elastix_difficulty"][param["t_range"], param["t_range"]] .= generate_elastix_difficulty(param_path["path_elastix_difficulty"], param["t_range"], heur!);
end;

If the following code displays any `Disconnected graph` warnings, they are usually safe to ignore if they are due to `Disconnected for merge`, but if you see any output of the form `Disconnected for i=`, that represents a serious problem and may require manually adjusting the registration graph to ensure it is connected. It is very likely that you will get such a message if you try running `ANTSUN` on fully-immobilized animals; in this case it is recommended to bypass the registration graph computation entirely and manually specify the graph.

In [None]:
## THIS CODE WON'T WORK WITH >2 MERGES

for dataset in datasets_freely_moving
    param = params[dataset]
    param_path = param_paths[dataset]
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    for i=1:(param["max_graph_num"] < param["max_t"] ? 2 : 1)
        path_diff_i = replace(param_path["path_elastix_difficulty"], ".txt"=>"_$(i).txt")
        if i == 1
            t_inc = [t for t in param["t_range"] if t <= param["max_graph_num"]]
        else
            t_inc = [t for t in param["t_range"] if t > param["max_graph_num"]]
        end

        open(path_diff_i, "w") do f
            write(f, string(t_inc)*"\n")
            write(f, replace(string(data_dict["elastix_difficulty"][t_inc, t_inc]), ";"=>"\n"))
        end

        graph = load_graph(path_diff_i, func=identity);
        subgraph, unconnected = make_voting_subgraph(graph, param["degree"])
        
        if length(unconnected) > 0
            println("Disconnected for i=$i: $unconnected")
        end
        output_graph(subgraph, replace(param_path["path_reg_prob"], ".txt"=>"_$(i).txt"))

        if i > 1
            path_diff_merge = replace(param_path["path_elastix_difficulty"], ".txt"=>"_merge$(i).txt")
            open(path_diff_merge, "w") do f
                t_merge = param["t_range"]
                write(f, string(t_merge)*"\n")

                diff_merge = deepcopy(data_dict["elastix_difficulty"])
                diff_merge[t_inc, t_inc] .= Inf

                t_inc_1 = [t for t in param["t_range"] if t <= param["max_graph_num"]]
                diff_merge[t_inc_1, t_inc_1] .= Inf

                write(f, replace(string(diff_merge[t_merge, t_merge]), ";"=>"\n"))
            end
            graph = load_graph(path_diff_merge, func=identity);
            subgraph, unconnected = make_voting_subgraph(graph, param["degree_dataset"])
            if length(unconnected) > 0
                println("Disconnected for merge: $unconnected")
            end
            output_graph(subgraph, replace(param_path["path_reg_prob"], ".txt"=>"_merge$(i).txt"))
        end
    end
    # loads registration problems you previously computed from the graph
    if param["max_graph_num"] < param["max_t"]
        data_dict["registration_problems"] = load_registration_problems([replace(param_path["path_reg_prob"], ".txt"=>"_1.txt"), replace(param_path["path_reg_prob"], ".txt"=>"_2.txt"), replace(param_path["path_reg_prob"], ".txt"=>"_merge2.txt")]);
    else
        data_dict["registration_problems"] = load_registration_problems([replace(param_path["path_reg_prob"], ".txt"=>"_1.txt")]);
    end;
end

In [None]:
for dataset = datasets_register
    open(param_paths[dataset_central]["path_reg_prob_$dataset"], "w") do f
        for problem in data_dicts["$(dataset)_to_central"]["registration_problems"]
            write(f, "$(problem[1]) $(problem[2])\n")
        end
    end
end
for dataset = keys(datasets)
    open(param_paths[dataset]["path_reg_prob"], "w") do f
        for problem in data_dicts[dataset]["registration_problems"]
            write(f, "$(problem[1]) $(problem[2])\n")
        end
    end
end

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

## Registration on the OpenMind cluster

After the registration problems have been determined, the data can be uploaded to OpenMind. This code generates a number of registration scripts to run all of the registration problems using `elastix`, and then executes them all with `sbatch`. If you encounter errors here, it is likely that you have misconfigured something related to interfacing with OpenMind, either in a parameter setting in this notebook or on your computer (eg: `ssh` keys).

### Immobilized and camera alignment registration

Expected runtime: 1-8 hours.

#### Set up all intra-filter immobilized registrations in the red channel

In [None]:
# syncs data to server and generates sbatch files for elastix
for dataset = keys(datasets)
    if dataset in datasets_freely_moving
        continue
    else
        @time write_sbatch_graph(data_dicts[dataset]["registration_problems"], param_paths[dataset], param_paths[dataset], params[dataset],
            nrrd_dir_key="path_dir_nrrd", nrrd_om_dir_key="path_om_nrrd")
    end
end

#### Set up camera alignment registration

In [None]:
@time write_sbatch_graph([(x,x) for x in params[dataset_central]["t_range"]], param_paths[dataset_central], 
    param_paths[dataset_central], params[dataset_central], moving_channel_key="ch_activity",
    reg_dir_key="path_dir_reg_activity_marker", reg_om_dir_key="path_om_reg_activity_marker",
    nrrd_dir_key="path_dir_nrrd", nrrd_om_dir_key="path_om_nrrd",
    path_head_rotate_key="path_head_rotate_activity_marker", parameter_files_key="parameter_files_activity_marker",
    job_name_key="job_name_activity_marker", cmd_dir_key="path_dir_cmd_am", cmd_om_key="path_om_cmd_am", clear_cmd_dir=true);


In [None]:
@time write_sbatch_graph([(x,x) for x in params[dataset_central]["t_range"]], param_paths[dataset_central], 
    param_paths[dataset_central], params[dataset_central], fixed_channel_key="ch_activity",
    reg_dir_key="path_dir_reg_marker_activity", reg_om_dir_key="path_om_reg_marker_activity",
    nrrd_dir_key="path_dir_nrrd", nrrd_om_dir_key="path_om_nrrd",
    path_head_rotate_key="path_head_rotate_activity_marker", parameter_files_key="parameter_files_activity_marker",
    job_name_key="job_name_marker_activity", cmd_dir_key="path_dir_cmd_ma", cmd_om_key="path_om_cmd_ma", clear_cmd_dir=false);

#### Set up intra-filter immobilized registrations in the green channel

In [None]:
for dataset = keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    @time write_sbatch_graph(data_dicts[dataset]["registration_problems"], param_paths[dataset], param_paths[dataset], params[dataset],
        moving_channel_key="ch_activity", fixed_channel_key="ch_activity", reg_dir_key="path_dir_reg_activity",
        reg_om_dir_key="path_om_reg_activity", job_name_key="job_name_activity", cmd_dir_key="path_dir_cmd_activity", 
        cmd_om_key="path_om_cmd_activity", clear_cmd_dir=false, nrrd_dir_key="path_dir_nrrd", nrrd_om_dir_key="path_om_nrrd")
end


In [None]:
paths_cmd = []
for dataset in keys(datasets)
    if !(dataset == dataset_central)
        push!(paths_cmd, param_paths[dataset]["path_om_cmd_array"])
    end
end

The following lines of code submit the jobs for running on OpenMind. Note that if the code is stalling for a long time without running, you may have a file that got stuck in your `lock` directory (this most commonly happens if a copy of `ANTSUN` fails to communicate with OpenMind to query whether its jobs are completed). Make sure that there is no other copy of `ANTSUN` that is currently running jobs on OpenMind, and then clear your `lock` directory to get the code to run.

In [None]:
get_lock(param_paths[dataset_central], params[dataset_central])
@time run_elastix_openmind(param_paths[dataset_central], params[dataset_central], extra_cmd_paths=paths_cmd)

Wait for OpenMind to finish running the jobs before submitting more:

In [None]:
@time wait_for_elastix(params[dataset_central])
release_lock(param_paths[dataset_central], params[dataset_central])

### Freely-moving registration

Expected runtime: 2-5 days.

Note that you can continue executing subsequent code in `ANTSUN` while the freely-moving registration problems are ongoing in OpenMind. Alternatively, if you do not want to submit them to OpenMind until you are certain the dataset is usable, you can move this section to the beginning of the "Extract traces in freely-moving animal" section.

In [None]:
for dataset = datasets_freely_moving
    @time write_sbatch_graph(data_dicts[dataset]["registration_problems"], param_paths[dataset], param_paths[dataset], params[dataset])
end

In [None]:
for dataset = datasets_register
    @time write_sbatch_graph(data_dicts["$(dataset)_to_central"]["registration_problems"], param_paths[dataset_central], param_paths[dataset], 
        params[dataset_central], reg_dir_key="path_dir_reg_$(dataset)", reg_om_dir_key="path_om_reg_$(dataset)", cmd_dir_key="path_dir_cmd_$(dataset)",
        cmd_om_key="path_om_cmd_$dataset", job_name_key="job_name_$dataset", clear_cmd_dir=false)
end

In [None]:
get_lock(param_paths[dataset_central], params[dataset_central])
@time run_elastix_openmind(param_paths[dataset_central], params[dataset_central])

### Syncing the data

Expected runtime: 30-60 minutes.

After you run the registration problems on OpenMind, sync the data back to your local machine.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    @time sync_registered_data(param_paths[dataset], params[dataset]);
end

In [None]:
@time sync_registered_data(param_paths[dataset_central], params[dataset_central], reg_dir_key="path_dir_reg_activity_marker", reg_om_dir_key="path_om_reg_activity_marker");
@time sync_registered_data(param_paths[dataset_central], params[dataset_central], reg_dir_key="path_dir_reg_marker_activity", reg_om_dir_key="path_om_reg_marker_activity");

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    @time sync_registered_data(param_paths[dataset], params[dataset], reg_dir_key="path_dir_reg_activity", reg_om_dir_key="path_om_reg_activity");
end

You will also need to modify the directories to point to your local machine instead of OpenMind:

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    error_dicts[dataset]["elastix_param_path_errors"] = fix_param_paths(data_dicts[dataset]["registration_problems"], param_paths[dataset], params[dataset]);
end

In [None]:
error_dicts[dataset_central]["elastix_param_path_errors_am"] = fix_param_paths([(x,x) for x in params[dataset_central]["t_range"]], param_paths[dataset_central], params[dataset_central];
        reg_dir_key="path_dir_reg_activity_marker", n_resolution_key="reg_n_resolution_activity_marker");
error_dicts[dataset_central]["elastix_param_path_errors_ma"] = fix_param_paths([(x,x) for x in params[dataset_central]["t_range"]], param_paths[dataset_central], params[dataset_central];
        reg_dir_key="path_dir_reg_marker_activity", n_resolution_key="reg_n_resolution_activity_marker");

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    error_dicts["elastix_param_path_errors"] = fix_param_paths(data_dicts[dataset]["registration_problems"], param_paths[dataset], params[dataset],
            reg_dir_key="path_dir_reg_activity");
end

# Instance segmentation

For more information, see https://github.com/flavell-lab/pytorch-3dunet and https://flavell-lab.github.io/SegmentationTools.jl/stable/segment

### Make HDF5 files from NRRD files

The U-Net requires data to be input in HDF5 format. Make sure to use the original NRRD files, **not** any noise-filtered version.

In [None]:
for dataset in keys(datasets)
    make_unet_input_h5(param_paths[dataset], param_paths[dataset]["path_dir_nrrd_crop"], params[dataset]["t_range"], params[dataset]["ch_marker"], param_paths[dataset]["get_basename"])
end

### Run `pytorch-3dunet`

Expected runtime: 1 hour.

This command runs the 3D-UNet to segment neurons on the current dataset.

Note that this step can crash if you are out of GPU memory. If it does so, restart the kernel, rerun "Initialize data dictionaries", then rerun this step and all subsequent steps.

In [None]:
for dataset in keys(datasets)
    @time call_unet(param_paths[dataset], gpu=flv_c-1)
end

The UNet segmentation used the GPU, so we can try to free GPU memory here, though note that the kernel will need to be restarted before all memory can be freed.

In [None]:
CUDA.reclaim()
GC.gc()

### Watershedding Concave Neurons

Expected runtime: 1 hour.

The U-Net often fails to completely instance segment neurons, which leaves oddly-shaped and concave neurons that need further segmentation. We use a customized watershed algorithm for this. There are a bunch of heuristic keyword parameters in `instance_segment_threshold` to play around with, if it doesn't work well initally.

In [None]:
for dataset in keys(datasets)
    error_dicts[dataset]["seg_error"] = instance_segmentation_watershed(params[dataset], param_paths[dataset], param_paths[dataset]["path_dir_nrrd_crop"],
            params[dataset]["t_range"], param_paths[dataset]["get_basename"], save_centroid=true, save_signal=true, save_roi=true);
end
GC.gc()


You can view the segmentation quality at a given time point by uncommenting this code:

In [None]:
# let
#     t = 221 
#     raw_contrast = 3

#     view_roi_3D(read_img(NRRD(joinpath(param_path["path_dir_nrrd_crop"], param_path["get_basename"](t, param["ch_marker"])*".nrrd"))), load_predictions(joinpath(param_path["path_dir_unet_data"], "$(t)_predictions.h5")), read_img(NRRD(joinpath(param_path["path_dir_roi_watershed"], "$(t).nrrd"))), plot_size=(1200, 400), raw_contrast=raw_contrast)
# end

# NeuroPAL RGB Image Extraction 

This section finishes immobilized dataset registrations and produces the NeuroPAL RGB image.

### Get GCaMP activity

Expected runtime: 1 hour. GPU-accelerated.

Get GCaMP activity for the UNet ROIs. First, we use the median red to green (camera-alignment) registration to align the two cameras; this will be useful again later in the immobilized NeuroPAL registration.

In [None]:
data_dicts[dataset_central]["euler_params"], data_dicts[dataset_central]["params_avg"], error_dicts[dataset_central]["average_am"] = average_am_registrations(params[dataset_central]["t_range"], param_paths[dataset_central]);

Next, we transform the GCaMP data through the camera-alignment registration, shear-correction, and cropping steps so it aligns with the ROIs computed from the UNet run on the red channel data. Note that importantly, we do **not** run the GCaMP data through the denoising step.

In [None]:
for dataset in datasets_freely_moving
    if !(dataset == dataset_central)
        t_valid = params[dataset_central]["t_range"][1]
        path_avg_reg = joinpath(param_paths[dataset_central]["path_dir_reg_activity_marker"], "$(t_valid)to$(t_valid)", param_paths[dataset_central]["name_transform_activity_marker_avg"])
        for t in params[dataset]["t_range"]
            cp(path_avg_reg, joinpath(param_paths[dataset]["path_dir_reg_activity_marker"], "$(t)to$(t)", param_paths[dataset]["name_transform_activity_marker_avg"]), force=true)
        end
    end
    error_dicts[dataset]["am_reg_errors"] = extract_activity_am_reg(param_paths[dataset], params[dataset], data_dicts[dataset]["shear_params"], data_dicts[dataset]["dict_param_crop_rot"], transform_key="name_transform_activity_marker_avg")
end

### Average immobilized datasets and transform to ch2

Expected runtime: 1 minute.

This step uses the immobilized registrations from OpenMind to register the immobilized images from the same filter together and create a higher-SNR averaged image. If some of those registrations failed, warning messages will appear. The dataset is still likely usable if only a small number of registrations failed, but if too many failed the immobilized data quality will severely suffer and it may be necessary to rerun the immobilized registration on OpenMind, rerun the "Syncing the data" section, and then re-run the rest of the "NeuroPAL RGB Image Extraction" section.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    for ch=[1,2]
        reg_dir = (ch == 1) ? param_paths[dataset]["path_dir_reg_activity"] : param_paths[dataset]["path_dir_reg"]
        averaged_stack = average_registered_images(reg_dir,
            joinpath(param_paths[dataset]["path_dir_nrrd"], param_paths[dataset]["get_basename"](reg_timepts[dataset],ch)*".nrrd"),
            data_dicts[dataset]["registration_problems"], param_paths[dataset]["nrrd_avg_res"])
        write_nrrd(param_paths[dataset]["path_nrrd_avg"][ch], averaged_stack, (params[dataset]["spacing_lat"], params[dataset]["spacing_lat"], params[dataset]["spacing_axi"]))
    end
end

Very rarely, the intra-filter immobilized registrations will succeed but give bad quality registration output (visible in the registered images as duplicated or distorted neurons). If that happens, you can manually try to find which time points gave the bad quality output and exclude them by uncommenting and adapting the following code, which generates many different immobilized images constructed by excluding certain time points.

In [None]:
# for dataset in keys(datasets)
#     if dataset in datasets_freely_moving
#         continue
#     end
#     for n=[2,4,10,20,60]
#         reg_problems = []
#         count=1
#         for t=1:n-1
#             if t==reg_timepts[dataset]
#                 count += 1
#             end
#             push!(reg_problems, (count, reg_timepts[dataset]))
#             count += 1
#         end

#         for ch=[1,2]
#             reg_dir = (ch == 1) ? param_paths[dataset]["path_dir_reg_activity"] : param_paths[dataset]["path_dir_reg"]
#             averaged_stack = average_registered_images(reg_dir,
#                 joinpath(param_paths[dataset]["path_dir_nrrd"], param_paths[dataset]["get_basename"](reg_timepts[dataset],ch)*".nrrd"),
#                 reg_problems, param_paths[dataset]["nrrd_avg_res"])
#             write_nrrd(joinpath(param_paths[dataset]["path_root_process"], "partial_averaged_n$(n)_ch$(ch).nrrd"), averaged_stack, (params[dataset]["spacing_lat"], params[dataset]["spacing_lat"], params[dataset]["spacing_axi"]))
#         end
#     end
# end

The following code transforms the immobilized registered images using the camera-alignment registration fit on the freely-moving data. This is especially important for the BFP channel.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving
        continue
    end
    
    t_valid = params[dataset_central]["t_range"][1]
    run_transformix_img(joinpath(param_paths[dataset_central]["path_dir_reg_activity_marker"], "$(t_valid)to$(t_valid)"),
        param_paths[dataset]["path_nrrd_avg_ch1toch2"], param_paths[dataset]["path_nrrd_avg"][1], 
        joinpath(param_paths[dataset_central]["path_dir_reg_activity_marker"], "$(t_valid)to$(t_valid)", param_paths[dataset_central]["name_transform_activity_marker_avg"]),
        joinpath(param_paths[dataset_central]["path_dir_reg_activity_marker"], "$(t_valid)to$(t_valid)", "$(dataset)_"*param_paths[dataset_central]["name_transform_activity_marker_roi"]),        
        param_paths[dataset]["path_transformix"]);
end

Occasionally, the camera-alignment registration fit on the freely-moving data will fail to generalize to the immobilized data. Such a failure would be visible in this image as a misalignment between red neurons (red fluorescent protein channel) and green neurons (BFP channel). **It is strongly recommended to carefully inspect this image for such artifacts**, though note that there may be fewer green neurons than red neurons because not all neurons express BFP.

If you find such artifacts, you will need to manually refit the transformation on the immobilized images outside of this notebook, eg by running the following commands from the command line after replacing all relevant paths:

`cd /data1/prj_XYZ/data_processed/DATASET_output`

`mkdir camera_align_BFP`

`elastix -m /data1/prj_XYZ/data_processed/DATASET_output/neuropal/DATASET_BFP/avg_ch1toch2/result.nrrd -f /data1/prj_XYZ/data_processed/DATASET_output/neuropal/DATASET_BFP/avg_ch2.nrrd -out camera_align_BFP -p /data1/shared/elastix_parameters/parameters_euler_turbo_output.txt`

If you care about the GCaMP channel in other filters in the NeuroPAL immobilized data for some reason, you would also need to run analogous commands for the other filters and edit the code in "Run `transformix` to transform all channels using immobilized registration" accordingly.

In [None]:
let
    dataset = "BFP"
    img1 = maxprj(read_img(NRRD(joinpath(param_paths[dataset]["path_nrrd_avg_ch1toch2"], "result.nrrd"))), dims=3)[:,:,1]
    img2 = maxprj(read_img(NRRD(param_paths[dataset]["path_nrrd_avg"][2])), dims=3)[:,:,1]
#     RGB.(img1 ./ maximum(img1), img2 ./ maximum(img2), 0)
    RGB.((img2 .- median(img2)) ./ (maximum(img2) - median(img2)), (img1 .- median(img1)) ./ (maximum(img1) - median(img1)), 0)
end

If you use manual camera alignment, set the following parameter to `true`:

In [None]:
manual_camera_alignment = false;

If you find such artifacts in the immobilized images, it is also worth checking that the freely-moving image did not have camera-alignment issues near the end of the recording (eg: if someone bumped the cameras mid-recording session), as such camera-alignment issues would make the data unusable. However, it is extremely rare for the freely-moving data to have this issue.

In [None]:
let
    dataset = "freely_moving"
    img1 = maxprj(read_img(NRRD(joinpath(param_paths[dataset]["path_dir_nrrd_crop"], "$(datasets["freely_moving"])_t1550_ch1.nrrd"))), dims=3)[:,:,1]
    img2 = maxprj(read_img(NRRD(joinpath(param_paths[dataset]["path_dir_nrrd_crop"], "$(datasets["freely_moving"])_t1550_ch2.nrrd"))), dims=3)[:,:,1]
#  00 RGB.(img1 ./ maximum(img1), img2 ./ maximum(img2), 0)
    RGB.((img2 .- median(img2)) ./ (maximum(img2) - median(img2)), (img1 .- median(img1)) ./ (maximum(img1) - median(img1)), 0)
end

### Register immobilized datasets to each other and compute NeuroPAL RGB image

Expected runtime: 1-3 hours. Unlike other registration problems, these registrations run locally on your computing server. They are very computationally expensive (using most cores available on the system) but only use the CPU, not the GPU. This step can take substantially longer if the server is under high load.

**This is the section of the pipeline that is most likely to require manual intervention.**

#### Set manual translation and rotation parameters

You should leave `manual_translate` and `manual_rotate` empty the first time you run the code in the hopes that it works out of the box, but be prepared to modify them and rerun this code several times.

In [None]:
## fully-manual translation
manual_translate = Dict()
manual_rotate = Dict()

In [None]:
for dataset in keys(manual_translate)
    img_dataset = read_img(NRRD(param_paths[dataset]["path_nrrd_avg"][2]))
    img_ch1toch2 = read_img(NRRD(joinpath(param_paths[dataset]["path_nrrd_avg_ch1toch2"], "result.nrrd")))
    Δ = manual_translate[dataset]
    s = size(img_dataset)
    rot_mtx = AngleAxis(manual_rotate[dataset], 0, 0, 1)
    img_translate = fill(median(img_dataset), s)
    img_translate[max(Δ[1]+1,1):min(s[1],Δ[1]+s[1]), max(Δ[2]+1,1):min(s[2],Δ[2]+s[2]), :] .=
            img_dataset[max(-Δ[1]+1,1):min(s[1],-Δ[1]+s[1]), max(-Δ[2]+1,1):min(s[2],-Δ[2]+s[2]), :]
    img_translate = warp(img_translate, recenter(rot_mtx, s.÷2), fillvalue=median(img_dataset))[1:s[1],1:s[2],1:s[3]]
    write_nrrd(param_paths[dataset]["path_nrrd_avg_translate"][2], img_translate, (spacing_lat, spacing_lat, spacing_axi))
    s = size(img_ch1toch2)
    img_translate = fill(median(img_ch1toch2), s)
    img_translate[max(Δ[1]+1,1):min(s[1],Δ[1]+s[1]), max(Δ[2]+1,1):min(s[2],Δ[2]+s[2]), :] .=
            img_ch1toch2[max(-Δ[1]+1,1):min(s[1],-Δ[1]+s[1]), max(-Δ[2]+1,1):min(s[2],-Δ[2]+s[2]), :]
    img_translate = warp(img_translate, recenter(rot_mtx, s.÷2), fillvalue=median(img_dataset))[1:s[1],1:s[2],1:s[3]]
    write_nrrd(param_paths[dataset]["path_nrrd_avg_ch1toch2_translate"], img_translate, (spacing_lat, spacing_lat, spacing_axi))
end

#### Run inter-filter immobilized registration locally

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving || dataset == dataset_immobilized_central
        continue
    end
    avg_path_key = (dataset in keys(manual_translate)) ? "path_nrrd_avg_translate" : "path_nrrd_avg"
    path_immobilized_central = param_paths[dataset_immobilized_central]["path_nrrd_avg"][2]
    run_registration(param_paths[dataset], path_immobilized_central, param_paths[dataset][avg_path_key][2])
end

### Inspect registration quality

In this block of code, you can inspect the quality of the immobilized registrations. Each registration is represented with two parameters `X` and `Y`. Usually, you will want to set `X=2` and vary `Y` between 0 and 3. You should run the code with each value of `Y` and manually inspect all the resulting images to find the best one, and you should do this for each immobilized dataset (except `all_red` since that's the dataset you are registering the other datasets to). A good match will look like a lot of yellow neurons, which means that the green channel (which is your filter) matched with the red channel (which is the `all_red` image you are registering your filter to). Since the entire purpose of the NeuroPAL strain is to have different neurons have different intensities in different filters, some neurons will look redder and some will look greener. However, bad registrations can be recognized by either very weird green distortions that don't look like neurons, or by a lot of misalignment between red and green neurons.

It is fairly common for it to be the case that none of the registrations look usable, especially for the `mNeptune` channel. If this happens, it is often because there was drift between different immobilized recordings which the registration failed to correct for. You can check if this is the case by setting `display_mode = "raw"` below and examining the output for severe drift (ie: mismatch between green and red).

You can often correct for this drift manually. To do so, modify the `manual_translate` and `manual_rotate` parameters in the "Set manual translation and rotation parameters" step (eg: `manual_translate = Dict("mNeptune" => (17,-10));` and `manual_rotate = Dict("mNeptune" => π/30);`), then set `display_mode = "translate"` and rerun this code. Make sure to rerun the entire "Set manual translation and rotation parameters" setp before rerunning this block of code, though it is not necessary to rerun the "Run inter-filter immobilized registration locally" step in between iterations. Ideally, by iterating between this block of code to visualize the images, and editing the translation and rotation parameters, you should be able to get a good alignment between the two images (red and green). It is okay if the alignment is not perfect, but once you have it as well-aligned as possible with manual translation and rotation, you can then rerun the "Run inter-filter immobilized registration locally" step and change the `for dataset in keys(datasets)` line to `for dataset in [dataset]`, where `dataset` is the dataset you just manually aligned. Once the registration is finished, you can then reset `display_mode = "register"` and repeat the process of finding the optimal `Y` value. Hopefully, one of the registrations would have succeeded this time.

If none of the registrations succeed even after you tried manually correcting for drift, if the problematic channel was `mNeptune` and the `mNeptune_GCaMP` channel did register successfully, you can use that channel, but notify whoever is labeling the NeuroPAL image that they will need to use the `mNeptune_GCaMP` channel and deal with the bleedthrough that this channel causes. If either both `mNeptune` and `mNeptune_GCaMP` failed, or any other channel failed, your dataset will not be usable without a very large amount of effort.

**Note:** if you restarted the kernel, you will need to rerun the blocks of cells that define the `manual_camera_alignment`, `manual_translate`, and `manual_rotate` variables or the following code blocks will not work.

In [None]:
let
    dataset = "BFP" # Set this to the immobilized dataset you are checking the registration quality of
    dim = 3 # Dimension to project the data. It is usually most useful to leave this at 3.
    display_mode = "register"
    X = 2
    Y = 0
    contrast = 0.4 # contrast parameter

    if display_mode == "raw"
        ## This loads the raw image for your immobilized dataset, without any registration being performed.
        img1 = maxprj(read_img(NRRD(param_paths[dataset]["path_nrrd_avg"][2])), dims=dim)[:,:,1]
    elseif display_mode == "translate"
    ## This loads the manually-translated image for your immobilized dataset, before registration.
    ## You should try to adjust the translation and rotation parameters to get this to line up with the `all_red` image as best as you can.
        img1 = maxprj(read_img(NRRD(param_paths[dataset]["path_nrrd_avg_translate"][2])), dims=dim)[:,:,1]
    elseif display_mode == "register"
        ## This loads the latest registered image of the given dataset.
        ## Ideally, if the registration worked well, it would look very similar to the `all_red` image.
        img1 = maxprj(read_img(NRRD(joinpath(param_paths[dataset]["path_dir_reg_neuropal"], "result.$(X).R$(Y).nrrd"))), dims=dim)[:,:,1]
    end

    ## This loads the `all_red` image, which is what you are trying to get your image to look like.
    img2 = maxprj(read_img(NRRD(param_paths["all_red"]["path_nrrd_avg"][2])), dims=dim)[:,:,1]
    
    RGB.((img2 .- median(img2)) ./ (contrast*maximum(img2)), (img1 .- median(img1)) ./ (contrast*maximum(img1)), 0)    
end

Modify these values based on which registrations look the best (set them to the `(X,Y)` pairs you found above):

In [None]:
params["BFP"]["registration_resolution_neuropal"] = (2,1)
params["OFP"]["registration_resolution_neuropal"] = (2,1)
params["mNeptune"]["registration_resolution_neuropal"] = (2,0)
params["mNeptune_gcamp"]["registration_resolution_neuropal"] = (2,1)

You can save your registration settings by uncommenting this block of code (this is recommended to do once you've finalized them):

In [None]:
## Uncomment this code to overwrite previous dictionaries.
# for dataset in keys(datasets)
#     if "get_basename" in keys(param_paths[dataset])
#         delete!(param_paths[dataset], "get_basename")
#     end

#     param_path = param_paths[dataset]
#     param = params[dataset]
#     JLD2.@save(param_path["path_param_path"], param_path)
#     add_get_basename!(param_path)
#     JLD2.@save(param_path["path_param"], param)
# end

### Run `transformix` to transform all channels using immobilized registration

The previous sections only aligned images in the red channel, which would not align the BFP channel. Thus, we need to run it through both the camera-alignment registration and the inter-filter registration to align it to the rest of the NeuroPAL image.

If you used manual camera-alignment registration, the code will apply it to the BFP channel. If you also care about datasets other than `BFP` for manual camera-alignment registration (eg: if for some reason you cared about the GCaMP signal in the OFP filter), modify the `dataset == "BFP"` part of this code accordingly.

In [None]:
for dataset in keys(datasets)
    if dataset in datasets_freely_moving || dataset == dataset_immobilized_central
        continue
    end
    
    input_img_path = (dataset in keys(manual_translate)) ? param_paths[dataset]["path_nrrd_avg_ch1toch2_translate"] : joinpath(param_paths[dataset]["path_nrrd_avg_ch1toch2"], "result.nrrd")
    res = params[dataset]["registration_resolution_neuropal"]
    
    if manual_camera_alignment && dataset == "BFP"
        input_img_path = joinpath(param_paths["freely_moving"]["path_root_process"], "camera_align_BFP/result.0.R2.nrrd")
    end
    
    run_transformix_img(param_paths[dataset]["path_dir_reg_neuropal"],
        param_paths[dataset]["path_nrrd_avg_ch1toch2_reg"],
        input_img_path,
        joinpath(param_paths[dataset]["path_dir_reg_neuropal"], "TransformParameters.$(res[1]).R$(res[2]).txt"),
        joinpath(param_paths[dataset]["path_dir_reg_neuropal"], "TransformParameters.$(res[1]).R$(res[2])_roi.txt"),
        param_paths[dataset]["path_transformix"]);
end

### Generate final NeuroPAL RGB images

This code generates the final NeuroPAL RGB images. As in the NeuroPAL manual, red = mNeptune, green = OFP, and blue = BFP. It will save them to the `neuropal/NeuroPAL.nrrd` and `neuropal/NeuroPAL_mNeptune_GCaMP_bleedthrough.nrrd` files, depending on `mNeptune` channel version. If you got the `mNeptune` registration to work it is recommended to use the `neuropal/NeuroPAL.nrrd` version.

This code will also save a separate `nrrd` file for each filter individually (after alignment, but not merging them together with RGB colors), and a `neuron_rois.nrrd` file containing the UNet output ROIs. When labeling the data, each ROI in that file should be assigned a neuron label if possible.

In [None]:
let
    cr_params = deepcopy(data_dicts[dataset_immobilized_central]["dict_param_crop_rot"][reg_timepts[dataset_immobilized_central]])
    vec = data_dicts["$(dataset_immobilized_central)_to_central"]["curves"][params[dataset_central]["max_t"]+reg_timepts[dataset_immobilized_central]]
   
    crop_x, crop_y, crop_z = cr_params["crop"]
    θ_ = cr_params["θ"]
    worm_centroid = cr_params["worm_centroid"]
    
    
    res = params["mNeptune"]["registration_resolution_neuropal"]
    mNeptune = read_img(NRRD(joinpath(param_paths["mNeptune"]["path_dir_reg_neuropal"], "result.$(res[1]).R$(res[2]).nrrd")))
    res = params["mNeptune_gcamp"]["registration_resolution_neuropal"]
    mNeptune_gcamp = read_img(NRRD(joinpath(param_paths["mNeptune_gcamp"]["path_dir_reg_neuropal"], "result.$(res[1]).R$(res[2]).nrrd")))
    res = params["OFP"]["registration_resolution_neuropal"]
    OFP = read_img(NRRD(joinpath(param_paths["OFP"]["path_dir_reg_neuropal"], "result.$(res[1]).R$(res[2]).nrrd")))
    OFP_GCaMP = read_img(NRRD(joinpath(param_paths["OFP"]["path_nrrd_avg_ch1toch2_reg"], "result.nrrd")))
    BFP = read_img(NRRD(joinpath(param_paths["BFP"]["path_nrrd_avg_ch1toch2_reg"], "result.nrrd")))
    all_red = read_img(NRRD(joinpath(param_paths["all_red"]["path_nrrd_avg"][2])))
    img_roi = read_img(NRRD(joinpath(param_paths["all_red"]["path_dir_roi_watershed"], "$(reg_timepts[dataset_immobilized_central]).nrrd")))
    
    
    # not already cropped - need to crop
    mNeptune, _, _, _ = crop_rotate(mNeptune, crop_x, crop_y, crop_z, θ_, worm_centroid)
    mNeptune_gcamp, _, _, _ = crop_rotate(mNeptune_gcamp, crop_x, crop_y, crop_z, θ_, worm_centroid)
    OFP, _, _, _ = crop_rotate(OFP, crop_x, crop_y, crop_z, θ_, worm_centroid)
    BFP, _, _, _ = crop_rotate(BFP, crop_x, crop_y, crop_z, θ_, worm_centroid)
    OFP_GCaMP, _, _, _ = crop_rotate(OFP_GCaMP, crop_x, crop_y, crop_z, θ_, worm_centroid)
    all_red, _, _, _ = crop_rotate(all_red, crop_x, crop_y, crop_z, θ_, worm_centroid)
    
    z_rot = 0
    if vec[1][1] > vec[1][end]
        z_rot = π
    end
    
    rot_x = AngleAxis(π*rotate_img_x, 1, 0, 0)
    rot_z = AngleAxis(z_rot, 0, 0, 1)
    rot_mtx = rot_x*rot_z
    
    img_size = size(mNeptune)
    
    mNeptune = warp(mNeptune, recenter(rot_mtx, img_size./2))
    mNeptune_gcamp = warp(mNeptune_gcamp, recenter(rot_mtx, img_size./2))
    OFP = warp(OFP, recenter(rot_mtx, img_size./2))
    OFP_GCaMP = warp(OFP_GCaMP, recenter(rot_mtx, img_size./2))
    BFP = warp(BFP, recenter(rot_mtx, img_size./2))
    img_roi = UInt16.(warp(img_roi, recenter(rot_mtx, img_size./2), Constant()))
    all_red = warp(all_red, recenter(rot_mtx, img_size./2))
    
    mNeptune = max.(mNeptune .- median(mNeptune), 0.)
    mNeptune_gcamp = max.(mNeptune_gcamp .- median(mNeptune_gcamp), 0.)
    OFP = max.(OFP .- median(OFP), 0.)
    OFP_GCaMP = max.(OFP_GCaMP .- median(OFP_GCaMP), 0.)
    BFP = max.(BFP .- median(BFP), 0.)
    all_red = max.(all_red .- median(all_red), 0.)
    
    
    img_RGB = zeros(eltype(mNeptune), (size(mNeptune)...,3))
    img_RGB[:,:,:,1] .= parent(mNeptune)
    img_RGB[:,:,:,2] .= parent(OFP)
    img_RGB[:,:,:,3] .= parent(BFP)
    
    write_nrrd(path_bfp_img, BFP, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_ofp_img, OFP, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_gcamp_img, OFP_GCaMP, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_all_red_img, all_red, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_mNeptune_img, mNeptune, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_mNeptune_gcamp_img, mNeptune_gcamp, (spacing_lat, spacing_lat, spacing_axi))
    write_nrrd(path_neuron_img, img_roi, (spacing_lat, spacing_lat, spacing_axi))
    
    header = nrrd_header(eltype(img_RGB), (spacing_lat, spacing_lat, spacing_axi), size(img_RGB))
    write_nrrd(path_neuropal_img, img_RGB, header)
    
    img_RGB = zeros(eltype(mNeptune_gcamp), (size(mNeptune_gcamp)...,3))
    img_RGB[:,:,:,1] .= parent(mNeptune_gcamp)
    img_RGB[:,:,:,2] .= parent(OFP)
    img_RGB[:,:,:,3] .= parent(BFP)
    
    header = nrrd_header(eltype(img_RGB), (spacing_lat, spacing_lat, spacing_axi), size(img_RGB))
    write_nrrd(path_neuropal_img_mNeptune_GCaMP, img_RGB, header)
end

In [None]:
CUDA.reclaim()
GC.gc()

# Behavioral Data

Expected runtime: 8 hours. GPU-accelerated.

Using the NIR data collected concurrently with the confocal data, we can extract behavioral parameters such as velocity, head curvature, and more, and align them with the confocal data.

For more information, please see https://flavell-lab.github.io/BehaviorDataNIR.jl/stable/

### Read data, extract worm spline and position

This section runs a 2D UNet to segment out the worm's position in the NIR images. It them runs a custom spline-fitting algorithm that incorporates data from nearby timepoints to determine the correct spline to use in instances where the worm is touching itself.

**IMPORTANT:** Be sure that you have GPU memory available before running this section! This is the one section of `ANTSUN` that can have a "silent failure" mode, where the code can crash and still give reasonable-looking but incorrect output, which can be very dangerous when combined with Jupyter notebook's propensity for disconnecting the client from the Jupyter server (and thereby not sending outputs such as error messages to the client). In this case, this code can crash during the step of correcting for self-intersecting worm splines, which can cause the head curvature and body curvature to be computed incorrectly. Ideally, you should be able to see three complete progres bars (the middle one will be much shorter), which will tell you that the code executed correctly.

In [None]:
for dataset = datasets_freely_moving
    param_path = param_paths[dataset]
    param = params[dataset]
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    path_h5 = param_paths["freely_moving"]["path_h5"]
    
    pos_feature, pos_feature_unet = read_pos_feature(path_h5)
    
    pos_stage = read_stage(path_h5)
    
    data_dict["x_stage"] = impute_list(pos_stage[1, :])
    data_dict["y_stage"] = impute_list(pos_stage[2, :])
    
    mn_vec, mp_vec, orthog_mp_vec = nmp_vec(pos_feature);
    data_dict["pm_angle"] = vec_to_angle(mp_vec);

    data_dict["x_array"] = zeros(data_dict["max_t_nir"], param["num_center_pts"] + 1)
    data_dict["y_array"] = zeros(data_dict["max_t_nir"], param["num_center_pts"] + 1)
    data_dict["x_med"], data_dict["y_med"] = offset_xy(data_dict["x_stage"], data_dict["y_stage"], pos_feature_unet[2,:,:]);

    data_dict["nir_worm_angle"] = zeros(data_dict["max_t_nir"])
    data_dict["eccentricity"] = zeros(data_dict["max_t_nir"])

    data_dict["med_axis_dict"] = Dict()
    data_dict["med_axis_dict"][0] = nothing
    data_dict["pts_order_dict"] = Dict()
    data_dict["pts_order_dict"][0] = nothing
    data_dict["is_omega"] = Dict()

    path_weight = param_path["path_2d_unet_param"];

    worm_seg_model = create_model(1, 1, 16, path_weight);

    error_dict["worm_spline_errors_1"] = compute_worm_spline!(param, path_h5, worm_seg_model, nothing, data_dict["med_axis_dict"], data_dict["pts_order_dict"],
            data_dict["is_omega"], data_dict["x_array"], data_dict["y_array"], data_dict["nir_worm_angle"], 
            data_dict["eccentricity"]; timepts="all")

    data_dict["worm_thickness"], count = compute_worm_thickness(param, path_h5, worm_seg_model, data_dict["med_axis_dict"], data_dict["is_omega"]);

    error_dict["worm_spline_errors_2"] = compute_worm_spline!(param, path_h5, worm_seg_model, data_dict["worm_thickness"], data_dict["med_axis_dict"], data_dict["pts_order_dict"],
            data_dict["is_omega"], data_dict["x_array"], data_dict["y_array"], data_dict["nir_worm_angle"], 
            data_dict["eccentricity"]; timepts="all")
end

### Extract variables

This block of code places evenly-spaced points along the previously-computed splines, such that the spacing of the points is preserved across time.

It then uses the spline information along with stage position information to compute the worm's velocity, angular velocity, head curvature, nose curvature, body curvature, body segment angles, and eigenworms.

In [None]:
for dataset = datasets_freely_moving
    data_dict = data_dicts[dataset]
    param = params[dataset]
    vec_to_confocal = vec -> nir_vec_to_confocal(vec, data_dict["confocal_to_nir"], param["max_t"])

    
    interpolate_splines!(data_dict)
    data_dict["segment_end_matrix"] = get_segment_end_matrix(param, data_dict["x_array"], data_dict["y_array"]);

    data_dict["x_stage_confocal"] = vec_to_confocal(data_dict["x_stage"]);
    data_dict["y_stage_confocal"] = vec_to_confocal(data_dict["y_stage"]);
    data_dict["zeroed_x"], data_dict["zeroed_y"] = zero_stage(data_dict["x_med"], data_dict["y_med"]);
    data_dict["zeroed_x_confocal"] = vec_to_confocal(data_dict["zeroed_x"])
    data_dict["zeroed_y_confocal"] = vec_to_confocal(data_dict["zeroed_y"]);
    
    # extract variables at normal time points
    get_body_angles!(data_dict, param)
    get_angular_velocity!(data_dict, param)
    get_velocity!(data_dict, param)
    get_curvature_variables!(data_dict, param)
    get_nose_curling!(data_dict, param);

    # extract variables prior to confocal recording
    get_body_angles!(data_dict, param, prefix="pre_")
    get_angular_velocity!(data_dict, param, prefix="pre_")
    get_velocity!(data_dict, param, prefix="pre_")
    get_curvature_variables!(data_dict, param, prefix="pre_")
    get_nose_curling!(data_dict, param, prefix="pre_");
    data_dict["M_bodyangle"], data_dict["X_bodyangle"], data_dict["Yt_bodyangle"] = multivar_fit(data_dict["body_angle"], PCA, maxoutdim=param["num_eigenworm"]);    
end

The UNet segmentation used the GPU, so we can try to free GPU memory here, though note that the kernel will need to be restarted before all memory can be freed.

In [None]:
CUDA.reclaim()
GC.gc()

### Plot the eigenworms

In [None]:
let
    dataset = "freely_moving"
    data_dict = data_dicts[dataset]
    Plots.plot(data_dict["M_bodyangle"].proj[:,1])
    for i=2:size(data_dict["M_bodyangle"].proj,2)
        Plots.plot!(data_dict["M_bodyangle"].proj[:,i])
    end
    xlabel!("Worm segment")
    ylabel!("Weight")
    title!("Eigenworms")
end

### Make video of behavioral variables

In [None]:
for dataset = datasets_freely_moving
    param_path = param_paths[dataset]
    param = params[dataset]
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    path_h5 = param_paths["freely_moving"]["path_h5"]
    
    
    vars = []
    n_t = data_dict["max_t_nir"]
    push!(vars, ("NIR time", [t for t in 1:n_t], palette(:tab10)[2]))
    push!(vars, ("Confocal time", [Int32.(maximum(data_dict["nir_to_confocal"][1:t])) for t in 1:n_t], palette(:tab10)[2]))

    for k in ["x_stage", "y_stage"]
        push!(vars, (uppercase(k[1]), round.(data_dict[k], digits=1), palette(:tab10)[4]))
    end

    for k in ["velocity_stage", "head_angle", "worm_curvature"]
        push!(vars, (k, [let tc = maximum(Int32.(data_dict["nir_to_confocal"][1:t])); (tc == 0) ? NaN : round(data_dict[k][tc], digits=4) end for t in 1:n_t], palette(:tab10)[6]))
    end

    push!(vars, ("Reversal state", [let tc = maximum(Int32.(data_dict["nir_to_confocal"][1:t])); (tc == 0) ? NaN : tc in data_dict["all_rev"] end for t in 1:n_t], palette(:tab10)[9]))

    write_behavior_video(path_h5, joinpath(param_path["path_root_process"], "$(datasets[dataset])_20fps.mp4"), vars=vars, downsample=false)
end;

### Save your data

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

# Extract traces in freely-moving animal

This code extracts GCaMP traces from registered and segmented data. For more information, please see https://flavell-lab.github.io/ExtractRegisteredData.jl.

If the code is running `wait_for_elastix`, it is safe to restart the kernel and debug any issues that have arisen with the immobilized registration while waiting for the freely-moving registrations to complete on OpenMind. Once you have finished any debugging you needed to do, you can come back here and start from `wait_for_elastix` again.

The `wait_for_elastix` command will wait until all of your jobs on OpenMind are complete, so it is recommended not to have other jobs running on OpenMind while you are trying to run `ANTSUN`. Note that it is fine to have multiple copies of `ANTSUN` running simultaneously due to the locking system which will prevent them from submitting jobs to OpenMind simultaneously. (Though do be careful you don't run out of GPU memory if you do this.)

You can also manually check the status of your OpenMind jobs by logging into OpenMind and running `squeue -u $USER`. Occasionally, some jobs will get stuck under `(launch failed requeued held)`, if this happens you will need to run `scontrol release jobid`, where `jobid` is the ID of the job that got stuck (only use the part before the `_` symbol). This should get the jobs unstuck.

If you get weird output after running this section, or if you simply want to check that everything worked properly before running it, you can check whether OpenMind has completed your freely-moving registration problems successfully. To do this, first make sure none of the registrations are actively running (eg via `squeue`), and then run the following command on OpenMind: `more /om2/user/$USER/DATASET_output/log/julia_run_elx.txt`, replacing `DATASET` with this notebook's dataset. The command should return some text that contains the line `all jobs have been submitted`. If it doesn't, especially if it contains a bunch of SLURM-related error messages and warnings, that means that some of the registration problems failed to be submitted to OpenMind. To resolve this, run the commands `cd /om2/user/$USER/DATASET_output` and then `sbatch --partition=flavell run_elastix_julia.sh` to resubmit the remaining jobs. Do **not** try to resubmit the jobs to OpenMind via this notebook or it will start over from the beginning, wasting a lot of computation time.

In [None]:
@time wait_for_elastix(params[dataset_central])

In [None]:
release_lock(param_paths[dataset_central], params[dataset_central])

### Sync freely-moving data

Expected runtime: 1 hour.

In [None]:
for dataset in datasets_freely_moving
    @time sync_registered_data(param_paths[dataset], params[dataset]);
end

In [None]:
for dataset in datasets_register
    @time sync_registered_data(param_paths[dataset_central], params[dataset_central], reg_dir_key="path_dir_reg_$dataset", reg_om_dir_key="path_om_reg_$dataset");
end

In [None]:
for dataset in datasets_freely_moving
    error_dicts[dataset]["elastix_param_path_errors"] = fix_param_paths(data_dicts[dataset]["registration_problems"], param_paths[dataset], params[dataset]);
end

In [None]:
for dataset in datasets_register
    error_dicts["$(dataset)_to_central"]["elastix_param_path_errors"] = fix_param_paths(data_dicts["$(dataset)_to_central"]["registration_problems"], param_paths[dataset_central], params[dataset_central], reg_dir_key="path_dir_reg_$dataset");
end

### Make quality dictionary

Expected runtime: 1 hour.

By default, each `elastix` registration will produce several different attempts at different resolution levels. This code computes a NCC-based metric for how well each registration did, and selects the best one for each problem, or discards a registration problem attempt if none of the registrations are sufficiently high-quality.

See https://flavell-lab.github.io/RegistrationGraph.jl/stable/postprocessing

In [None]:
for dataset in keys(datasets)
    evaluation_functions = Dict()
    param = params[dataset]
    param_path = param_paths[dataset]
    data_dict = data_dicts[dataset]
    error_dict = error_dicts[dataset]
    
    path_fixed = (dataset in datasets_freely_moving) ? param_path["path_dir_nrrd_filt"] : param_path["path_dir_nrrd"]
    
    evaluation_functions[param["quality_metric"]] = (moving, fixed, resolution) -> metric_tfm(calculate_ncc(read_img(NRRD(joinpath(path_fixed, param_path["get_basename"](fixed, param["ch_marker"])*".nrrd"))), read_img(NRRD(joinpath(param_path["path_dir_reg"], "$(moving)to$(fixed)", "result.$(resolution[1]).R$(resolution[2]).nrrd")))), threshold=param["reg_quality_threshold"])
    data_dict["q_dict"], data_dict["best_reg"], error_dict["q_dict_errors"] = make_quality_dict(param_path, param, data_dict["registration_problems"], evaluation_functions);
end;

In [None]:
for dataset in keys(datasets)
    for k in keys(data_dicts[dataset]["q_dict"])
        for r in keys(data_dicts[dataset]["q_dict"][k])
            if isnan(data_dicts[dataset]["q_dict"][k][r]["NCC"])
                data_dicts[dataset]["q_dict"][k][r]["NCC"] = Inf
            end
        end
    end
end

### Fraction of successful registrations

Typical values are between 50-70%. Values below 50% don't necessarily mean your data is unusable, but it is likely fairly low quality and you likely won't get as many neurons or as clean of traces.

In [None]:
let
    dataset = "freely_moving"
    length([k for k in keys(data_dicts["$(dataset)"]["q_dict"]) if minimum([data_dicts["$(dataset)"]["q_dict"][k][r]["NCC"] for r in keys(data_dicts["$(dataset)"]["q_dict"][k])]) < Inf]) / length(data_dicts["$(dataset)"]["registration_problems"])
end

### Compute ROI match matrix

Expected runtimne: 1 hour.

This step computes the matrix of ROI matches. The rows and columns of this matrix correspond to every ROI detected by the UNet segmentation in every time point. Approximately, the (i,j)th entry of this matrix represents confidence that the ROIs `i` and `j` are the same, based on registration between their corresponding time points that mapped one of them on top of the other. The confidence is also based on several heuristics that try to estimate how accurate the match is.

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["roi_overlaps"], data_dicts[dataset]["roi_activity_diff"], error_dicts[dataset]["overlap_errors"] = extract_roi_overlap(data_dicts[dataset]["best_reg"], param_paths[dataset], params[dataset]);
end

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["regmap_matrix"], data_dicts[dataset]["label_map"] = make_regmap_matrix(data_dicts[dataset]["roi_overlaps"], data_dicts[dataset]["roi_activity_diff"], data_dicts[dataset]["q_dict"], data_dicts[dataset]["best_reg"], params[dataset]);
end

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

### Group ROIs into neurons and extract traces

Expected runtime: 1-3 days.

Based on the ROI match matrix, group matched ROIs together as the same neuron using a custom hierarchical clustering algorithm.

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["new_label_map"], data_dicts[dataset]["inv_map"], hmer = find_neurons(data_dicts[dataset]["regmap_matrix"], data_dicts[dataset]["label_map"], params[dataset]);
end

Now that we have the neuron identities, we can extract traces in both channels from neurons across frames.

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["activity_traces"], error_dicts[dataset]["activity_traces_errors"] = extract_traces(data_dicts[dataset]["inv_map"], param_paths[dataset]["path_dir_activity_signal"]);
end

In [None]:
for dataset in datasets_freely_moving
    @time data_dicts[dataset]["marker_traces"], error_dicts[dataset]["marker_traces_errors"] = extract_traces(data_dicts[dataset]["inv_map"], param_paths[dataset]["path_dir_marker_signal"]);
end

# Process traces

Expected runtime: 10 minutes.

Now that traces have been extracted, we can process them to try to decrease noise levels.

For more information, please see https://flavell-lab.github.io/CaAnalysis.jl

### Get background levels

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["activity_bkg"] = get_background(1:params[dataset]["max_t"], param_paths[dataset]["get_basename"], param_paths[dataset]["path_dir_nrrd"], params[dataset]["ch_activity"]);
    data_dicts[dataset]["marker_bkg"] = get_background(1:params[dataset]["max_t"], param_paths[dataset]["get_basename"], param_paths[dataset]["path_dir_nrrd"], params[dataset]["ch_marker"]);    
end

### Check for bad time points

The following code generates a heatmap of which time points neuron signal exists. If you see large clumps of black (neuron not detected), check that the worm was in focus in those time points and possibly consider deleting them (removing them from `param["t_range"]`).

In [None]:
for dataset in datasets_freely_moving
    data_dicts[dataset]["traces_array_quality"], hmap, vr = make_traces_array(data_dicts[dataset]["activity_traces"], threshold=params[dataset]["num_detections_threshold"])
end

heatmap(abs.(data_dicts[dataset_central]["traces_array_quality"]), clim=(0,0.00001))
title!("Time point quality check")
xlabel!("time (frame #)")
ylabel!("neuron")

### Apply data processing and denoising steps

Here we can apply a variety of data processing and denoising steps. The `min_intensity` variable can remove neurons with very low signal - this needs to be set to a negative value for SWF360 to include GFP-negative neurons, but otherwise can be set to a small positive value. Setting `bleach_corr=true` will apply bleach correction via a single-exponential bleaching model fit on the ratiometric traces. Setting `divide=true` will divide the green channel by the red channel. Setting `normalize=true` will normalize all neuron traces to have a mean of 0 (by dividing by their mean and subtracting 1). Setting `interpolate_t_range` will interpolate the data to missing datapoints.

In [None]:
for dataset in datasets_freely_moving
    figure()
    data_dicts[dataset]["traces_array"], data_dicts[dataset]["traces_array_F_F20"], data_dicts[dataset]["raw_zscored_traces_array"],
            data_dicts[dataset]["valid_rois"], data_dicts[dataset]["bleach_param"], data_dicts[dataset]["bleach_curve"], data_dicts[dataset]["bleach_resid"] = 
            process_traces(params[dataset], data_dicts[dataset]["activity_traces"], data_dicts[dataset]["marker_traces"],
                        params[dataset]["num_detections_threshold"], 1:params[dataset]["max_t"], min_intensity=0, normalize_fn=mean,
                        activity_bkg=data_dicts[dataset]["activity_bkg"], marker_bkg=data_dicts[dataset]["marker_bkg"],
                        denoise=false, bleach_corr=true, divide=true, interpolate=true);
    PyPlot.savefig(joinpath(param_paths[dataset]["path_root_process"], "bleach.png"))
end

### Visualize traces

In [None]:
for dataset in datasets_freely_moving
    output_roi_candidates(data_dicts[dataset]["raw_zscored_traces_array"], data_dicts[dataset]["inv_map"], param_paths[dataset], params[dataset], param_paths[dataset]["get_basename"], params[dataset]["ch_marker"], params[dataset]["t_range"], data_dicts[dataset]["valid_rois"]);
end

In [None]:
heatmap(data_dicts[dataset_central]["traces_array_F_F20"],clims=(-1,4))
xlabel!("time (frame #)")
ylabel!("neuron")
title!("ΔF/F20")

In [None]:
heatmap(data_dicts[dataset_central]["raw_zscored_traces_array"], clim=(-1,4))
title!("GCaMP Traces")
xlabel!("time")
ylabel!("neuron")

### Check bleaching amount

It is recommended to discard datasets if the bleaching amount is at least 2.3.

In [None]:
let
    dict_bleach = data_dicts["freely_moving"]
    bleach_str = dict_bleach["bleach_curve"][1] / dict_bleach["bleach_curve"][end]
    if bleach_str >= 2.3
        @warn("High bleaching strength: $bleach_str")
    else
        println(bleach_str)
    end
end

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

# Match neurons to immobilized dataset

Expected runtime: 1 minute.

This section runs a similar algorithm to the one used to match ROIs across time points in the freely-moving datasets to generate neuron traces, but this time it matches the neuron labels from the freely-moving datasets to the immobilized NeuroPAL dataset.

In [None]:
for dataset = datasets_register
    evaluation_functions = Dict()
    evaluation_functions[params[dataset_central]["quality_metric"]] = (moving, fixed, resolution) -> metric_tfm(calculate_ncc(read_img(NRRD(joinpath(param_paths[dataset_central]["path_dir_nrrd_filt"], param_paths[dataset_central]["get_basename"](fixed, params[dataset_central]["ch_marker"])*".nrrd"))), read_img(NRRD(joinpath(param_paths[dataset_central]["path_dir_reg_$(dataset)"], "$(moving)to$(fixed)", "result.$(resolution[1]).R$(resolution[2]).nrrd")))), threshold=params[dataset_central]["reg_quality_threshold"]);
    data_dicts["$(dataset)_to_central"]["q_dict"], data_dicts["$(dataset)_to_central"]["best_reg"], 
            error_dicts["$(dataset)_to_central"]["q_dict_errors"] = 
            make_quality_dict(param_paths[dataset_central], params[dataset_central], 
            data_dicts["$(dataset)_to_central"]["registration_problems"], evaluation_functions);
end;

In [None]:
for dataset = datasets_register
    for k in keys(data_dicts["$(dataset)_to_central"]["q_dict"])
        for r in keys(data_dicts["$(dataset)_to_central"]["q_dict"][k])
            if isnan(data_dicts["$(dataset)_to_central"]["q_dict"][k][r]["NCC"])
                data_dicts["$(dataset)_to_central"]["q_dict"][k][r]["NCC"] = Inf
            end
        end
    end
end

### Fraction of successful registrations

Typical values for this are between 10% and 30%.

If no registrations were successful, the most common culprit is that the head position was computed incorrectly for the `reg_timepts["all_red"]` timepoint in the `all_red` dataset. Navigate to the head-detection UNet part of `ANTSUN`, manually enter a corrected head position, and then rerun all of the steps of the pipeline relevant to the immobilized-to-freely-moving registration (including the immobilized-to-freely-moving registration on OpenMind).

If that wasn't the issue or the registration still doesn't work, it may be the case that the immobilized image is out of focus. If that happens, it is likely that your dataset is unusable.

In [None]:
let
    dataset = "all_red"
    length([k for k in keys(data_dicts["$(dataset)_to_central"]["q_dict"]) if 
                    minimum([data_dicts["$(dataset)_to_central"]["q_dict"][k][r]["NCC"] 
                    for r in keys(data_dicts["$(dataset)_to_central"]["q_dict"][k])]) < Inf]) /
                    length(data_dicts["$(dataset)_to_central"]["registration_problems"])
end

### Compute ROI matches

In [None]:
for dataset = datasets_register
    @time data_dicts["$(dataset)_to_central"]["roi_overlaps"], data_dicts["$(dataset)_to_central"]["roi_activity_diff"],
            error_dicts["$(dataset)_to_central"]["overlap_errors"] = 
            extract_roi_overlap(data_dicts["$(dataset)_to_central"]["best_reg"], 
            param_paths[dataset_central], params[dataset_central], reg_dir_key="path_dir_reg_$dataset", 
            transformed_dir_key="path_dir_transformed_$dataset", reg_problems_key="path_reg_prob_$dataset", 
            param_path_moving=param_paths[dataset]);
end

In [None]:
for dataset = datasets_register
    @time data_dicts["$(dataset)_to_central"]["regmap_matrix"], data_dicts["$(dataset)_to_central"]["label_map"] =
            make_regmap_matrix(data_dicts["$(dataset)_to_central"]["roi_overlaps"], 
            data_dicts["$(dataset)_to_central"]["roi_activity_diff"], data_dicts["$(dataset)_to_central"]["q_dict"],
            data_dicts["$(dataset)_to_central"]["best_reg"], params[dataset_central], 
            max_fixed_t=params[dataset_central]["max_t"]);
    data_dicts["$(dataset)_to_central"]["inv_map_regmap"] = invert_label_map(data_dicts["$(dataset)_to_central"]["label_map"])
end

In [None]:
for dataset = datasets_register
    @assert(!(dataset in datasets_freely_moving), "Registration to other freely-moving datasets not yet supported.")
    @time data_dicts["$(dataset)_to_central"]["roi_matches_all"], 
    data_dicts["$(dataset)_to_central"]["inv_matches_all"],
    data_dicts["$(dataset)_to_central"]["roi_matches"],
    data_dicts["$(dataset)_to_central"]["roi_match_confidence"] = register_immobilized_rois(
            data_dicts["$(dataset)_to_central"]["regmap_matrix"], 
            data_dicts["$(dataset)_to_central"]["label_map"],
            data_dicts["$(dataset)_to_central"]["inv_map_regmap"],
            data_dicts[dataset_central]["new_label_map"],
            data_dicts[dataset_central]["valid_rois"],
            params[dataset_central], reg_timepts[dataset])
end

# Import pumping

Pumping is not automatically computed by `ANTSUN`, so it needs to be manually annotated. You can import manually-annotated pumping data here, uncommenting this code and then modifying the file path as appropriate. Importing pumping is necessary to run the `CePNEM` model on your data.

In [None]:
# import_pumping!(data_dicts["freely_moving"], params["freely_moving"], ["/path/to/pumping/file.csv"])

# Save your data

You can save the data at any time by running this command, and load from a previous save in the "Initialize Data Dictionaries" section. This is very helpful to avoid having to redo large amounts of computation in the event of a kernel failure, and can also be a way for collaborators to work on the same dataset without editing the same notebook files. You can also delete any variables that are taking up too much memory here, to prevent them from being auto-loaded when you restart the kernel.

In [None]:
for dataset = keys(datasets)
    data_dict = data_dicts[dataset]
    data_dict["t_range"] = params[dataset]["t_range"]
    JLD2.@save(param_paths[dataset]["path_data_dict"], data_dict)
end

for dataset = datasets_register
    data_dict = data_dicts["$(dataset)_to_central"]
    JLD2.@save(joinpath(param_paths[dataset_central]["path_root_process"], "$(dataset)_registered_data_dict.jld2"), data_dict)
end

In [None]:
CUDA.reclaim()
GC.gc()