# AutoCellLabeler on Freely-Moving Data

In [None]:
import os

import pandas as pd 

import nrrd
import numpy as np

import h5py

from tqdm import tqdm

from matplotlib import pyplot as plt

import shutil

import openpyxl
import csv
import re

import itertools

from functools import reduce

from autolabel import *

from mpl_toolkits.axes_grid1 import make_axes_locatable



## Load datasets, neuron names

In [None]:
datasets_prj_neuropal = ["2022-07-15-06", "2022-07-15-12", "2022-07-20-01", "2022-07-26-01", "2022-08-02-01", "2023-01-23-08", "2023-01-23-15", "2023-01-23-21", "2023-01-19-08", "2023-01-19-22", "2023-01-09-28", "2023-01-17-01", "2023-01-19-15", "2023-01-23-01", "2023-03-07-01", "2022-12-21-06", "2023-01-05-18", "2023-01-06-01", "2023-01-06-08", "2023-01-09-08", "2023-01-09-15", "2023-01-09-22", "2023-01-10-07", "2023-01-10-14", "2023-01-13-07", "2023-01-16-01", "2023-01-16-08", "2023-01-16-15", "2023-01-16-22", "2023-01-17-07", "2023-01-17-14", "2023-01-18-01"]
datasets_prj_rim = ["2023-06-09-01", "2023-07-28-04", "2023-06-24-02", "2023-07-07-11", "2023-08-07-01", "2023-06-24-11", "2023-07-07-18", "2023-08-18-11", "2023-06-24-28", "2023-07-11-02", "2023-08-22-08", "2023-07-12-01", "2023-07-01-09", "2023-07-13-01", "2023-06-09-10", "2023-07-07-01", "2023-08-07-16", "2023-08-22-01", "2023-08-23-23", "2023-08-25-02", "2023-09-15-01", "2023-09-15-08", "2023-08-18-18", "2023-08-19-01", "2023-08-23-09", "2023-08-25-09", "2023-09-01-01", "2023-08-31-03", "2023-07-01-01", "2023-07-01-23"]
datasets_prj_aversion = ["2023-03-30-01", "2023-06-29-01", "2023-06-29-13", "2023-07-14-08", "2023-07-14-14", "2023-07-27-01", "2023-08-08-07", "2023-08-14-01", "2023-08-16-01", "2023-08-21-01", "2023-09-07-01", "2023-09-14-01", "2023-08-15-01", "2023-10-05-01", "2023-06-23-08", "2023-12-11-01", "2023-06-21-01"]
datasets_prj_5ht = ["2022-07-26-31", "2022-07-26-38", "2022-07-27-31", "2022-07-27-38", "2022-07-27-45", "2022-08-02-31", "2022-08-02-38", "2022-08-03-31"]
datasets_prj_starvation = ["2023-05-25-08", "2023-05-26-08", "2023-06-05-10", "2023-06-05-17", "2023-07-24-27", "2023-09-27-14", "2023-05-25-01", "2023-05-26-01", "2023-07-24-12", "2023-07-24-20", "2023-09-12-01", "2023-09-19-01", "2023-09-29-19", "2023-10-09-01", "2023-09-13-02"]

datasets = datasets_prj_neuropal + datasets_prj_rim + datasets_prj_aversion + datasets_prj_5ht + datasets_prj_starvation
print(len(set(datasets)) == len(datasets))

datasets_val = ['2023-06-24-02', '2023-08-07-01', '2023-08-19-01', # RIM datasets
                '2022-07-26-01', '2023-01-23-21', '2023-01-23-01', # NeuroPAL datasets
                '2023-07-14-08', # Aversion datasets
                '2022-08-02-31', # 5-HT datasets
                '2023-07-24-27', '2023-07-24-20'] # Starvation datasets
datasets_test = ['2023-08-22-01', '2023-07-07-18', '2023-07-01-23',  # RIM datasets
                 '2023-01-06-01', '2023-01-10-07', '2023-01-17-07', # Neuropal datasets
                 '2023-08-21-01', "2023-06-23-08", # Aversion datasets
                 '2022-07-27-38', # 5-HT datasets
                 '2023-10-09-01', '2023-09-13-02' # Starvation datasets
                 ]
datasets_train = [dataset for dataset in datasets if dataset not in datasets_val and dataset not in datasets_test]

You can download this file from [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=4&st=ybsvv0ry&dl=0) under `AutoCellLabeler/train_val_test_data/extracted_neuron_ids_final_1.h5`:

In [None]:
path_extracted_neuron_ids = "/store1/PublishedData/Data/prj_register/AutoCellLabeler/train_val_test_data/extracted_neuron_ids_final_1.h5"
extracted_neuron_ids = []
with h5py.File(path_extracted_neuron_ids, 'r') as f:
    extracted_neuron_ids = [name.decode('utf-8') for name in f['neuron_ids'][:]]

## Copy files from freely-moving datasets

This block of code copies random time points from freely-moving datasets, directly from ANTSUN output. There are also pre-copied files available in [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=4&st=ybsvv0ry&dl=0) under `AutoCellLabeler/freely_moving_eval`. If using those files, feel free to skip this section. If using ANTSUN output, it is important to run this code block.

