# Compute syllable sequences and entropy in an expanding window after syllable onset

1. Note that you can use a dask scheduler by setting `dask_address` to the address of the scheduler (see `../analysis_configuration.toml`)
1. Note that you must compute `syllable_stats_offline.toml` via `all_behavior_00_statemap_and_syllable_stats.ipynb` first!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.batch import apply_parallel_joblib
from rl_analysis.io.df import dlight_exclude
from rl_analysis.photometry.lagged import get_syllable_sequence_stats
from rl_analysis.util import randomize_cols
from rl_analysis.info.util import dm_entropy
from functools import partial
from joblib import Parallel, delayed
from copy import deepcopy

import tempfile
import pandas as pd
import numpy as np
import os

In [3]:
# redirect to server app logs for long-running operations...
import sys
from contextlib import redirect_stderr

terminal = sys.__stderr__

In [4]:
# note that we're using observed TRUE for downstream groupby so whatever is a categorical will get dropped
# if it's not observed
convs = {
    "timestamp": "float32",
    "snippet": "int64",
    "syllable": "category",
    "next_syllable": "uint8",
    "bin": "category",
    "time_bin": "float32",
    "dlight_bin": "category",
    "mouse_id": "category",
    "uuid": "category",
    "counts": "uint16",
}

In [5]:
force = True
# dask_address = None

In [6]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [7]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
dlight_cfg = analysis_config["dlight_common"]

In [8]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [9]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

if lagged_cfg["use_renormalized"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)
if lagged_cfg["estimate_within_bin"]:
    file = f"{file}_withinbin"

# file = f"{file}_{use_tup}"
results_file = f"{file}_lag_global_entropy.parquet"
results_tm_file = f"{file}_lag_global_entropy_tms.npy"
corr_file = f"{file}_lag_global_entropy_correlations.parquet"
shuffle_file = f"{file}_lag_global_entropy_shuffle.parquet"
ents_file = f"{file}_lag_global_entropy_ents.parquet"

In [10]:
use_features = []
for _use_win in lagged_cfg["use_windows"]:
    use_features += [f"{_}_{_use_win}" for _ in lagged_cfg["use_neural_features"]]
scalar_keys = lagged_cfg["usage_and_scalars"]["scalars"]

## Load in pre-processed data

In [11]:
feature_df = pd.read_parquet(features_save_file)
rle_df = pd.read_parquet(rle_save_file)

In [12]:
features_save_file

'/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_raw_data/dlight_snippets_offline_features.parquet'

In [13]:
use_area = "dls"
# we can only exclude sessions since these calculations depend on contiguity, remove specific trials later
feature_df = feature_df.loc[
    (feature_df["area"] == use_area) & (~feature_df["session_number"].isin([1, 2]))
].copy()

In [14]:
with open(
    os.path.join(proc_dirs["dlight"], "lagged_analysis_session_bins.toml"), "r"
) as f:
    use_session_bins = toml.load(f)["session_bins"]

syllable_stats = toml.load(
    os.path.join(proc_dirs["dlight"], f"syllable_stats_photometry_{file_suffix}.toml")
)
usage = syllable_stats["usages"]
mapping = {int(k): int(v) for k, v in syllable_stats["syllable_to_sorted_idx"].items()}
reverse_mapping = {
    int(k): int(v) for k, v in syllable_stats["sorted_idx_to_syllable"].items()
}

use_syllables = np.array(list(mapping.keys()))
use_syllables = use_syllables[use_syllables >= 0]

In [15]:
feature_df["syllable"] = feature_df["syllable"].map(mapping)
rle_df["syllable"] = rle_df["syllable"].map(mapping)

# Stage data for downstream computation

In [16]:
feature_df["time_bin"] = pd.cut(feature_df["timestamp"], use_session_bins, labels=False)

try:
    feature_df = feature_df.set_index("window_tup", append=True)
except KeyError:
    pass

wins = feature_df.index.get_level_values(-1).unique()
idx = pd.IndexSlice

In [17]:
dfs = []
for _idx in tqdm(wins):
    use_vals = feature_df.loc[idx[:, _idx], lagged_cfg["use_neural_features"]]
    use_vals.columns = [f"{_}_{_idx}" for _ in use_vals.columns]
    use_vals.index = use_vals.index.droplevel(-1)
    dfs.append(use_vals)

  0%|          | 0/8 [00:00<?, ?it/s]

In [18]:
meta_cols = feature_df.columns.difference(feature_df.filter(regex="dff").columns)

