# Compute syllable sequences and entropy in an expanding window after syllable onset (for display examples)

1. Note that you can use a dask scheduler by setting `dask_address` to the address of the scheduler (see `../analysis_configuration.toml`)
1. Note that you must compute `syllable_stats_offline.toml` via `all_behavior_00_statemap_and_syllable_stats.ipynb` first!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.batch import apply_parallel_joblib
from rl_analysis.io.df import dlight_exclude
from rl_analysis.photometry.lagged import get_syllable_sequence_stats
from rl_analysis.info.util import dm_entropy
from functools import partial
from copy import deepcopy

import pandas as pd
import numpy as np
import os

In [3]:
# note that we're using observed TRUE for downstream groupby so whatever is a categorical will get dropped
# if it's not observed
convs = {
    "timestamp": "float32",
    "snippet": "int64",
    "syllable": "category",
    "next_syllable": "uint8",
    "bin": "category",
    "time_bin": "float32",
    "dlight_bin": "category",
    "mouse_id": "category",
    "uuid": "category",
    "counts": "uint16",
}

In [4]:
force = True

In [5]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [6]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
dlight_cfg = analysis_config["dlight_common"]

In [7]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [9]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

if lagged_cfg["use_renormalized"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)
file = f"{file}_for_examples"

results_file = f"{file}_lag_global_entropy.parquet"
results_tm_file = f"{file}_lag_global_entropy_tms.npy"
corr_file = f"{file}_lag_global_entropy_correlations.parquet"
ents_file = f"{file}_lag_global_entropy_ents.parquet"

In [10]:
use_features = []
for _use_win in lagged_cfg["use_windows"]:
    use_features += [f"{_}_{_use_win}" for _ in lagged_cfg["use_neural_features"]]
scalar_keys = lagged_cfg["usage_and_scalars"]["scalars"]

## Load in pre-processed data

In [11]:
feature_df = pd.read_parquet(features_save_file)
rle_df = pd.read_parquet(rle_save_file)

In [12]:
feature_df = feature_df.loc[
    (feature_df["area"] == "dls") & (~feature_df["session_number"].isin([1, 2]))
].copy()

In [13]:
with open(
    os.path.join(proc_dirs["dlight"], "lagged_analysis_session_bins.toml"), "r"
) as f:
    use_session_bins = toml.load(f)["session_bins"]

In [14]:
syllable_stats = toml.load(
    os.path.join(proc_dirs["dlight"], f"syllable_stats_photometry_{file_suffix}.toml")
)
usage = syllable_stats["usages"]
mapping = {int(k): int(v) for k, v in syllable_stats["syllable_to_sorted_idx"].items()}
reverse_mapping = {
    int(k): int(v) for k, v in syllable_stats["sorted_idx_to_syllable"].items()
}

use_syllables = np.array(list(mapping.keys()))
use_syllables = use_syllables[use_syllables >= 0]

feature_df["syllable"] = feature_df["syllable"].map(mapping)
rle_df["syllable"] = rle_df["syllable"].map(mapping)

# Stage data for downstream computation

In [15]:
feature_df["time_bin"] = pd.cut(feature_df["timestamp"], use_session_bins, labels=False)

In [16]:
try:
    feature_df = feature_df.set_index("window_tup", append=True)
except KeyError:
    pass

In [17]:
wins = feature_df.index.get_level_values(-1).unique()
idx = pd.IndexSlice

dfs = []
for _idx in tqdm(wins):
    use_vals = feature_df.loc[idx[:, _idx], lagged_cfg["use_neural_features"]]
    use_vals.columns = [f"{_}_{_idx}" for _ in use_vals.columns]
    use_vals.index = use_vals.index.droplevel(-1)
    dfs.append(use_vals)

  0%|          | 0/8 [00:00<?, ?it/s]

In [18]:
meta_cols = feature_df.columns.difference(feature_df.filter(regex="dff").columns)

meta_df = feature_df[meta_cols].loc[idx[:, _idx], :]
meta_df.index = meta_df.index.droplevel(-1)

feature_df = pd.concat(dfs, axis=1).join(meta_df)
rle_df["duration"] = rle_df.groupby("uuid")["timestamp"].shift(-1) - rle_df["timestamp"]
feature_df = feature_df.reset_index()