Replace the paths to the data with where they are on your machine.

In [None]:
input_paths = {
    "prj_rim": "/store1/prj_rim/data_processed",
    "prj_neuropal": "/store1/prj_neuropal/data_processed",
    "prj_starvation": "/data1/prj_starvation/data_processed",
    "prj_5ht": "/data3/prj_5ht/published_data/data_processed_neuropal",
    "prj_aversion": "/data1/prj_aversion/data_processed"
}

output_path_nrrd = "/path/to/your/data_dir/nrrd"
output_path_roi = "/path/to/your/data_dir/roi"

In [None]:
dataset_timepoints = {}
n_trials = 100
max_t = 1600 # max timepoint in each dataset

for dataset in datasets_test:
    if dataset in datasets_prj_rim:
        prj_dir = input_paths["prj_rim"]
    elif dataset in datasets_prj_neuropal:
        prj_dir = input_paths["prj_neuropal"]
    elif dataset in datasets_prj_starvation:
        prj_dir = input_paths["prj_starvation"]
    elif dataset in datasets_prj_5ht:
        prj_dir = input_paths["prj_5ht"]
    elif dataset in datasets_prj_aversion:
        prj_dir = input_paths["prj_aversion"]
    else:
        continue
    
    base_path = os.path.join(prj_dir, dataset + "_output")
    dataset_timepoints[dataset] = []
    for trial in tqdm(range(n_trials)):
        t = np.random.randint(max_t)
        nrrd_path = os.path.join(base_path, "NRRD_cropped", dataset + "_t" + str(t).zfill(4) + "_ch2.nrrd")
        watershed_path = os.path.join(base_path, "img_roi_watershed", str(t) + ".nrrd")
        n_reattempts = 0
        while (t in dataset_timepoints[dataset] or not os.path.exists(nrrd_path) or not os.path.exists(watershed_path)) and n_reattempts < 1000:
            t = np.random.randint(max_t)
            nrrd_path = os.path.join(base_path, "NRRD_cropped", dataset + "_t" + str(t).zfill(4) + "_ch2.nrrd")
            watershed_path = os.path.join(base_path, "img_roi_watershed", str(t) + ".nrrd")
            n_reattempts += 1
        if n_reattempts >= 1000:
            raise(ValueError("Could not find valid timepoint for dataset " + dataset))
        
        expand_nrrd_dimension(nrrd_path, os.path.join(output_path_nrrd, dataset + "_" + str(t) + ".nrrd"))
        shutil.copy(watershed_path, output_path_roi)
        dataset_timepoints[dataset].append(t)


## Convert files to H5

This section converts some previously-copied files to H5 format. If you have already downloaded the data files from the Dropbox link above, start here. If you're using copied data from this notebook, set `input_path_nrrd` and `input_path_roi` to be the previous code block's `output_nrrd` and `output_path_roi`, respectively.

In [None]:
input_path_nrrd = "/store1/PublishedData/Data/prj_register/AutoCellLabeler/freely_moving_eval/nrrd"
input_path_roi = "/store1/PublishedData/Data/prj_register/AutoCellLabeler/freely_moving_eval/img_roi_watershed"

output_path_roi_crop = "/path/to/your/data_dir/roi_crop"
output_path_h5 = "/path/to/your/data_dir/h5"


In [None]:
def list_files_to_dict(directory):
    """
    Helper function to list which timepoints were selected in each NRRD directory.
    """
    files = os.listdir(directory)
    uid_dict = {}
    for file in files:
        if file.endswith(".nrrd"):
            uid, timepoint = file.split("_")
            timepoint = timepoint.split(".")[0]  # Remove the extension
            if uid in uid_dict:
                uid_dict[uid].append(timepoint)
            else:
                uid_dict[uid] = [timepoint]
    return uid_dict


uid_timepoint_dict = list_files_to_dict(input_path_nrrd)


In [None]:
os.path.join(output_path_roi_crop, dataset + "_" + t + ".h5")

In [None]:
cropouts = {}

for dataset in uid_timepoint_dict:
    cropouts[dataset] = []
    for t in tqdm(uid_timepoint_dict[dataset]):
        cropout = create_h5_from_nrrd(
                os.path.join(input_path_nrrd, dataset + "_" + t + ".nrrd"),
                os.path.join(output_path_h5, dataset + "_" + t + ".h5"),
                os.path.join(input_path_roi, dataset + "_" + t + ".nrrd"),
                os.path.join(output_path_roi_crop, dataset + "_" + t + ".h5"),
                (64, 120, 284), # crop size
                len(extracted_neuron_ids)+1
        )
        cropouts[dataset].append(cropout)

## Run TagRFP-only AutoCellLabeler

