# Here we are taking dataframes with dLight "snippets" (waveforms aligned to onset of an event) and computing basic summary features

1. Load in dLight and scalar snippets
1. Load in pre-snippet dLight data to collect stats for downstream processing (e.g. syllable ID and sequence)
1. Compute features

Note that you'll need to run twice to get online and offline features, which requires editing `../analysis_configuration.toml`

1. Run once with `use_offline=True` under `dlight_transition_features`
1. Run once with `use_offline=False` under `dlight_transition_features`

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.util import rle
from rl_analysis.photometry.util import (
    renormalize_df,
    compute_pandas_feature,
    compute_numba_feature,
)
from rl_analysis.photometry.features import nanargmax
from joblib import Parallel, delayed
from contextlib import redirect_stderr

import pandas as pd
import numpy as np
import os
import sys

terminal = sys.__stderr__

# Preprocessing

In [3]:
import toml
with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
raw_dirs = analysis_config["raw_data"]
data_dirs = analysis_config["data_dirs"]
proc_dirs = analysis_config["intermediate_results"]
transition_cfg = analysis_config["dlight_transition_features"]
common_cfg = analysis_config["common"]

In [5]:
file_suffix = "offline" if transition_cfg["use_offline"] else "online"
syllable_key = "predicted_syllable (offline)" if transition_cfg["use_offline"] else "predicted_syllable"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [6]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

