# Evaluate AutoCellLabeler on SWF415 Data

This notebook evaluates the performance of the TagRFP-only AutoCellLabeler network on freely-moving data from the SWF415 strain, a different strain than the NeuroPAL strain the network was trained on.

This notebook assumes that you have already run the `AutoCellLabeler_freely_moving` notebook through the "Run TagRFP-only AutoCellLabeler" section to format the SWF415 data and run the model on it.

This notebook currently only supports using one time point per animal.

In [None]:
# Flavell lab packages
using ND2Process
using GPUFilter
using WormFeatureDetector
using NRRDIO
using FlavellBase
using SegmentationTools
using ImageDataIO
using RegistrationGraph
using ExtractRegisteredData
using CaAnalysis
using BehaviorDataNIR
using UNet2D


# Other packages
using ProgressMeter
using PyCall
using PyPlot
using Statistics
using StatsBase
using DelimitedFiles
using Images
using Cairo
using Distributions
using DataStructures
using HDF5
using Interact
using WebIO
using Plots
# using GraphPlot
# using LightGraphs
# using SimpleWeightedGraphs
using Dates
using JLD2
using TotalVariation
using VideoIO
using Distributions
using MultivariateStats
using FFTW
using LinearAlgebra
using GLMNet
using InformationMeasures
using CUDA
using LsqFit
using Optim
using Rotations
using CoordinateTransformations
using ImageTransformations
using Interpolations

## Load SWF415 Datasets

This section loads the ANTSUN outputs for the SWF415 datasets so that we can access the mapping between ROIs at the timepoints where AutoCellLabeler was run and neuron indices used to fit CePNEM.

