# Compute syllable sequences and entropy in an expanding window after syllable onset (for fig. 4)

1. Note that you can use a dask scheduler by setting `dask_address` to the address of the scheduler (see `../analysis_configuration.toml`)
1. Note that you must compute `syllable_stats_offline.toml` via `all_behavior_00_statemap_and_syllable_stats.ipynb` first!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.batch import (
    apply_parallel_joblib,
)
from rl_analysis.photometry.lagged import get_syllable_sequence_stats
from rl_analysis.io.df import dlight_exclude
from rl_analysis.info.util import dm_entropy
from rl_analysis.util import randomize_cols
from functools import partial

import tempfile
import pandas as pd
import numpy as np
import os

In [3]:
# note that we're using observed TRUE for downstream groupby so whatever is a categorical will get dropped
# if it's not observed
convs = {
    "timestamp": "float32",
    "snippet": "int64",
    "syllable": "category",
    "next_syllable": "uint8",
    "bin": "category",
    "time_bin": "float32",
    "dlight_bin": "category",
    "mouse_id": "category",
    "uuid": "category",
    "counts": "uint16",
}

In [4]:
force = True
# dask_address = None

In [5]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [6]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
dlight_cfg = analysis_config["dlight_common"]
dask_address = analysis_config["dask"].get("address")

In [7]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [9]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

if lagged_cfg["use_renormalized"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)

if lagged_cfg["estimate_within_bin"]:
    file = f"{file}_withinbin"

results_file = f"{file}_lag_global_entropy_for_index.parquet"
results_tm_file = f"{file}_lag_global_entropy_tms_for_index.npy"
corr_file = f"{file}_lag_global_entropy_correlations_for_index_per_mouse.parquet"
shuffle_file = f"{file}_lag_global_entropy_shuffle_for_index_per_mouse.parquet"
ents_file = f"{file}_lag_global_entropy_ents_for_index.parquet"

In [10]:
use_features = []
for _use_win in lagged_cfg["use_windows"]:
    use_features += [f"{_}_{_use_win}" for _ in lagged_cfg["use_neural_features"]]
scalar_keys = lagged_cfg["usage_and_scalars"]["scalars"]

## Load in pre-processed data

In [11]:
feature_df = pd.read_parquet(features_save_file)
rle_df = pd.read_parquet(rle_save_file)

In [12]:
use_area = "dls"
# we can only exclude sessions since these calculations depend on contiguity, remove specific trials later
feature_df = feature_df.loc[
    (feature_df["area"] == use_area) & (~feature_df["session_number"].isin([1, 2]))
].copy()

In [13]:
with open(
    os.path.join(proc_dirs["dlight"], "lagged_analysis_session_bins.toml"), "r"
) as f:
    use_session_bins = toml.load(f)["session_bins"]

syllable_stats = toml.load(
    os.path.join(proc_dirs["dlight"], f"syllable_stats_photometry_{file_suffix}.toml")
)
usage = syllable_stats["usages"]
mapping = {int(k): int(v) for k, v in syllable_stats["syllable_to_sorted_idx"].items()}
reverse_mapping = {
    int(k): int(v) for k, v in syllable_stats["sorted_idx_to_syllable"].items()
}

use_syllables = np.array(list(mapping.keys()))
use_syllables = use_syllables[use_syllables >= 0]

In [14]:
feature_df["syllable"] = feature_df["syllable"].map(mapping)
rle_df["syllable"] = rle_df["syllable"].map(mapping)

# Stage data for downstream computation

In [15]:
feature_df["time_bin"] = pd.cut(feature_df["timestamp"], use_session_bins, labels=False)

In [16]:
try:
    feature_df = feature_df.set_index("window_tup", append=True)
except KeyError:
    pass

In [17]:
wins = feature_df.index.get_level_values(-1).unique()

In [18]:
idx = pd.IndexSlice

In [19]:
dfs = []
for _idx in tqdm(wins):
    use_vals = feature_df.loc[idx[:, _idx], lagged_cfg["use_neural_features"]]
    use_vals.columns = [f"{_}_{_idx}" for _ in use_vals.columns]
    use_vals.index = use_vals.index.droplevel(-1)
    dfs.append(use_vals)

  0%|          | 0/8 [00:00<?, ?it/s]

In [20]:
meta_cols = feature_df.columns.difference(feature_df.filter(regex="dff").columns)

In [21]:
meta_df = feature_df[meta_cols].loc[idx[:, _idx], :]
meta_df.index = meta_df.index.droplevel(-1)

In [22]:
feature_df = pd.concat(dfs, axis=1).join(meta_df)

