In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=true
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2
%env JAX_DEBUG_NANS=true
%env CUDA_VISIBLE_DEVICES=0

env: XLA_PYTHON_CLIENT_PREALLOCATE=true
env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2
env: JAX_DEBUG_NANS=true
env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import os

from cepnem_jax import h5_to_dict, dict_to_h5

### Define some constants used in importation of data

In [3]:
v_STD = 0.06030961137253011
vT_STD = 0.9425527026496543
θh_STD = 0.49429038957075727
P_STD = 1.2772001409506841

### Define paths for each project

In [4]:
prj_data = {}

In [5]:
# PRJ_RIM

prj_data["prj_rim"] = {
    "processed_h5_directory": "/data3/prj_rim/processed_h5", # Directory containing the processed h5 files
    "neuropal_file": "/data3/prj_rim/Decoding_Data/Aggregated_Traces_h5/dict_neuropal_label.h5", # File containing the neuropal labels
    "octanol_directory": "/data3/prj_rim/OctanolEncounters", # File containing the octanol events
    "structured_dict": h5_to_dict("/data3/prj_rim/Decoding_Data/Structured_Data_Info.h5")
    }

In [6]:
# PRJ_KFC

# prj_data["prj_kfc"] = {
#     "processed_h5_directory": "/store1/prj_kfc/data/processed_h5", # Directory containing the processed h5 files
#     "neuropal_file": "/store1/prj_jax/Aggregated_Traces_h5/dict_neuropal_label_prj_kfc.h5", # File containing the neuropal labels
#     "structured_dict": h5_to_dict("/store1/prj_kfc/Structured_Data_Info.h5")
#     }

In [7]:
result_dict_comb = {}
data = {}

for project, prj_vals in prj_data.items():
    ds_of_interest = [ds for ds, value in prj_vals["structured_dict"].items() if set(["wt"] if project=="prj_rim" else ["baseline", "neuropal"]).issubset(set(value["Tags"]))]
    result_dict = h5_to_dict(prj_vals["neuropal_file"]) # Load the neuropal labels
    
    for neuron in result_dict.keys():
        if(neuron == 'glia' or neuron == 'UNKNOWN' or '?' in neuron): # Ignore glia and unknown neurons
            continue
            
        if(neuron not in result_dict_comb.keys()): # Initialize the dictionary entry for a neuron not already found
            result_dict_comb[neuron] = {}
        for ds in list(result_dict[neuron].keys()): # Add the neuropal labels to the combined dictionary
            if(ds in ds_of_interest and not (project=='prj_kfc' and (neuron.startswith('SMB')))): # Ignore the SMB neurons in the prj_kfc
                result_dict_comb[neuron][ds] = result_dict[neuron][ds]
                
                
    for ds in ds_of_interest:
        loaded_data = h5_to_dict(os.path.join(prj_vals["processed_h5_directory"], ds + "-data.h5"))
        if(not 'pumping' in loaded_data['behavior'].keys()):
            continue
        data[ds] = {}
        data[ds]['neuron_traces'] =  loaded_data['gcamp']['trace_array'] 
        data[ds]['velocity'] =  loaded_data['behavior']['velocity'] / v_STD
        data[ds]['head_angle'] = loaded_data['behavior']['head_angle'] / θh_STD * (-1 if prj_vals['structured_dict'][ds]['Flipped'] else 1)
        data[ds]['pumping'] = loaded_data['behavior']['pumping'] / P_STD
        data[ds]['reversal_events'] = loaded_data['behavior']['reversal_events'].T

### Find all neuron classes with DV pairs. Parse out only Ventral neurons

In [8]:
neuron_dv_classes = {}

for neuron in result_dict_comb:
    if(len(neuron) < 4):
        continue
    if(neuron[3] in "V"):
        if(neuron[-1] in "LR"):
            if(neuron[:3] in neuron_dv_classes):
                neuron_dv_classes[neuron[:3]][neuron[-1]].append(neuron)
            else:
                neuron_dv_classes[neuron[:3]] = {'L': [], 'R': []}
                neuron_dv_classes[neuron[:3]][neuron[-1]].append(neuron)
        else:
            if(neuron[:3] in neuron_dv_classes):
                neuron_dv_classes[neuron[:3]].append(neuron)
            else:
                neuron_dv_classes[neuron[:3]] = [neuron]

### Some helper functions

In [9]:
def intersection(lsts):
    if(len(lsts) > 2):
        return list(set(lsts[0]) & set(intersection(lsts[1:])))
    else:
        return list(set(lsts[0]) & set(lsts[1]))
    