meta_df = feature_df[meta_cols].loc[idx[:, _idx], :]
meta_df.index = meta_df.index.droplevel(-1)

feature_df = pd.concat(dfs, axis=1).join(meta_df)
rle_df["duration"] = rle_df.groupby("uuid")["timestamp"].shift(-1) - rle_df["timestamp"]
feature_df = feature_df.reset_index()

In [19]:
rle_df["syllable_num"] = rle_df.groupby(["uuid"])["syllable"].transform(
    lambda x: np.arange(len(x))
)
feature_df = feature_df.sort_values(["timestamp", "uuid"]).dropna(subset=["timestamp"])
rle_df = rle_df.sort_values(["timestamp", "uuid"])
rle_df["timestamp"] = rle_df["timestamp"].astype("float32")

feature_df["uuid"] = feature_df["uuid"].astype("str")
feature_df = pd.merge_asof(
    feature_df, rle_df[["timestamp", "uuid", "syllable_num"]], on="timestamp", by="uuid"
)
feature_df = feature_df.sort_values(["uuid", "timestamp"])
rle_df = rle_df.sort_values(["uuid", "timestamp"])

# Now we compute values triggered on syllable instances, storing features we want to split by downstream (dLight, scalars, etc.)

In [20]:
usage_bins = np.arange(*lagged_cfg["entropy"]["bins"])
idx = pd.IndexSlice
feature_df = feature_df.reset_index(drop=True)

In [21]:
if lagged_cfg["estimate_within_bin"]:
    group_keys = ["uuid", "time_bin"]
else:
    group_keys = ["uuid"]
group_obj = feature_df.groupby(group_keys)

In [22]:
use_syllables = [int(_) for _ in use_syllables]

In [23]:
if not os.path.exists(results_file) or force:
    func = partial(
        get_syllable_sequence_stats,
        usage_bins=usage_bins,
        chk_syllables=np.arange(100),
        dlight_features=use_features,
        K=len(usage),
        truncate=syllable_stats["truncate"],
    )
    with redirect_stderr(terminal):
        print(f"{group_obj.ngroups} Total groups to process", file=terminal)
        syllable_rates = apply_parallel_joblib(
            group_obj, func, n_jobs=-1, backend="loky"
        )
    syllable_rates = syllable_rates.reset_index()

    for k, v in tqdm(convs.items()):
        try:
            syllable_rates[k] = syllable_rates[k].astype(v)
        except KeyError:
            pass

    save_rates = syllable_rates[syllable_rates.columns.difference(["tm"])]
    tm_list = syllable_rates["tm"].to_list()

    syllable_rates["total_duration"] = (
        syllable_rates["count"] * syllable_rates["duration"]
    )
    save_rates.to_parquet(results_file)
    np.save(results_tm_file, tm_list)
else:
    save_rates = pd.read_parquet(results_file)
    tm_mat = np.load(results_tm_file)
    tm_list = list(tm_mat)
save_rates["tm"] = tm_list