In [23]:
rle_df["duration"] = rle_df.groupby("uuid")["timestamp"].shift(-1) - rle_df["timestamp"]

In [24]:
feature_df = feature_df.reset_index()

In [25]:
rle_df["syllable_num"] = rle_df.groupby(["uuid"])["syllable"].transform(
    lambda x: np.arange(len(x))
)
feature_df = feature_df.sort_values(["timestamp", "uuid"]).dropna(subset=["timestamp"])
rle_df = rle_df.sort_values(["timestamp", "uuid"])
rle_df["timestamp"] = rle_df["timestamp"].astype("float32")

feature_df["uuid"] = feature_df["uuid"].astype("str")
feature_df = pd.merge_asof(
    feature_df, rle_df[["timestamp", "uuid", "syllable_num"]], on="timestamp", by="uuid"
)
feature_df = feature_df.sort_values(["uuid", "timestamp"])
rle_df = rle_df.sort_values(["uuid", "timestamp"])

# Now we compute values triggered on syllable instances, storing features we want to split by downstream (dLight, scalars, etc.)

In [26]:
usage_bins = np.arange(5, 61, 5)

In [27]:
truncate = syllable_stats["truncate"]

In [28]:
idx = pd.IndexSlice

In [29]:
if lagged_cfg["estimate_within_bin"]:
    group_keys = ["uuid", "time_bin"]
else:
    group_keys = ["uuid"]
group_obj = feature_df.groupby(group_keys)

In [30]:
use_syllables = [int(_) for _ in use_syllables]

In [31]:
if not os.path.exists(results_file) or force:
    func = partial(
        get_syllable_sequence_stats,
        chk_syllables=np.arange(100),
        truncate=syllable_stats["truncate"],
        usage_bins=usage_bins,
        dlight_features=use_features,
        K=len(usage),
    )
    print(group_obj.ngroups)
    syllable_rates = apply_parallel_joblib(group_obj, func, n_jobs=-10, backend="loky")
    syllable_rates = syllable_rates.reset_index()

    for k, v in tqdm(convs.items()):
        try:
            syllable_rates[k] = syllable_rates[k].astype(v)
        except KeyError:
            pass

    save_rates = syllable_rates[syllable_rates.columns.difference(["tm"])]
    tm_list = syllable_rates["tm"].to_list()

    syllable_rates["total_duration"] = (
        syllable_rates["count"] * syllable_rates["duration"]
    )
    save_rates.to_parquet(results_file)
    np.save(results_tm_file, tm_list)
else:
    save_rates = pd.read_parquet(results_file)
    # tm_mat = np.load(results_tm_file)
    tm_mat = np.load(results_tm_file)[:, :truncate, :truncate]
    tm_list = list(tm_mat)
save_rates["tm"] = tm_list

2500


[Parallel(n_jobs=-10)]: Using backend LokyBackend with 119 concurrent workers.
[Parallel(n_jobs=-10)]: Done   4 tasks      | elapsed:    4.4s
[Parallel(n_jobs=-10)]: Done  27 tasks      | elapsed:    5.1s
[Parallel(n_jobs=-10)]: Done  50 tasks      | elapsed:    6.0s
[Parallel(n_jobs=-10)]: Done  75 tasks      | elapsed:    6.9s
[Parallel(n_jobs=-10)]: Done 100 tasks      | elapsed:    7.9s
[Parallel(n_jobs=-10)]: Done 127 tasks      | elapsed:    9.1s
[Parallel(n_jobs=-10)]: Done 154 tasks      | elapsed:   10.6s
[Parallel(n_jobs=-10)]: Done 183 tasks      | elapsed:   12.6s
[Parallel(n_jobs=-10)]: Done 212 tasks      | elapsed:   14.0s
[Parallel(n_jobs=-10)]: Done 243 tasks      | elapsed:   15.6s
[Parallel(n_jobs=-10)]: Done 274 tasks      | elapsed:   17.0s
[Parallel(n_jobs=-10)]: Done 307 tasks      | elapsed:   18.4s
[Parallel(n_jobs=-10)]: Done 340 tasks      | elapsed:   19.8s
[Parallel(n_jobs=-10)]: Done 375 tasks      | elapsed:   21.3s
[Parallel(n_jobs=-10)]: Done 410 tasks 

  0%|          | 0/10 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  save_rates["tm"] = tm_list


# Compute correlations

In [32]:
from copy import deepcopy
import gc

In [33]:
del feature_df, rle_df
gc.collect()

0

In [35]:
use_syllable_rates = save_rates.copy()
use_syllable_rates = use_syllable_rates.loc[use_syllable_rates["area"] == "dls"].copy()
use_syllable_rates = dlight_exclude(
    use_syllable_rates, exclude_3s=False, syllable_key="syllable", **dlight_cfg
)