if transition_cfg["renormalize"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

In [7]:
print(features_save_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_raw_data/dlight_snippets_online_features.parquet


# Basic logic: load in snippets for feature-ization, get usages from the "full" dataframe

In [8]:
partition_cols = ["area", "mouse_id", "uuid"]

In [9]:
snippet_df = pd.read_parquet(load_file).sort_index()
if transition_cfg["renormalize"]:
    snippet_df = renormalize_df(snippet_df, normalize_keys=["signal_reref_dff", "signal_dff", "reference_dff"])

In [10]:
snippet_df[partition_cols] = snippet_df[partition_cols].astype("str")
snippet_df = snippet_df.loc[snippet_df["x"].between(*transition_cfg["pre_window"])].copy()

In [11]:
full_df = pd.read_parquet(
    os.path.join(raw_dirs["dlight"], "dlight_photometry_processed_full.parquet"),
    columns=[
        "uuid",
        "predicted_syllable (offline)",
        "predicted_syllable",
        "timestamp",
        "mouse_id",
        "session_number",
        "target_syllable",
        "opsin",
        "area",
        "date",
        "signal_max",
        "signal_reference_corr",
        "signal_reref_dff_z",
        "signal_dff",
        "stim_duration",
        "reference_dff",
        "labels",
        "signal_reref_dff",
    ],
).sort_index()

In [12]:
full_df[partition_cols] = full_df[partition_cols].astype("str")

In [13]:
full_df.rename(columns={syllable_key: "syllable"}, inplace=True)

In [14]:
rle_df = full_df.groupby("uuid")["syllable"].apply(rle)
rle_df.index = rle_df.index.droplevel(0)
rle_df = full_df.loc[rle_df.index]
rle_df["syllable"] = rle_df["syllable"].fillna(-5)
rle_df["syllable"] = rle_df["syllable"].astype("int8")

In [15]:
rle_df.to_parquet(rle_save_file)

# Compute a large set of features on the waveform for correlations

In [16]:
snippet_df.index = range(len(snippet_df))

In [17]:
# features we want to compute
compute_features = [
    "max",
    nanargmax,
    "mean",
    "min",
]

numba_features = [
]

In [18]:
meta_df = (
    snippet_df.loc[snippet_df["x"] == 0].drop_duplicates("snippet").set_index("snippet")
)

In [19]:
ave_scalars = (
    snippet_df.loc[(snippet_df["x"] >= 0) & (snippet_df["x"] <= snippet_df["duration"])]
    .groupby("snippet")[transition_cfg["scalars"]]
    .mean()
)

In [None]:
# one last trim, don't need anything beyond the most extreme window edges...
win_arr = np.array(transition_cfg["windows"])
snippet_df = snippet_df.loc[snippet_df["x"].between(win_arr.min(), win_arr.max(), inclusive="both")].copy()

In [None]:
snippet_df["is_feedback"] = snippet_df["feedback_status"] == 1
snippet_df["is_catch"] = snippet_df["feedback_status"] == 0

In [None]:
use_df = snippet_df[
    list(
        set(
            transition_cfg["proc_keys"]
            + ["x", "snippet", "is_feedback", "is_catch", "duration"]
            + transition_cfg["scalars"]
        )
    )
].copy()
use_df.index = range(len(use_df))

In [None]:
use_df["x_samples"] = (use_df["x"] * common_cfg["fs"]).round()

In [None]:
transition_cfg["windows"]

In [None]:
delays = []
# same time window is joined along 1
for time_win in transition_cfg["windows"]:
    for _pandas_feature in compute_features:
        delays.append(
            delayed(compute_pandas_feature)(
                tuple(time_win),
                use_df,
                _pandas_feature,
                keys=transition_cfg["proc_keys"],
            )
        )

    # get feedback status
    delays.append(
        delayed(compute_pandas_feature)(
            tuple(time_win), use_df, "any", keys=["is_feedback", "is_catch"]
        )
    )

    # also average them scalars (they also contain duration information if onset to offset...
    delays.append(
        delayed(compute_pandas_feature)(
            tuple(time_win), use_df, "mean", keys=transition_cfg["scalars"]
        )
    )

In [None]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True) # prevents deadlocking

In [None]:
with redirect_stderr(terminal):
    print(f"{len(delays)} jobs to process", file=terminal)
    results_pandas = Parallel(verbose=10, backend="multiprocessing", n_jobs=20)(
        delays
    )

In [None]:
delays = []
for time_win in transition_cfg["windows"]:
    for _numba_feature in numba_features:
        delays.append(
            delayed(compute_numba_feature)(
                tuple(time_win),
                use_df,
                _numba_feature,
                keys=transition_cfg["proc_keys"],
            )
        )

In [None]:
with redirect_stderr(terminal):
    print(f"{len(delays)} jobs to process", file=terminal)
    results_numba = Parallel(verbose=10, backend="multiprocessing", n_jobs=25)(
        delays
    )

In [None]:
results = results_numba + results_pandas

In [None]:
tmp = pd.Series([_.iloc[0]["window"][0] for _ in results])

In [None]:
dfs = []
for _int in tqdm(tmp.unique()):
    grab_idx = tmp[tmp == _int].index
    _results = [results[i] for i in grab_idx]
    _df = pd.concat(_results, axis=1)
    _df = _df.loc[:, ~_df.columns.duplicated()]
    dfs.append(_df)

In [None]:
feature_df = pd.concat(dfs)
feature_df.columns = [
    "_".join(_) if (isinstance(_, tuple) and len(_[1]) > 0) else _[0] for _ in feature_df.columns
]
feature_df = feature_df.join(meta_df[transition_cfg["meta_keys"]])

In [None]:
feature_df.drop(ave_scalars.columns, axis=1, errors="ignore", inplace=True)
feature_df = feature_df.join(ave_scalars)

In [None]:
feature_df["window_tup"] = feature_df["window"].array.to_tuples()
feature_df["window_tup"] = feature_df["window_tup"].astype("str")

In [None]:
feature_df["win_left"] = feature_df["window"].array.left.to_numpy()
feature_df["win_right"] = feature_df["window"].array.right.to_numpy()

In [None]:
convert_cols = feature_df.filter(regex=".*idxmax").columns
feature_df[convert_cols] /= common_cfg["fs"]

In [None]:
feature_df.to_parquet(features_save_file)