def union(lsts):
    if(len(lsts) > 2):
        return list(set(lsts[0]) | set(union(lsts[1:])))
    else:
        return list(set(lsts[0]) | set(lsts[1]))

### Parse the data

In [10]:
# Define the time_ranges events of interest in each dataset


key = jax.random.PRNGKey(0)

time_range = 12
offset = 0
max_shift = 0
min_reversal_len = time_range + offset

decoding_data = {}

result_dict = result_dict_comb

# Go through datasets and find all reversals that are long enough and have enough data before and after
for ds in data.keys():
    reversal_starts = []
    reversal_ends = []
    turns = []
    
    for reversal in data[ds]['reversal_events']:
        if(reversal[1] - reversal[0] >= min_reversal_len and reversal[1]+2 < len(data[ds]['head_angle'])):        
            post_reversal_turn = "D" if np.mean([data[ds]['head_angle'][reversal[1]+1], data[ds]['head_angle'][reversal[1]+2]]) > 0 else "V" # Determine the turn after the reversal
            reversal_starts.append(reversal[0] - offset + time_range)
            reversal_ends.append(reversal[1] - offset)
            
            
            turns.append(post_reversal_turn)
        
    data[ds]['reversal_starts'] = reversal_starts
    data[ds]['reversal_ends'] = reversal_ends
    data[ds]['all_turns'] = turns
    
    # if("D" not in turns):
    #     print(f"Warning: No dorsal turns detected in {ds}")
    #     data[ds]['reversal_starts'] = []
    #     data[ds]['reversal_ends'] = []
    #     data[ds]['all_turns'] = []

In [11]:
# Extract neuron activation and head curvature data for each event of interest

for dv_class in neuron_dv_classes:
    if(isinstance(neuron_dv_classes[dv_class], list)):
        datasets = list(result_dict_comb[neuron_dv_classes[dv_class][0]].keys())
    else:
        L_datasets = list(result_dict_comb[neuron_dv_classes[dv_class]["L"][0]].keys())
        R_datasets = list(result_dict_comb[neuron_dv_classes[dv_class]["R"][0]].keys())
        datasets = union([L_datasets, R_datasets])
    # print(f"{dv_class}V {datasets}")
    
    n_valid_turns = 0
    
    # Go through datasets and find all reversals that are long enough and have enough data before and after
    for ds in datasets:
        n_valid_turns += len(data[ds]['all_turns'])
        
    all_turns = []
    reversal_start_neuron_data = np.zeros((n_valid_turns, time_range, 3))
    reversal_end_neuron_data = np.zeros((n_valid_turns, time_range, 3))
    
    for source, neuron_data_arr in zip(["start", "end"], [reversal_start_neuron_data, reversal_end_neuron_data]):
        # Go through datasets and extract the neural data for each reversal event
        idx = 0
        for ds in datasets:
            t_events = data[ds][f'reversal_{source}s']
        
            for i, t_event in enumerate(t_events):
                # Extract the head curvature data
                head_curv = data[ds]['head_angle'][t_event-time_range:t_event] 
                
                if(isinstance(neuron_dv_classes[dv_class], list)):
                    # If the neuron doesn't have a left and right neuron, just use the one neuron
                    neuron_data = data[ds]['neuron_traces'][t_event-time_range:t_event, result_dict[neuron_dv_classes[dv_class][0]][ds]["index"]-1]
                else:
                    # If the neuron is present in both the left and right neurons, randomly choose one. Otherwise, use the one that is present
                    neuron_data = np.zeros((time_range))
                    if(ds in result_dict[neuron_dv_classes[dv_class]['L'][0]]):
                        neuron_data_L = data[ds]['neuron_traces'][t_event-time_range:t_event, result_dict[neuron_dv_classes[dv_class]['L'][0]][ds]["index"]-1]
                        neuron_data = neuron_data_L
                    if(ds in result_dict[neuron_dv_classes[dv_class]['R'][0]]):
                        neuron_data_R = data[ds]['neuron_traces'][t_event-time_range:t_event, result_dict[neuron_dv_classes[dv_class]['R'][0]][ds]["index"]-1]
                        neuron_data = neuron_data_R
                    if(ds in result_dict[neuron_dv_classes[dv_class]['L'][0]] and ds in result_dict[neuron_dv_classes[dv_class]['R'][0]]):
                        # Both L and R are detected. Randomly choose between the left and right neurons
                        key, sk = jax.random.split(key)
                        neuron_data = neuron_data_L if 0.5 < jax.random.uniform(sk, (1,)) else neuron_data_R 
                    
                # Stack the neural data and head curvature data
                neuron_data_arr[idx] = np.stack([neuron_data, head_curv, np.zeros((time_range))], axis=1) 
                    
                # Append the turn direction
                all_turns.append(data[ds]['all_turns'][i])
                idx += 1
    
    ventral_turn_start = []
    dorsal_turn_start = []

    ventral_turn_end = []
    dorsal_turn_end = []

    # Reformat the data into a format that can be used for decoding
    for (turn, start_value, end_value) in zip(all_turns, reversal_start_neuron_data, reversal_end_neuron_data):
        if(turn == "V"):
            ventral_turn_start.append(start_value)
            ventral_turn_end.append(end_value)
        elif(turn == "D"):
            dorsal_turn_start.append(start_value)
            dorsal_turn_end.append(end_value)
            
    ventral_turn_start = jnp.array(ventral_turn_start)
    dorsal_turn_start = jnp.array(dorsal_turn_start)

    ventral_turn_end = jnp.array(ventral_turn_end)
    dorsal_turn_end = jnp.array(dorsal_turn_end)
    
    # Print out some stats
    # print(f"Class: {dv_class}V Detections: {len(datasets):2d} Turns: {n_valid_turns:4d} (D: {dorsal_turn_start.shape[0]:4d}, V: {ventral_turn_start.shape[0]:4d})")
    print(f"Class: {dv_class}V Turns: {n_valid_turns:4d} (D: {dorsal_turn_start.shape[0]:4d}, V: {ventral_turn_start.shape[0]:4d})")

    decoding_data[dv_class] = {}

    decoding_data[dv_class]['ventral_turn_start'] = ventral_turn_start
    decoding_data[dv_class]['dorsal_turn_start'] = dorsal_turn_start
    
    decoding_data[dv_class]['ventral_turn_end'] = ventral_turn_end
    decoding_data[dv_class]['dorsal_turn_end'] = dorsal_turn_end
    
    # Save to an h5 file
    dict_to_h5(f"/home/alex/data3/prj_rim/Decoding_Data/Finalized/DVPostReversal/turn_predict_{dv_class}V.h5", decoding_data[dv_class], overwrite=True)