In [36]:
idx = pd.IndexSlice


In [37]:
ndlight_bins = 20
bin_column = use_features[0]
dlight_bin_keys = ["mouse_id", "syllable"]
bin_data = True

pre_agg_keys = deepcopy(lagged_cfg["entropy"]["pre_agg_keys"]) + ["syllable"]
agg_keys = deepcopy(lagged_cfg["entropy"]["agg_keys"]) + ["syllable"]
correlation_keys = deepcopy(lagged_cfg["entropy"]["correlation_keys"])
correlation_keys = ["dlight_bin_feature", "bin", "mouse_id", "syllable"]

if lagged_cfg["estimate_within_bin"]:
    pre_agg_keys = ["time_bin"] + pre_agg_keys
    agg_keys = ["time_bin"] + agg_keys
    correlation_keys = ["time_bin"] + correlation_keys

# what features are we correlating? (dlight_bin, entropy and number of transitions in bin)
corr_cols = ["dlight_bin", "entropy"]

In [38]:
def get_bins(x, ndlight_bins):
    try:
        results = pd.qcut(x, ndlight_bins, labels=False, duplicates="drop")
    except IndexError:
        results = np.ones((len(x),))
        results[:] = np.nan
        results = pd.Series(data=results, index=x.index)
    return results

In [39]:
# dlight binning
# NOTE PREVIOUSLY BINNED BY UUID
if bin_data:
    cat_rates = []
    for _feature in tqdm(use_features):
        _tmp_rates = use_syllable_rates.copy()
        if dlight_bin_keys is not None:
            _tmp_rates["dlight_bin"] = (
                _tmp_rates.groupby(dlight_bin_keys)[_feature].transform(
                    lambda x: get_bins(x, ndlight_bins)
                )
                # .transform(lambda x: pd.qcut(x, ndlight_bins, labels=False, duplicates="drop"))
            ).astype("category")
        else:
            _tmp_rates["dlight_bin"] = get_bins(_tmp_rates[_feature], ndlight_bins)
        _tmp_rates["dlight_bin_feature"] = _feature
        cat_rates.append(_tmp_rates.dropna(subset=["dlight_bin"]))
        use_syllable_rates = pd.concat(cat_rates, axis=0).reset_index(drop=True)
else:
    pass

  0%|          | 0/1 [00:00<?, ?it/s]

In [40]:
use_syllable_rates["syllable"] = use_syllable_rates["syllable"].astype("int")

# aggregation one
agg_rates = (
    use_syllable_rates.groupby(pre_agg_keys, observed=True)["tm"].sum().sort_index()
)
agg_rates = agg_rates.reset_index()

In [41]:
# need to reconvert data types to conserve memory
for k, v in tqdm(convs.items()):
    try:
        agg_rates[k] = agg_rates[k].astype(v)
    except KeyError:
        pass

  0%|          | 0/10 [00:00<?, ?it/s]

In [42]:
# aggregation 2
ave_rates = agg_rates.groupby(agg_keys, observed=True)["tm"].sum().dropna()

In [43]:
tqdm.pandas()

In [44]:
totals = ave_rates.progress_apply(lambda x: x.sum())

  0%|          | 0/770325 [00:00<?, ?it/s]

In [45]:
# count transitions in each bin and threshold
totals = ave_rates.groupby(agg_keys, observed=True).apply(lambda x: x.sum().sum())
use_syllables = totals[totals > 100].index

In [46]:
# ensure we have all bins
use_rates = ave_rates.loc[use_syllables]
chk_keys = list(set(agg_keys[:-1]) - set(["dlight_bin"]))

In [47]:
ents = use_rates.progress_apply(
    lambda x: dm_entropy(
        x[:truncate, :truncate], alpha="perks", marginalize=False, axis=1
    )
)
ents.to_frame().to_parquet(os.path.join(proc_dirs["dlight"], ents_file))

  0%|          | 0/494817 [00:00<?, ?it/s]

In [48]:
use_rates = pd.concat([use_rates, ents.rename("entropy")], axis=1)

if bin_data:
    use_rates = use_rates.reset_index()
    use_rates["dlight_bin"] = use_rates["dlight_bin"].astype("float")
    use_rates["bin"] = use_rates["bin"].astype("uint16")
else:
    neural_ave = save_rates.groupby(ave_rates.index.names)[use_features].mean()
    neural_ave = neural_ave.groupby(["bin"]).transform(
        lambda x: (x - x.mean()) / x.std()
    )
    use_rates = pd.concat([use_rates, neural_ave], axis=1).dropna()