2500 Total groups to process
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    4.9s
[Parallel(n_jobs=-1)]: Done  32 tasks      | elapsed:    5.6s
[Parallel(n_jobs=-1)]: Done  57 tasks      | elapsed:    6.5s
[Parallel(n_jobs=-1)]: Done  82 tasks      | elapsed:    7.3s
[Parallel(n_jobs=-1)]: Done 109 tasks      | elapsed:    8.3s
[Parallel(n_jobs=-1)]: Done 136 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-1)]: Done 165 tasks      | elapsed:   11.2s
[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed:   12.6s
[Parallel(n_jobs=-1)]: Done 225 tasks      | elapsed:   13.9s
[Parallel(n_jobs=-1)]: Done 256 tasks      | elapsed:   15.1s
[Parallel(n_jobs=-1)]: Done 289 tasks      | elapsed:   16.5s
[Parallel(n_jobs=-1)]: Done 322 tasks      | elapsed:   17.6s
[Parallel(n_jobs=-1)]: Done 357 tasks      | elapsed:   18.7s
[Parallel(n_jobs=-1)]: Done 392 tasks      | elapsed:   20.1s
[Parallel(n_jobs=-1)]: Do

  0%|          | 0/10 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  save_rates["tm"] = tm_list


# Compute correlations

In [24]:
from copy import deepcopy

In [25]:
use_syllable_rates = save_rates.copy()
use_syllable_rates = use_syllable_rates.loc[use_syllable_rates["area"] == "dls"].copy()
use_syllable_rates = dlight_exclude(
    use_syllable_rates, exclude_3s=False, syllable_key="syllable", **dlight_cfg
)

In [26]:
idx = pd.IndexSlice

In [27]:
# number of quantile bins for dlight data
ndlight_bins = lagged_cfg["entropy"]["ndlight_bins"]
ndlight_bins = 20
dlight_bin_keys = ["mouse_id"]

pre_agg_keys = deepcopy(lagged_cfg["entropy"]["pre_agg_keys"])
agg_keys = deepcopy(lagged_cfg["entropy"]["agg_keys"])
correlation_keys = deepcopy(lagged_cfg["entropy"]["correlation_keys"])

if lagged_cfg["estimate_within_bin"]:
    pre_agg_keys = ["time_bin"] + pre_agg_keys
    agg_keys = ["time_bin"] + agg_keys
    correlation_keys = ["time_bin"] + correlation_keys

# what features are we correlating? (dlight_bin, entropy and number of transitions in bin)
corr_cols = ["dlight_bin", "entropy"]

In [28]:
correlation_keys

['time_bin', 'dlight_bin_feature', 'bin']

In [29]:
def get_bins(x, ndlight_bins):
    try:
        results = pd.qcut(x, ndlight_bins, labels=False, duplicates="drop")
    except IndexError:
        results = np.ones((len(x),))
        results[:] = np.nan
        results = pd.Series(data=results, index=x.index)
    return results

In [30]:
cat_rates = []
for _feature in tqdm(use_features):
    _tmp_rates = use_syllable_rates.copy()
    _tmp_rates["dlight_bin"] = (
        _tmp_rates.groupby(dlight_bin_keys)[_feature].transform(
            lambda x: get_bins(x, ndlight_bins)
        )
        # .transform(lambda x: pd.qcut(x, ndlight_bins, labels=False, duplicates="drop"))
    ).astype("category")
    _tmp_rates["dlight_bin_feature"] = _feature
    cat_rates.append(_tmp_rates.dropna(subset=["dlight_bin"]))

  0%|          | 0/1 [00:00<?, ?it/s]

In [31]:
use_syllable_rates = pd.concat(cat_rates, axis=0).reset_index(drop=True)

In [32]:
# aggregation one
agg_rates = (
    use_syllable_rates.groupby(pre_agg_keys, observed=True)["tm"].sum().sort_index()
)
agg_rates = agg_rates.reset_index()

In [33]:
# need to reconvert data types to conserve memory
for k, v in tqdm(convs.items()):
    try:
        agg_rates[k] = agg_rates[k].astype(v)
    except KeyError:
        pass

  0%|          | 0/10 [00:00<?, ?it/s]

In [34]:
# aggregation 2
ave_rates = agg_rates.groupby(agg_keys, observed=True)["tm"].sum().reset_index()

In [35]:
# count transitions in each bin and threshold
totals = ave_rates.groupby(agg_keys)["tm"].apply(lambda x: x.sum().sum())
use_syllables = totals[totals > 100].index

In [36]:
# ensure we have all bins
use_rates = ave_rates.reset_index().set_index(agg_keys).loc[use_syllables]
chk_keys = list(set(agg_keys[:-1]) - set(["dlight_bin"]))

In [37]:
truncate = syllable_stats["truncate"]

In [38]:
ents = use_rates["tm"].apply(
    lambda x: dm_entropy(
        x[:truncate, :truncate], alpha="perks", marginalize=False, axis=1
    )
)
ents.to_frame().to_parquet(os.path.join(proc_dirs["dlight"], ents_file))

In [39]:
use_rates["entropy"] = ents
use_rates = use_rates.reset_index()
use_rates["dlight_bin"] = use_rates["dlight_bin"].astype("float")
use_rates["bin"] = use_rates["bin"].astype("uint16")

In [40]:
use_rates["dlight_bin_feature"].unique()

array(['signal_reref_dff_z_max_(0.0, 0.3)'], dtype=object)

In [41]:
size_stats = {}

In [42]:
correlation_sz = use_rates.groupby(correlation_keys)[corr_cols].size()
size_stats["n_per_group"] = correlation_sz.describe().to_dict()
size_stats["n_total_mouse_time_bin_lag_points"] = len(
    use_rates.loc[use_rates["dlight_bin_feature"] == use_features[0]]
)
size_stats["n_mice"] = use_rates["mouse_id"].nunique()
size_stats["n_bins"] = use_rates["time_bin"].nunique()
size_stats["n_dlight_bins"] = use_rates["dlight_bin"].nunique()

with open(os.path.join(proc_dirs["dlight"], "stats_lagged_entropy_n.toml"), "w") as f:
    toml.dump(size_stats, f)

In [43]:
if not os.path.exists(corr_file) or force:
    obs_corrs = use_rates.groupby(correlation_keys)[corr_cols].corr(
        **lagged_cfg["entropy"]["corr_kwargs"]
    )
    obs_corrs.index = obs_corrs.index.set_names("feature", level=-1)
    obs_corrs.to_parquet(corr_file)
else:
    obs_corrs = pd.read_parquet(corr_file)

In [44]:
def shuffler_corr(
    df: pd.DataFrame,
    tm_mat: np.ndarray,
    idx: int,
    shuffle_col: str = "counts",
    shuffle_by: list[str] = pre_agg_keys[-1],
    group_by: list[str] = correlation_keys,
    pre_agg_group_by: list[str] = agg_keys,
    compute_cols: list[str] = corr_cols,
    corr_kwargs: dict = lagged_cfg["entropy"]["corr_kwargs"],
    truncate: int = truncate,
) -> pd.DataFrame:

    use_df = df.copy()
    use_tms = deepcopy(tm_mat)
    _ = [randomize_cols(_tm) for _tm in use_tms]
    use_df["tm"] = list(use_tms)

    use_agg_df = use_df.groupby(pre_agg_group_by, observed=True)["tm"].sum()

    use_agg_df = (
        use_agg_df.reset_index().set_index(agg_keys).loc[use_syllables].reset_index()
    )
    ents = (
        use_agg_df["tm"].apply(
            lambda x: (
                dm_entropy(
                    x[:truncate, :truncate], alpha="perks", marginalize=False, axis=1
                )
            )
        )
    ).rename("entropy")
    use_agg_df["entropy"] = ents
    use_agg_df["dlight_bin"] = use_rates["dlight_bin"].astype("float")
    use_agg_df["bin"] = use_rates["bin"].astype("uint16")

    shuffle_df = use_agg_df.groupby(group_by)[compute_cols].corr(
        **lagged_cfg["entropy"]["corr_kwargs"]
    )
    shuffle_df.index = shuffle_df.index.set_names("feature", level=-1)
    shuffle_df["idx"] = idx
    return shuffle_df

In [45]:
if not os.path.exists(shuffle_file) or force:
    copy_rates = agg_rates[agg_rates.columns.difference(["tm"])].copy()
    tm_list = agg_rates["tm"].to_list()
    tm_mat = np.stack(tm_list)
    tm_mat = tm_mat[:, :truncate, :truncate].astype("uint8")

    if not np.any((tm_mat > 256)):
        tm_mat = tm_mat.astype("uint8")

    delays = [
        delayed(shuffler_corr)(copy_rates, tm_mat, sidx)
        for sidx in range(lagged_cfg["nshuffles"])
    ]

    with redirect_stderr(terminal):
        print(f"{len(delays)} jobs", file=terminal)
        ret_list = Parallel(n_jobs=-20, verbose=10, temp_folder=tempfile.gettempdir())(delays)
    shuffle_df = pd.concat(ret_list, axis=0)
    shuffle_df.to_parquet(shuffle_file)

1000 jobs
[Parallel(n_jobs=-20)]: Using backend LokyBackend with 109 concurrent workers.
[Parallel(n_jobs=-20)]: Done   3 tasks      | elapsed:   27.0s
[Parallel(n_jobs=-20)]: Done  24 tasks      | elapsed:   35.6s
[Parallel(n_jobs=-20)]: Done  47 tasks      | elapsed:   47.0s
[Parallel(n_jobs=-20)]: Done  70 tasks      | elapsed:   59.9s
[Parallel(n_jobs=-20)]: Done  95 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-20)]: Done 120 tasks      | elapsed:  1.4min
[Parallel(n_jobs=-20)]: Done 147 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-20)]: Done 174 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-20)]: Done 203 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-20)]: Done 232 tasks      | elapsed:  2.2min
[Parallel(n_jobs=-20)]: Done 263 tasks      | elapsed:  2.4min
[Parallel(n_jobs=-20)]: Done 294 tasks      | elapsed:  2.7min
[Parallel(n_jobs=-20)]: Done 327 tasks      | elapsed:  2.9min
[Parallel(n_jobs=-20)]: Done 360 tasks      | elapsed:  3.2min
[Parallel(n_jobs=-20)]: Done 

In [46]:
print(shuffle_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_intermediate_results/dlight_snippets_offline_features_withinbin_lag_global_entropy_shuffle.parquet