This data is publicly available in [our Dropbox](https://www.dropbox.com/scl/fo/fb1cdxbwznhjp491ru6uq/ANJqOrenXBhA7lfxLYarRE0?rlkey=9ljwkyfhphumymgyzxfdf4c7s&st=49ev0ivz&dl=0).

In [None]:
datasets_baseline = ["2021-05-26-07", "2021-06-11-01", "2021-08-04-06", "2021-08-17-01", "2021-08-18-01", "2021-09-22-05", "2021-10-26-01", "2021-11-12-01", "2021-11-12-05", "2022-01-09-01", "2022-01-17-01", "2022-04-05-01", "2022-04-12-04", "2022-04-14-04"]
datasets_stim = ["2021-09-06-09", "2021-09-14-01", "2021-09-14-05", "2021-09-23-01", "2021-09-30-01"]
datasets_stim_1600 = ["2022-02-08-04", "2022-02-16-01", "2022-02-16-04", "2022-03-15-04", "2022-03-22-01", "2022-04-18-04"]
datasets_baseline_1600 = ["2022-01-17-01", "2022-04-05-01", "2022-04-12-04", "2022-04-14-04"]

datasets = deepcopy(datasets_baseline)
append!(datasets, datasets_stim)
append!(datasets, datasets_stim_1600)

length(datasets)

In [None]:
data_dicts = Dict()
params_dict = Dict()
param_paths = Dict()
for dataset in datasets
    path_root_process = "/store1/prj_kfc/data_processed/$(dataset)_output"
    path_param_path = joinpath(path_root_process, "param_path.jld2")
    
    if isfile(path_param_path)
        f = JLD2.jldopen(path_param_path)
        param_paths[dataset] = f["param_path"]
        close(f)
    else
        @warn("No param_path.jld2 file found for dataset: $dataset")
    end
    param_path = param_paths[dataset]

    change_rootpath!(param_path, path_root_process)

    if isfile(param_path["path_param"])
        f = JLD2.jldopen(param_path["path_param"])
        params_dict[dataset] = f["param"]
        close(f)
    end
    
    param = params_dict[dataset]

    add_get_basename!(param_path)
    
    if isfile(param_path["path_data_dict"])
        f = JLD2.jldopen(param_path["path_data_dict"])
        data_dicts[dataset] = f["data_dict"]
        close(f)
    else
        data_dicts[dataset] = Dict()
    end
end

## Select which timepoint to use

By default, use the timepoint with the most neurons detected in the freely-moving traces. Note that you must also use this timepoint in the `AutoCellLabeler_freely_moving` notebook.

If you used other timepoints in the `AutoCellLabeler_freely_moving` notebook, set `best_timepts` to those timepoints.

In [None]:
best_timepts = Dict()

for dataset in datasets
    taq = data_dicts[dataset]["traces_array_quality"]
    # find column with most non-interpolated entries
    best_timepts[dataset] = argmax(sum(taq .> 0.00001, dims=1)[1,:])
end

## Process the data

After running `AutoCellLabeler_freely_moving` through the "Run TagRFP-only AutoCellLabeler" section, import the data here. Set `output_path` to the location of the `AutoCellLabeler_freely_moving` output. It should contain the following subdirectories:

- `h5` should contain the AutoCellLabeler input and prediction files
- `roi` should contain the original, uncropped ROI files from SegmentationNet
- `roi_crop` should contain the same ROI files but cropped to AutoCellLabeler's crop size

The fully-processed data for some SWF415 datasets is available on [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&st=ybsvv0ry&dl=0) under `AutoCellLabeler/SWF415_data`.

In [None]:
output_path = "/store1/PublishedData/Data/prj_register/AutoCellLabeler/SWF415_data"
output_path = "/data3/adam/new_unet_train/swf415_test"

create_dir(joinpath(output_path, "csv"))

In [None]:
for dataset in datasets_baseline
    data_dict = data_dicts[dataset]
    param = params_dict[dataset]
    param_path = param_paths[dataset]
    autolabel = pyimport("autolabel")

    param["autolabel_merged_rois_dist_thresh"] = 8
    param["autolabel_alt_label_threshold"] = 0.0
    param["autolabel_minimum_probability"] = 0.01
    param["autolabel_lrswap_threshold"] = 0.1
    param["autolabel_roi_edge_weight"] = 0.01
    param["autolabel_contamination_confidence_threshold"] = 0.75
    param["autolabel_contamination_num_threshold"] = 10
    param["autolabel_contamination_frac_threshold"] = 0.2
    param["autolabel_exclude_rois"] = []
    
    data_dict["roi_sizes"] = autolabel.get_roi_size(joinpath(output_path, "roi_crop", "$(dataset).h5"))
    path_predictions = joinpath(output_path, "h5", "$(dataset)_predictions.h5")
    data_dict["neuropal_probability_dict"], data_dict["contaminated_neurons"] = autolabel.create_probability_dict(joinpath(output_path, "roi_crop", "$(dataset).h5"), path_predictions)
    data_dict["neuropal_label_data"] = autolabel.output_label_file(data_dict["neuropal_probability_dict"], data_dict["contaminated_neurons"],
            data_dict["roi_sizes"], "/data3/adam/new_unet_train/extracted_neuron_ids_final_1.h5", joinpath(output_path, "roi", "$(dataset).nrrd"), joinpath(output_path, "csv", "$(dataset).csv"), 
            max_distance=param["autolabel_merged_rois_dist_thresh"], max_prob_decrease=param["autolabel_alt_label_threshold"], 
            min_prob=param["autolabel_minimum_probability"], exclude_rois=param["autolabel_exclude_rois"], 
            roi_matches=[],
            lrswap_threshold=param["autolabel_lrswap_threshold"], contamination_threshold=param["autolabel_contamination_num_threshold"],
            contamination_frac_threshold=param["autolabel_contamination_frac_threshold"]
    )
end

## Load CePNEM analysis dictionaries

These dictionaries store the results of the CePNEM model fits for all datasets. In the absence of available human labels for the SWF415 strain, consistency of these fits is used to evaluate the performance of AutoCellLabeler. They are available in [our Dropbox](https://www.dropbox.com/scl/fo/gms8q0hzcufczrqlpsk95/ANrhFTrQaLN6XJrTzkk5rBA?rlkey=ciiyxvsd3ya6i7tsyttnikfed&st=75yhtvvy&dl=0).

In [None]:
path_analysis_dict = "/store1/PublishedData/Data/prj_neuropal/analysis_dict.jld2"

analysis_dict = Dict()
if isfile(path_analysis_dict)
    f = JLD2.jldopen(path_analysis_dict)
    analysis_dict = f["analysis_dict"]
    close(f)
end;

In [None]:
excluded_classes = ["glia", "granule", "RIFL", "RIFR", "AFDL", "AFDR", "RMFL", "RMFR", "SIADL", "SIADR", "VA01", "VD01", "AVG", "DD01", "SABVL", "SABVR", "SIBDL", "SIBDR", "ADFL", "RIGL", "RIGR", "AVFL", "DB02"];

In [None]:
path_fit_results_lite = "/store1/PublishedData/Data/prj_neuropal/fit_results_lite.jld2"

fit_results = Dict()
if isfile(path_fit_results_lite)
    f = JLD2.jldopen(path_fit_results_lite)
    fit_results = f["fit_results_lite"]
    close(f)
end

In [None]:
datasets_neuropal_stim = ["2022-12-21-06", "2023-01-05-01", "2023-01-05-18", "2023-01-06-01", "2023-01-06-08", "2023-01-06-15", "2023-01-09-08", "2023-01-09-15", "2023-01-09-22", "2023-01-10-07", "2023-01-10-14", "2023-01-13-07", "2023-01-16-01", "2023-01-16-08", "2023-01-16-15", "2023-01-16-22", "2023-01-17-07", "2023-01-17-14", "2023-01-18-01"];

## Examine AutoCellLabeler concordance with CePNEM expectations

In this code:

- `neurons_rev` represents a list of neurons expected to encode reverse locomotion
- `neurons_fwd` represents a list of neurons expected to encode forward locomotion
- `neurons_dorsal` represents a list of neurons expected to encode dorsal locomotion
- `neurons_ventral` represents a list of neurons expected to encode ventral locomotion

The code computes the fraction of times the AutoCellLabeler labels for those neurons encode the expected locomotion property, and compares it against the fraction of human-labeled neurons (in the NeuroPAL strain) that encode the expected locomotion property. Specifically, the code simulates `n_trials` of randomly sampling each labeled neuron as either a random human label for that neuron (ie: AutoCellLabeler got the label right), or a random other neuron in that dataset (ie: AutoCellLabeler got the label wrong). In this way we can account for either (i) CePNEM failing to find significant encoding even when the label was correct or (ii) the neuron getting mislabeled as another neuron with the correct encoding.

By default, only baseline datasets are used since the heat stimulation can disrupt neural encodings.

In [None]:
neurons_rev = ["AVAL", "AVAR", "AVA?", "AVEL", "AVER", "AVE?", "RIML", "RIMR", "RIM?", "AIBL", "AIBR", "AIB?"]
neurons_fwd = ["RIBL", "RIBR", "RIB?", "AVBL", "AVBR", "AVB?", "RID", "RMEL", "RMER", "RMED"]
neurons_dorsal = ["SMDDL", "SMDDR"]
neurons_ventral = ["SMDVL", "SMDVR", "RIVL", "RIVR"]

neurons_nolr = ["RID", "RMED"]

beh_dict = Dict(
    "rev" => "v",
    "fwd" => "v",
    "dorsal" => "θh",
    "ventral" => "θh"
)

swap_dict = Dict(
    "rev" => "fwd",
    "fwd" => "rev",
    "dorsal" => "ventral",
    "ventral" => "dorsal"
)

neurons_of_interest = deepcopy(neurons_rev)
append!(neurons_of_interest, neurons_fwd)
append!(neurons_of_interest, neurons_dorsal)
append!(neurons_of_interest, neurons_ventral)
conf_thresh = 3

n_correct = 0
n_incorrect = 0
n_unknown = 0
n_both = 0
n_notrace = 0
conf = []

acc_fracs = 0:100
n_trials = 1000
n_correct_randomcontrol = zeros(length(acc_fracs), n_trials)
n_incorrect_randomcontrol = zeros(length(acc_fracs), n_trials)
n_unknown_randomcontrol = zeros(length(acc_fracs), n_trials)
n_both_randomcontrol = zeros(length(acc_fracs), n_trials)

all_correct = []
all_incorrect = []
all_unknown = []
all_both = []
conf_succeed = Dict()

@showprogress for dataset in datasets_baseline
    data_dict = data_dicts[dataset]
    conf_succeed[dataset] = []

    inv_valid_rois = Dict()
    for (i,roi) in enumerate(data_dict["valid_rois"])
        inv_valid_rois[roi] = UInt16(i)
    end
    for data in data_dict["neuropal_label_data"]

        if data["confidence"] >= conf_thresh && !(data["neuron_class"] in excluded_classes) && !occursin("alt", data["neuron_class"])
            push!(conf_succeed[dataset], (data["neuron_class"], data["max_prob"]))
        else
            continue
        end
        initial_neuron_idx = get(inv_valid_rois, get(data_dict["new_label_map"][best_timepts[dataset]], Int(data["roi_id"]), nothing), nothing)
        neuron_idx = nothing
        found = false
        for k in keys(data_dicts[dataset])
            if occursin("successful_idx_", k)
                found = true
                for (i, n) in enumerate(data_dicts[dataset][k])
                    if n == initial_neuron_idx
                        neuron_idx = i
                        break
                    end
                end
            end
        end

        if !found
            neuron_idx = initial_neuron_idx
        end

        # neuron_idx = rand(1:size(data_dicts[dataset]["traces_array"], 1))

        if data["neuron_class"] in neurons_of_interest && data["confidence"] >= conf_thresh
            push!(conf, data["max_prob"])
            if isnothing(neuron_idx)
                n_notrace += 1
                continue
            end
            encoding = nothing
            if data["neuron_class"] in neurons_rev
                encoding = "rev"
            elseif data["neuron_class"] in neurons_fwd
                encoding = "fwd"
            elseif data["neuron_class"] in neurons_dorsal
                encoding = "dorsal"
            elseif data["neuron_class"] in neurons_ventral
                encoding = "ventral"
            else
                error("Unknown neuron class: $(data["neuron_class"])")
            end
            correct = false
            incorrect = false
            beh = beh_dict[encoding]
            swap = swap_dict[encoding]
            for rng in analysis_dict["enc_change_rngs"][dataset]
                if neuron_idx in analysis_dict["neuron_categorization"][dataset][rng][beh][encoding]
                    correct = true
                end
                if neuron_idx in analysis_dict["neuron_categorization"][dataset][rng][beh][swap]
                    incorrect = true
                end
            end
            if correct
                if incorrect
                    n_both += 1
                    push!(all_both, (dataset, Int(neuron_idx), data["neuron_class"], data["max_prob"]))
                else

                    n_correct += 1
                    push!(all_correct, (dataset, Int(neuron_idx), data["neuron_class"], data["max_prob"]))
                end
            elseif incorrect
                n_incorrect += 1
                push!(all_incorrect, (dataset, Int(neuron_idx), data["neuron_class"], data["max_prob"]))
            else
                push!(all_unknown, (dataset, Int(neuron_idx), data["neuron_class"], data["max_prob"]))
                n_unknown += 1
            end

            neuron_class_nolr = data["neuron_class"]
            if !(data["neuron_class"] in neurons_nolr)
                neuron_class_nolr = neuron_class_nolr[1:end-1]
            end
            for i in 1:n_trials
                for (j, acc) in enumerate(acc_fracs)
                    # use legitimate encoding
                    if rand() < acc / 100
                        (dataset_rand, n_rand) = rand(analysis_dict["matches"][neuron_class_nolr])
                        count = 0
                        # exclude heat-stim datasets
                        while dataset_rand in datasets_neuropal_stim
                            (dataset_rand, n_rand) = rand(analysis_dict["matches"][neuron_class_nolr])
                            count += 1
                            # avoid infinite loops
                            if count > 10000
                                @warn("Could not find neuron $(neuron_class_nolr) in baseline data.")
                                break
                            end
                        end
                        
                        correct_random = false
                        incorrect_random = false
                        for rng in analysis_dict["enc_change_rngs"][dataset_rand]
                            if n_rand in analysis_dict["neuron_categorization"][dataset_rand][rng][beh][encoding]
                                correct_random = true
                            end
                            if n_rand in analysis_dict["neuron_categorization"][dataset_rand][rng][beh][swap]
                                incorrect_random = true
                            end
                        end
                        if correct_random
                            if incorrect_random
                                n_both_randomcontrol[j,i] += 1
                            else
                                n_correct_randomcontrol[j,i] += 1
                            end
                        elseif incorrect_random
                            n_incorrect_randomcontrol[j,i] += 1
                        else
                            n_unknown_randomcontrol[j,i] += 1
                        end
                    else
                        neuron_idx_random = rand(1:size(data_dicts[dataset]["traces_array"], 1))
                        correct_random = false
                        incorrect_random = false
                        for rng in analysis_dict["enc_change_rngs"][dataset]
                            if neuron_idx_random in analysis_dict["neuron_categorization"][dataset][rng][beh][encoding]
                                correct_random = true
                            end
                            if neuron_idx_random in analysis_dict["neuron_categorization"][dataset][rng][beh][swap]
                                incorrect_random = true
                            end
                        end
                        if correct_random
                            if incorrect_random
                                n_both_randomcontrol[j,i] += 1
                            else
                                n_correct_randomcontrol[j,i] += 1
                            end
                        elseif incorrect_random
                            n_incorrect_randomcontrol[j,i] += 1
                        else
                            n_unknown_randomcontrol[j,i] += 1
                        end
                    end
                end
            end
        end
    end
end

println(n_correct, " ", n_incorrect, " ", n_unknown, " ", n_both)

Raw fraction of labels with expected encoding (not corrected for random sample control):

In [None]:
n_correct / (n_correct + n_incorrect + n_unknown + n_both)

## Plot the results

Display the confidence interval based on the simulations, and plot the simulated vs observed CePNEM concordance.

In [None]:
let
    acc = n_correct_randomcontrol ./ (n_correct_randomcontrol .+ n_incorrect_randomcontrol .+ n_unknown_randomcontrol .+ n_both_randomcontrol)
    
    # Compute mean accuracy and 95th percentile error bounds
    mean_acc = [mean(acc[i, :]) for i in 1:101]
    lower_bound = [quantile(acc[i, :], 0.025) for i in 1:101]
    upper_bound = [quantile(acc[i, :], 0.975) for i in 1:101]

    # Measured values for the parameter
    measured_values = [n_correct / (n_correct + n_incorrect + n_unknown + n_both) for i in 1:101]


    # Find the first and last value of the parameter within the 95% confidence interval
    inside_interval = [(measured_values[i] >= lower_bound[i] && measured_values[i] <= upper_bound[i]) for i in 1:101]
    first_inside = findfirst(inside_interval)
    last_inside = findlast(inside_interval)

    # Find the value in control data closest to the measured value for each parameter
    closest_values = [mean(acc[i, :]) for i in 1:101]
  
    
    println("First parameter where measured value falls inside the 95% CI: ", first_inside - 1)
    println("Last parameter where measured value falls inside the 95% CI: ", last_inside - 1)
    println("Closest parameter to measured value: ", argmin(abs.(closest_values .- measured_values)) - 1)
    
    # Generate x values (parameters)
    x = (0:100) ./ 100

    # Create the plot
    fig, ax = subplots(figsize=(4.5,2))

    # Plot mean accuracy with shaded 95th percentile error bounds
    ax.plot(x, mean_acc, label="Simulations", linewidth=2)
    ax.fill_between(x, lower_bound, upper_bound, alpha=0.2)

    
    intersection_index = argmin(abs.(mean_acc .- measured_values))
    intersection_value = x[intersection_index]
    intersection_y = mean_acc[intersection_index]

    # Add a dashed vertical line at the intersection point
    ax.plot([intersection_value, intersection_value], [0, intersection_y], "k--", linewidth=2, label=nothing)

    # Add the second plot
    ax.scatter([intersection_value], [measured_values[1]], label="Autolabel", s=25, zorder=3)





    # Highlight the first and last inside interval points
    # if first_inside !== nothing
    #     ax.plot(x[first_inside], measured_values[first_inside], "go", label="First Inside Interval")
    # end
    # if last_inside !== nothing
    #     ax.plot(x[last_inside], measured_values[last_inside], "ro", label="Last Inside Interval")
    # end

    # Customize plot
    # ax.set_xlabel("Parameter", fontsize=7, fontname="DejaVu Sans")
    # ax.set_ylabel("Accuracy", fontsize=7, fontname="DejaVu Sans")
    ax.set_ylim(0, 1)
    ax.set_xlim(0, 1)
    ax.legend(loc="lower right", fontsize=7)

    # Customize ticks
    ax.tick_params(axis="both", which="major", labelsize=7)

    # Add bolded x tick mark at the intersection point
    ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0, intersection_value])
    ax.get_xticklabels()[end].set_weight("bold")

    # Disable top and right axes
    ax.spines["top"].set_visible(false)
    ax.spines["right"].set_visible(false)

    # Display the plot
    show()
    # PyPlot.savefig("/data3/prj_register/figures/figure_5/swf415_accuracy.pdf")
end