In [50]:
%%time
if not os.path.exists(corr_file) or force:
    obs_corrs = use_rates.groupby(correlation_keys, observed=True)[corr_cols].corr(
        **lagged_cfg["entropy"]["corr_kwargs"]
    )
    obs_corrs.index = obs_corrs.index.set_names("feature", level=-1)
    obs_corrs.to_parquet(corr_file)
else:
    obs_corrs = pd.read_parquet(corr_file)

CPU times: user 8.01 s, sys: 272 ms, total: 8.28 s
Wall time: 8.29 s


In [51]:
use_correlation_keys = ["time_bin", "dlight_bin_feature", "bin", "mouse_id"]
# use_correlation_keys = correlation_keys
use_agg_keys = use_correlation_keys + ["dlight_bin"]

In [52]:
from copy import deepcopy


def shuffler_corr(
    df,
    # tm_mat,
    tm_list,
    # col_marginals,
    idx,
    group_by=use_correlation_keys,
    pre_agg_group_by=use_agg_keys,
    pre_agg_func="sum",
    compute_cols=corr_cols,
    corr_kwargs=lagged_cfg["entropy"]["corr_kwargs"],
    truncate=truncate,
):

    use_df = df.copy()
    use_tms = deepcopy(tm_list)
    _ = [randomize_cols(_tm) for _tm in use_tms]

    use_df["tm"] = list(use_tms)

    use_agg_df = use_df.groupby(pre_agg_group_by, observed=True)["tm"].sum()
    use_agg_df = use_agg_df.reset_index()

    ents = (
        use_agg_df["tm"].apply(
            lambda x: (
                dm_entropy(
                    x[:truncate, :truncate], alpha="perks", marginalize=False, axis=1
                )
            )
        )
    ).rename("entropy")
    use_agg_df["entropy"] = ents
    use_agg_df["dlight_bin"] = use_agg_df["dlight_bin"].astype("float")
    use_agg_df["bin"] = use_agg_df["bin"].astype("uint16")

    shuffle_df = use_agg_df.groupby(group_by, observed=True)[compute_cols].corr(
        **lagged_cfg["entropy"]["corr_kwargs"]
    )
    shuffle_df.index = shuffle_df.index.set_names("feature", level=-1)
    shuffle_df["idx"] = idx
    return shuffle_df

In [53]:
features = agg_rates["dlight_bin_feature"].unique()

In [54]:
nshuffles = 100

In [55]:
from joblib import Parallel, delayed

In [None]:
%%time
if not os.path.exists(shuffle_file) or force:
    delays = []
    for _feature in features:
        proc_rates = agg_rates.loc[agg_rates["dlight_bin_feature"] == _feature].copy()
        proc_rates = proc_rates.set_index(use_syllables.names)
        proc_rates = proc_rates.loc[
            use_syllables.intersection(proc_rates.index)
        ].reset_index()
        copy_rates = proc_rates[proc_rates.columns.difference(["tm"])].copy()
        tm_list = np.stack(proc_rates["tm"])

        if not np.any((tm_list > 256)):
            tm_list = tm_list.astype("uint8")

        delays += [
            delayed(shuffler_corr)(copy_rates, tm_list, sidx)
            for sidx in range(nshuffles)
        ]

    print(len(delays))
    ret_list = Parallel(n_jobs=45, verbose=10, temp_folder=tempfile.gettempdir())(delays)
    shuffle_df = pd.concat(ret_list, axis=0)
    shuffle_df.to_parquet(shuffle_file)

100


[Parallel(n_jobs=45)]: Using backend LokyBackend with 45 concurrent workers.
[Parallel(n_jobs=45)]: Done   8 tasks      | elapsed:  8.4min
[Parallel(n_jobs=45)]: Done  22 out of 100 | elapsed:  8.6min remaining: 30.5min
[Parallel(n_jobs=45)]: Done  33 out of 100 | elapsed:  8.7min remaining: 17.7min
[Parallel(n_jobs=45)]: Done  44 out of 100 | elapsed:  8.9min remaining: 11.3min
[Parallel(n_jobs=45)]: Done  55 out of 100 | elapsed: 16.2min remaining: 13.3min
[Parallel(n_jobs=45)]: Done  66 out of 100 | elapsed: 16.4min remaining:  8.5min
[Parallel(n_jobs=45)]: Done  77 out of 100 | elapsed: 16.6min remaining:  5.0min
[Parallel(n_jobs=45)]: Done  88 out of 100 | elapsed: 16.8min remaining:  2.3min


In [58]:
print(shuffle_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_intermediate_results/dlight_snippets_offline_features_withinbin_lag_global_entropy_shuffle_for_index_per_mouse.parquet