To run the TagRFP-only AutoCellLabeler network on these H5 files, see the [`pytorch-3dunet` package](https://github.com/flavell-lab/pytorch-3dunet), which contains the code and parameter files for this `all_red` network. This network's weights can be found in [our Dropbox](https://www.dropbox.com/scl/fo/ealblchspq427pfmhtg7h/ALZ7AE5o3bT0VUQ8TTeR1As?rlkey=1e6tseyuwd04rbj7wmn2n6ij7&e=4&st=ybsvv0ry&dl=0) under `AutoCellLabeler/model_weights/paper_all_red.pytorch`.

## Load ROI matches

Assuming you are using ANTSUN-processed data, you can run the `extract_roi_matches.ipynb` Julia notebook to extract the ROI matches dictionary from the ANTSUN outputs. This dictionary is matches the ROIs in the freely-moving data to those in the immobilized NeuroPAL images, which can be used to assess the quality of the AutoCellLabeler predictions on the freely-moving data.

If you are using the data from our Dropbox, this dictionary is available in `AutoCellLabeler/freely_moving_eval/roi_match.h5`.

In [None]:
def load_h5_to_dict(filename):
    data = {}
    with h5py.File(filename, 'r') as file:
        # Iterate over each dataset in the file
        for key in file.keys():
            # Each dataset is loaded as a numpy array
            data[key] = file[key][:].transpose()
    return data

roi_matches = load_h5_to_dict("/store1/PublishedData/Data/prj_register/AutoCellLabeler/freely_moving_eval/roi_match.h5")

## Load AutoCellLabeler predictions

In [None]:
output_path_csv = "/store1/adam/test/AutoCellLabeler/csv"# "/path/to/your/data_dir/csv"

In [None]:
prob_dict_paper_all_red_fm = {}
contaminated_neurons_paper_all_red_fm = {}
output_dict_paper_all_red_fm = {}
roi_sizes = {}
for dataset_test in uid_timepoint_dict:
    prob_dict_paper_all_red_fm[dataset_test] = {}
    contaminated_neurons_paper_all_red_fm[dataset_test] = {}
    roi_sizes[dataset_test] = {}
    output_dict_paper_all_red_fm[dataset_test] = {}
    for t in tqdm(uid_timepoint_dict[dataset_test]):
        prob_dict_paper_all_red_fm[dataset_test][t], contaminated_neurons_paper_all_red_fm[dataset_test][t] = create_probability_dict(
                os.path.join(output_path_roi_crop, dataset_test + "_" + str(t) + ".h5"), 
                os.path.join(output_path_h5, dataset_test + "_" + str(t) + "_predictions.h5"),
                contamination_threshold=0.75
        )
        roi_sizes[dataset_test][t] = get_roi_size(os.path.join(output_path_roi_crop, dataset_test + "_" + str(t) + ".h5"))
        output_dict_paper_all_red_fm[dataset_test][t] = output_label_file(
                prob_dict_paper_all_red_fm[dataset_test][t],
                contaminated_neurons_paper_all_red_fm[dataset_test][t],
                roi_sizes[dataset_test][t], 
                path_extracted_neuron_ids,
                os.path.join(input_path_roi, dataset_test + "_" + t + ".nrrd"),
                os.path.join(output_path_csv, dataset_test + "_" + str(t) + ".csv"),
                max_prob_decrease=0.0, 
                confidence_demote=-1
        )


## Parse AutoCellLabeler predictions

In the freely-moving dataset, we can take advantage of multiple timepoints of data to construct more accurate predictions. This code accomplishes this by averaging the probability arrays for a neuron across all the timepoints where that neuron was linked to an ROI in that timepoint.

In [None]:
summed_roi_matches = {}
max_t_np = 0
max_roi_count = 339 # maximum number of ROIs in the immobilized dataset. You may need to increase this.

for dataset in uid_timepoint_dict:
    summed_roi_matches[dataset] = np.zeros((max_roi_count, len(extracted_neuron_ids)+1))
    count = np.zeros(339)
    for t_str in uid_timepoint_dict[dataset]:
        t = int(t_str)
        for roi in prob_dict_paper_all_red_fm[dataset][t_str]:
            if roi_matches[dataset][t,roi] > 0:
                t_np = int(roi_matches[dataset][t-1,roi-1]) # -1s are necessary Julia -> Python conversion
                if t_np > max_t_np:
                    max_t_np = t_np
                count[t_np] += 1
                summed_roi_matches[dataset][t_np,:] += prob_dict_paper_all_red_fm[dataset][t_str][roi]

    summed_roi_matches[dataset] /= np.maximum(np.sum(summed_roi_matches[dataset], axis=1), 1e-100)[:, np.newaxis]


## Save your data

In [None]:
with h5py.File("/path/to/your/data_dir/summed_roi_matches_paper_all_red_fm.h5", 'w') as f:
    for dataset in summed_roi_matches:
        f.create_dataset(dataset, data=summed_roi_matches[dataset])

In [None]:
def save_nested_dict_to_h5(filename, data):
    with h5py.File(filename, 'w') as file:
        for dataset, times in data.items():
            dataset_group = file.create_group(dataset)
            for timepoint, rois in times.items():
                timepoint_group = dataset_group.create_group(str(timepoint))
                for roi, array in rois.items():
                    timepoint_group.create_dataset(str(roi), data=array)

In [None]:
save_nested_dict_to_h5('/data3/adam/new_unet_train/test_paper_noisy/prob_dict.h5', prob_dict_paper_all_red_fm)