Class: CEPV Turns:  380 (D:   62, V:  318)
Class: IL1V Turns:  364 (D:   63, V:  301)
Class: IL2V Turns:  379 (D:   64, V:  315)
Class: OLQV Turns:  375 (D:   58, V:  317)
Class: RMDV Turns:  341 (D:   56, V:  285)
Class: RMEV Turns:  235 (D:   43, V:  192)
Class: SAAV Turns:  340 (D:   54, V:  286)
Class: SIAV Turns:  277 (D:   44, V:  233)
Class: SIBV Turns:  243 (D:   36, V:  207)
Class: SMBV Turns:  302 (D:   51, V:  251)
Class: SMDV Turns:  393 (D:   64, V:  329)
Class: URAV Turns:  273 (D:   44, V:  229)
Class: URYV Turns:  393 (D:   64, V:  329)


In [12]:

# Before added data and fixed labels
print("""
Class: CEPV Turns: 1214 (D: 356, V: 858)
Class: IL1V Turns:  933 (D: 231, V: 702)
Class: IL2V Turns: 1327 (D: 366, V: 961)
Class: OLQV Turns: 1330 (D: 364, V: 966)
Class: RMDV Turns: 1314 (D: 369, V: 945)
Class: RMEV Turns:  909 (D: 263, V: 646)
Class: SAAV Turns:  932 (D: 240, V: 692)
Class: SIAV Turns:  258 (D:  43, V: 215)
Class: SIBV Turns:  209 (D:  39, V: 170)
Class: SMBV Turns:  321 (D:  72, V: 249)
Class: SMDV Turns: 1320 (D: 374, V: 946)
Class: URAV Turns:  644 (D: 197, V: 447)
Class: URYV Turns: 1347 (D: 362, V: 985)
""")


Class: CEPV Turns: 1214 (D: 356, V: 858)
Class: IL1V Turns:  933 (D: 231, V: 702)
Class: IL2V Turns: 1327 (D: 366, V: 961)
Class: OLQV Turns: 1330 (D: 364, V: 966)
Class: RMDV Turns: 1314 (D: 369, V: 945)
Class: RMEV Turns:  909 (D: 263, V: 646)
Class: SAAV Turns:  932 (D: 240, V: 692)
Class: SIAV Turns:  258 (D:  43, V: 215)
Class: SIBV Turns:  209 (D:  39, V: 170)
Class: SMBV Turns:  321 (D:  72, V: 249)
Class: SMDV Turns: 1320 (D: 374, V: 946)
Class: URAV Turns:  644 (D: 197, V: 447)
Class: URYV Turns: 1347 (D: 362, V: 985)