In [19]:
rle_df["syllable_num"] = rle_df.groupby(["uuid"])["syllable"].transform(
    lambda x: np.arange(len(x))
)
feature_df = feature_df.sort_values(["timestamp", "uuid"]).dropna(subset=["timestamp"])
rle_df = rle_df.sort_values(["timestamp", "uuid"])
rle_df["timestamp"] = rle_df["timestamp"].astype("float32")

feature_df["uuid"] = feature_df["uuid"].astype("str")
feature_df = pd.merge_asof(
    feature_df, rle_df[["timestamp", "uuid", "syllable_num"]], on="timestamp", by="uuid"
)
feature_df = feature_df.sort_values(["uuid", "timestamp"])
rle_df = rle_df.sort_values(["uuid", "timestamp"])

In [20]:
usages = rle_df.groupby("uuid")["syllable"].value_counts().groupby("syllable").mean()
usages = usages[usages.index != -5]
use_syllables = usages.sort_values()[::-1].index[:].tolist()

# Now we compute values triggered on syllable instances, storing features we want to split by downstream (dLight, scalars, etc.)

In [21]:
usage_bins = np.arange(5, 25, 1)
idx = pd.IndexSlice
feature_df = feature_df.reset_index(drop=True)

group_keys = ["uuid"]
group_obj = feature_df.loc[~feature_df["session_number"].isin([1, 2])].groupby(
    group_keys
)

In [22]:
use_syllables = [int(_) for _ in use_syllables]

In [23]:
if not os.path.exists(results_file) or force:
    func = partial(
        get_syllable_sequence_stats,
        usage_bins=usage_bins,
        chk_syllables=np.arange(100),
        dlight_features=use_features,
        truncate=syllable_stats["truncate"],
        K=len(usage),
    )
    syllable_rates = apply_parallel_joblib(group_obj, func, n_jobs=-1, backend="loky")
    syllable_rates = syllable_rates.reset_index()

    for k, v in tqdm(convs.items()):
        try:
            syllable_rates[k] = syllable_rates[k].astype(v)
        except KeyError:
            pass

    save_rates = syllable_rates[syllable_rates.columns.difference(["tm"])]
    tm_list = syllable_rates["tm"].to_list()

    syllable_rates["total_duration"] = (
        syllable_rates["count"] * syllable_rates["duration"]
    )
    save_rates.to_parquet(results_file)
    np.save(results_tm_file, tm_list)
else:
    save_rates = pd.read_parquet(results_file)
    tm_mat = np.load(results_tm_file)
    tm_list = list(tm_mat)

  ret_list = Parallel(n_jobs=n_jobs, verbose=verbose, backend=backend, batch_size=batch_size)(
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:   21.2s
[Parallel(n_jobs=-1)]: Done  32 tasks      | elapsed:   30.6s
[Parallel(n_jobs=-1)]: Done  57 tasks      | elapsed:   41.2s
[Parallel(n_jobs=-1)]: Done  82 tasks      | elapsed:   53.1s
[Parallel(n_jobs=-1)]: Done 109 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-1)]: Done 136 tasks      | elapsed:  1.4min
[Parallel(n_jobs=-1)]: Done 165 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed:  2.1min
[Parallel(n_jobs=-1)]: Done 225 tasks      | elapsed:  2.4min
[Parallel(n_jobs=-1)]: Done 296 out of 500 | elapsed:  3.3min remaining:  2.3min
[Parallel(n_jobs=-1)]: Done 347 out of 500 | elapsed:  3.9min remaining:  1.7min
[Parallel(n_jobs=-1)]: Done 398 out of 500 | elapsed:  4.5min remaining:  1.1min
[Parallel(n_jobs=-1)]: Don

  0%|          | 0/10 [00:00<?, ?it/s]

# Compute correlations

In [24]:
use_features = ["signal_reref_dff_z_max_(0.0, 0.3)"]

In [25]:
use_syllable_rates = save_rates.copy()
use_syllable_rates["tm"] = tm_list

In [26]:
use_syllable_rates = use_syllable_rates.loc[use_syllable_rates["area"] == "dls"].copy()
use_syllable_rates = dlight_exclude(
    use_syllable_rates, exclude_3s=False, syllable_key="syllable", **dlight_cfg
)

