# Align dLight data (plus other features) to syllable onset within a window

Each window is then concatenated together to create one (very large) dataframe.
Useful for computing syllable-triggered averages and syllable averages.

Note that you'll need to re-run to get online, offline, and movement-initiation aligned syllables. Change the appropriate flags in `../analysis_configuration.toml` . Most analysis uses -3 to +3 re: onset, and a small subset of figures uses -10 to +10 re: onset . 

1. Run with `window_bounds=[-3, 3]` , `snippet_grab.label_key = "predicted_syllable"` , and data keys under `SHORT WIN KEYS`
1. Run with `window_bounds=[-3, 3]` , `snippet_grab.label_key = "predicted_syllable (offline)"` , and data keys under `SHORT WIN KEYS`
1. Run with `window_bounds=[-10, 10]` , `snippet_grab.label_key = "predicted_syllable (offline)"` , and data keys under `LONG WIN KEYS`
1. Run with `window_bounds=[-10, 10]` , `snippet_grab.label_key = "movement_initiations"` , and data keys under `LONG WIN KEYS`

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from rl_analysis.photometry.util import align_window_to_label
from rl_analysis.batch import apply_parallel_joblib
from tqdm.auto import tqdm
from contextlib import redirect_stderr

import numpy as np
import pandas as pd
import os
import toml
import sys

terminal = sys.__stderr__

In [3]:
with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
data_dirs = analysis_config["data_dirs"]
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
snippet_cfg = analysis_config["dlight_snippet"]

## Data loading

In [5]:
df_filtered = pd.read_parquet(
    os.path.join(raw_dirs["dlight"], "dlight_photometry_processed_full.parquet"),
).sort_index()

In [6]:
quality_keys = ["signal_reference_corr", "reference_max", "signal_max"]

In [7]:
all_data_keys = df_filtered.columns.intersection(snippet_cfg["data_keys"]).to_list()
all_meta_keys = df_filtered.columns.intersection(snippet_cfg["meta_keys"] + quality_keys).to_list()

In [8]:
# do some conversions to save memory...

In [9]:
for k, v in tqdm(df_filtered[all_data_keys].dtypes.items(), total=len(all_data_keys)):
    if v == "float64":
        df_filtered[k] = df_filtered[k].astype("float32")

for k, v in tqdm(snippet_cfg["convs"].items()):
    df_filtered[k] = df_filtered[k].astype(v)

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [10]:
snippet_cfg["snippet_grab"]

{'window_bounds': [-3, 3], 'label_key': 'predicted_syllable'}

In [11]:
print(all_data_keys)

['pc00', 'pc01', 'pc02', 'pc03', 'pc04', 'pc05', 'pc06', 'pc07', 'pc08', 'pc09', 'centroid_x_mm', 'centroid_y_mm', 'velocity_2d_mm', 'height_ave_mm', 'feedback_status', 'timestamp', 'angle', 'acceleration_2d_mm', 'jerk_2d_mm', 'angle_unwrapped', 'velocity_angle', 'velocity_height', 'signal_reref_dff', 'signal_reref_dff_z']


## Get snippets

In [12]:
func = align_window_to_label(
    **snippet_cfg["snippet_grab"],
    fs=analysis_config["common"]["fs"],
    data_keys=all_data_keys,
    meta_keys=all_meta_keys,
)

In [13]:
group_obj = df_filtered.groupby("uuid")

In [14]:
njobs = group_obj.ngroups

In [15]:
with redirect_stderr(terminal):
    print(f"{njobs} jobs to process", file=terminal)
    snippet_df = apply_parallel_joblib(group_obj, func, n_jobs=-20)

898 jobs to process
[Parallel(n_jobs=-20)]: Using backend LokyBackend with 109 concurrent workers.
[Parallel(n_jobs=-20)]: Done   3 tasks      | elapsed:   40.8s
[Parallel(n_jobs=-20)]: Done  24 tasks      | elapsed:   48.6s
[Parallel(n_jobs=-20)]: Done  47 tasks      | elapsed:   55.5s
[Parallel(n_jobs=-20)]: Done  70 tasks      | elapsed:  1.0min
[Parallel(n_jobs=-20)]: Done  95 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-20)]: Done 120 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-20)]: Done 147 tasks      | elapsed:  1.4min
[Parallel(n_jobs=-20)]: Done 174 tasks      | elapsed:  1.5min
[Parallel(n_jobs=-20)]: Done 203 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-20)]: Done 232 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-20)]: Done 263 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-20)]: Done 294 tasks      | elapsed:  2.1min
[Parallel(n_jobs=-20)]: Done 327 tasks      | elapsed:  2.2min
[Parallel(n_jobs=-20)]: Done 360 tasks      | elapsed:  2.4min
[Parallel(n_jobs=-2

### Saving

In [16]:
for k, v in tqdm(snippet_cfg["convs"].items()):
    snippet_df[k] = snippet_df[k].astype(v)

  0%|          | 0/8 [00:00<?, ?it/s]

In [17]:
snippet_df.index = range(len(snippet_df))

In [18]:
try:
    codes = pd.factorize(
        pd._libs.lib.fast_zip([snippet_df["snippet"].values, snippet_df["uuid"].cat.codes.values])
    )[0]
except AttributeError as e:
    print(e)
    codes = pd.factorize(
        list(zip(snippet_df["snippet"].tolist(), snippet_df["uuid"].tolist()))
    )[0]

snippet_df["snippet"] = codes

In [19]:
import pyarrow as pa
pa.set_cpu_count(10)

In [20]:
if np.diff(snippet_cfg["snippet_grab"]["window_bounds"]) > 10:
    file_suffix = "_longwin"
else:
    file_suffix = ""

if snippet_cfg["snippet_grab"]["label_key"] == "predicted_syllable":
    save_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_online{file_suffix}.parquet")
elif snippet_cfg["snippet_grab"]["label_key"] == "predicted_syllable (offline)":
    save_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_offline{file_suffix}.parquet")
elif snippet_cfg["snippet_grab"]["label_key"] == "movement_initiations":
    save_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_movements{file_suffix}.parquet")
else:
    RuntimeError("label key not recognized")

snippet_df.to_parquet(
    save_file,
    allow_truncated_timestamps=True,
    partition_cols=["area", "mouse_id", "uuid"],
    existing_data_behavior="delete_matching",
)

In [21]:
print(save_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_raw_data/dlight_snippets_online.parquet