In [27]:
idx = pd.IndexSlice

In [28]:
# number of quantile bins for dlight data
ndlight_bins = lagged_cfg["entropy"]["ndlight_bins"]

# column to bin for dlight correlation
bin_column = use_features[0]
dlight_bin_keys = ["mouse_id", "syllable"]

pre_agg_keys = deepcopy(lagged_cfg["entropy"]["pre_agg_keys"]) + ["syllable"]
agg_keys = deepcopy(lagged_cfg["entropy"]["agg_keys"]) + ["syllable"]
correlation_keys = deepcopy(lagged_cfg["entropy"]["correlation_keys"]) + [
    "syllable",
    "mouse_id",
]

# what features are we correlating? (dlight_bin, entropy and number of transitions in bin)
corr_cols = ["dlight_bin", "entropy"]

In [29]:
def get_bins(x, ndlight_bins):
    try:
        results = pd.qcut(x, ndlight_bins, labels=False, duplicates="drop")
    except IndexError:
        results = np.ones((len(x),))
        results[:] = np.nan
        results = pd.Series(data=results, index=x.index)

    return results

In [30]:
# dlight binning
cat_rates = []
for _feature in tqdm(use_features):
    _tmp_rates = use_syllable_rates.copy()
    _tmp_rates["dlight_bin"] = (
        _tmp_rates.groupby(["mouse_id"])[_feature].transform(
            lambda x: get_bins(x, ndlight_bins)
        )
    ).astype("category")
    _tmp_rates["dlight_bin_feature"] = _feature
    cat_rates.append(_tmp_rates)

  0%|          | 0/1 [00:00<?, ?it/s]

In [31]:
use_syllable_rates = pd.concat(cat_rates, axis=0).reset_index(drop=True)

In [32]:
# aggregation one
agg_rates = (
    use_syllable_rates.groupby(pre_agg_keys, observed=True)["tm"].sum().sort_index()
)
agg_rates = agg_rates.reset_index()

In [33]:
# need to reconvert data types to conserve memory
for k, v in tqdm(convs.items()):
    try:
        agg_rates[k] = agg_rates[k].astype(v)
    except KeyError:
        pass

  0%|          | 0/10 [00:00<?, ?it/s]

In [34]:
# aggregation 2
ave_rates = agg_rates.groupby(agg_keys, observed=True)["tm"].sum().reset_index()

In [35]:
tqdm.pandas()

In [36]:
len(ave_rates)

159440

In [37]:
ave_rates["totals"] = ave_rates["tm"].progress_apply(np.sum)

  0%|          | 0/159440 [00:00<?, ?it/s]

In [38]:
# ensure we have all bins
use_rates = ave_rates.reset_index().set_index(agg_keys)

In [39]:
# truncate = lagged_cfg["entropy"]["tm_truncate"]
truncate = syllable_stats["truncate"]

In [40]:
ents = use_rates["tm"].apply(
    lambda x: dm_entropy(
        x[:truncate, :truncate], alpha="perks", marginalize=False, axis=1
    )
)
ents.to_frame().to_parquet(os.path.join(proc_dirs["dlight"], ents_file))

In [41]:
use_rates["entropy"] = ents
use_rates = use_rates.reset_index()
use_rates["dlight_bin"] = use_rates["dlight_bin"].astype("float")
use_rates["bin"] = use_rates["bin"].astype("uint16")

In [42]:
use_rates["dlight_bin_feature"].unique()

array(['signal_reref_dff_z_max_(0.0, 0.3)'], dtype=object)

In [43]:
if not os.path.exists(corr_file) or force:
    obs_corrs = use_rates.groupby(correlation_keys)[corr_cols].corr(
        **lagged_cfg["entropy"]["corr_kwargs"]
    )
    obs_corrs.index = obs_corrs.index.set_names("feature", level=-1)
    obs_corrs.to_parquet(corr_file)
else:
    obs_corrs = pd.read_parquet(corr_file)

In [44]:
print(corr_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_intermediate_results/dlight_snippets_offline_features_for_examples_lag_global_entropy_correlations.parquet
