# Here we're computing features for the "syllable-time" or "kernel regression" model

The most important features here are:

1. counts
2. entropy
3. scalars

Each syllable statistic is computed for the current syllable over a series of bin sizes, e.g. count_5 represents how often the current syllable at time T is expressed from T to T + 5.  Same for entropy.  

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.util import rle, count_transitions, pd_zscore
from rl_analysis.batch import apply_parallel_joblib
from rl_analysis.info.util import dm_entropy
from rl_analysis.io.df import dlight_exclude_toml
from rl_analysis.photometry.encoding.features import get_counts_persample, get_entropy_persample, split_array
from functools import partial

import pandas as pd
import numpy as np
import os

In [3]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
encoding_cfg = analysis_config["dlight_encoding_features"]
dlight_cfg = analysis_config["dlight_common"]

# load in the dlight data, preprocessing

In [5]:
partition_cols = ["area", "mouse_id", "uuid"]

In [6]:
use_dct = dlight_exclude_toml(
    os.path.join(raw_dirs["dlight"], "dlight_photometry_processed_full.toml"),
    **analysis_config["dlight_common"]
)

In [7]:
dlight_df = pd.read_parquet(
    os.path.join(raw_dirs["dlight"], "dlight_photometry_processed_full.parquet"),
    filters=[("uuid", "in", list(use_dct.keys()))],
).sort_index()

In [8]:
signal_keys = dlight_df.filter(regex="(signal|reference|reref)_dff$").columns.tolist()
dlight_df["labels"] = (
    dlight_df[encoding_cfg["label_key"]].replace(-5, np.nan).astype("UInt8")
)
dlight_df = dlight_df.dropna(subset=["labels"])

In [9]:
dlight_df = dlight_df[~dlight_df["session_number"].isin([1, 2, 3, 4])].copy()

In [10]:
dlight_df[partition_cols] = dlight_df[partition_cols].astype("str")
dlight_df["syllable_number"] = dlight_df.groupby("uuid", group_keys=False)["labels"].transform(
    lambda x: (x.astype("float").diff() != 0).cumsum()
)

In [11]:
rle_df = dlight_df.groupby("uuid")["labels"].apply(rle)
rle_df.index = rle_df.index.droplevel(0)
rle_df = dlight_df.loc[rle_df.index]

## Get counts...

In [12]:
count_df = (
    rle_df.groupby("uuid")["labels"].value_counts(normalize=False).rename("count")
)
count_df_z = (
    count_df.groupby("uuid")
    .transform(lambda x: (x - x.mean()) / x.std())
    .rename("count_z")
)

In [13]:
dlight_df.drop(
    ["count", "count_z", "count_ave", "count_ave_z"],
    axis=1,
    errors="ignore",
    inplace=True,
)

In [14]:
if dlight_df.index.names != count_df.index.names:
    try:
        dlight_df = dlight_df.set_index(count_df.index.names)
    except KeyError:
        dlight_df = dlight_df.reset_index().set_index(count_df.index.names)
dlight_df = dlight_df.join(count_df, how="left").join(count_df_z, how="left")

In [15]:
count_ave_df = (
    rle_df.groupby("mouse_id")["labels"]
    .value_counts(normalize=True)
    .rename("count_ave")
)
count_ave_df_z = (
    count_ave_df.groupby("mouse_id")
    .transform(lambda x: (x - x.mean()) / x.std())
    .rename("count_ave_z")
)

In [16]:
if dlight_df.index.names != count_ave_df.index.names:
    try:
        dlight_df = dlight_df.set_index(count_ave_df.index.names)
    except KeyError:
        dlight_df = dlight_df.reset_index().set_index(count_ave_df.index.names)
dlight_df = dlight_df.join(count_ave_df, how="left").join(count_ave_df_z, how="left")
# dlight_df = pd.concat([dlight_df, count_ave_df, count_ave_df_z], axis=1)

## Get entropy (and maybe transition rank)

In [17]:
if "offline" in encoding_cfg["label_key"]:
    syllable_stats = toml.load(
        os.path.join(proc_dirs["dlight"], "syllable_stats_photometry_offline.toml")
    )
else:
    syllable_stats = toml.load(
        os.path.join(proc_dirs["dlight"], "syllable_stats_photometry_online.toml")
    )
usage = syllable_stats["usages"]
mapping = {int(k): int(v) for k, v in syllable_stats["syllable_to_sorted_idx"].items()}
reverse_mapping = {
    int(k): int(v) for k, v in syllable_stats["sorted_idx_to_syllable"].items()
}

In [18]:
# entropy is computed with usage sorted syllable IDs, we then map back to original IDs
# after the calculation
rle_df["labels_sorted"] = rle_df["labels"].map(mapping)

tms = rle_df.groupby(["mouse_id", "uuid"])["labels_sorted"].apply(
    lambda x: count_transitions(x, K=100)
)
truncate = syllable_stats["truncate"]

def ent_func(x: np.ndarray) -> float:
    return dm_entropy(x[:truncate], alpha="perks", marginalize=False, axis=None)

In [19]:
tm_df_rows = tms.apply(lambda x: split_array(x, mapping=reverse_mapping)).stack()
tm_df_rows = tm_df_rows.apply(lambda x: np.array(x))

In [20]:
agg_df_rows = tm_df_rows.groupby(["uuid", "labels"]).sum()
entropy_df = agg_df_rows.apply(ent_func).rename("entropy_out")
entropy_df_z = (
    entropy_df.groupby("uuid")
    .transform(lambda x: (x - x.mean()) / x.std())
    .rename("entropy_out_z")
)

In [21]:
agg_ave_df_rows = tm_df_rows.groupby(["mouse_id", "labels"]).sum()
entropy_ave_df = agg_ave_df_rows.apply(ent_func).rename("entropy_out_ave")
entropy_ave_df_z = (
    entropy_ave_df.groupby("mouse_id")
    .transform(lambda x: (x - x.mean()) / x.std())
    .rename("entropy_out_ave_z")
)

In [22]:
rle_df["labels_sorted_next"] = rle_df.groupby("uuid")["labels_sorted"].shift(-1)

In [23]:
dlight_df.drop(
    ["entropy_out", "entropy_out_z", "entropy_out_ave", "entropy_out_ave_z"],
    axis=1,
    inplace=True,
    errors="ignore",
)

In [24]:
if dlight_df.index.names != entropy_df.index.names:
    try:
        dlight_df = dlight_df.set_index(entropy_df.index.names)
    except KeyError:
        dlight_df = dlight_df.reset_index().set_index(entropy_df.index.names)
dlight_df = dlight_df.join(entropy_df, how="left").join(entropy_df_z, how="left")
# dlight_df = pd.concat([dlight_df, entropy_df, entropy_df_z], axis=1)

In [25]:
if dlight_df.index.names != entropy_ave_df.index.names:
    try:
        dlight_df = dlight_df.set_index(entropy_ave_df.index.names)
    except KeyError:
        dlight_df = dlight_df.reset_index().set_index(entropy_ave_df.index.names)
dlight_df = dlight_df.join(entropy_ave_df, how="left").join(
    entropy_ave_df_z, how="left"
)

In [26]:
def get_trans_prob(x, key1="labels_sorted", key2="labels_sorted_next"):

    index = x.index
    vals1 = x[key1].values
    vals2 = x[key2].values

    nans = np.isnan(vals1) | np.isnan(vals2)
    vals1 = vals1[~nans].astype("int")
    vals2 = vals2[~nans].astype("int")
    index = index[~nans]
    # get that tm

    use_tm = tms.loc[x.name].copy()
    use_tm = use_tm / use_tm.sum(axis=1, keepdims=True)
    ps = use_tm[vals1, vals2]
    new_series = pd.Series(data=ps, index=index)
    return new_series

In [27]:
trans_p_df = (
    rle_df.set_index("syllable_number")
    .groupby(["mouse_id", "uuid"])
    .apply(get_trans_prob)
    .rename("p_out")
)

  use_tm = use_tm / use_tm.sum(axis=1, keepdims=True)


In [28]:
dlight_df.drop("p_out", axis=1, inplace=True, errors="ignore")
try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass
dlight_df = dlight_df.set_index(trans_p_df.index.names)
dlight_df = dlight_df.join(trans_p_df, how="left")

## Global entropy (sliding window)

In [29]:
group_obj = rle_df.set_index("syllable_number").groupby("uuid")["labels_sorted"]

In [30]:
func = partial(
    get_entropy_persample, bin_sizes=encoding_cfg["window_sizes"], truncate=truncate
)

In [31]:
entropy_bins = apply_parallel_joblib(group_obj, func, verbose=10, n_jobs=-1)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:   30.1s
[Parallel(n_jobs=-1)]: Done  54 out of 280 | elapsed:   34.3s remaining:  2.4min
[Parallel(n_jobs=-1)]: Done  83 out of 280 | elapsed:   35.9s remaining:  1.4min
[Parallel(n_jobs=-1)]: Done 112 out of 280 | elapsed:   37.7s remaining:   56.5s
[Parallel(n_jobs=-1)]: Done 141 out of 280 | elapsed:   41.4s remaining:   40.8s
[Parallel(n_jobs=-1)]: Done 170 out of 280 | elapsed:   45.4s remaining:   29.4s
[Parallel(n_jobs=-1)]: Done 199 out of 280 | elapsed:   47.6s remaining:   19.4s
[Parallel(n_jobs=-1)]: Done 228 out of 280 | elapsed:   49.1s remaining:   11.2s
[Parallel(n_jobs=-1)]: Done 257 out of 280 | elapsed:   50.7s remaining:    4.5s
[Parallel(n_jobs=-1)]: Done 280 out of 280 | elapsed:   52.3s finished


In [32]:
dlight_df.drop(entropy_bins.columns, axis=1, inplace=True, errors="ignore")

In [33]:
try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass

In [34]:
dlight_df = dlight_df.set_index(entropy_bins.index.names)
dlight_df = dlight_df.join(entropy_bins, how="left")

## Sliding window syllable counts

In [35]:
func = partial(
    get_counts_persample, bin_sizes=encoding_cfg["window_sizes"], truncate=truncate
)

In [36]:
count_bins = apply_parallel_joblib(group_obj, func, n_jobs=-1)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    3.6s
[Parallel(n_jobs=-1)]: Done  54 out of 280 | elapsed:    5.5s remaining:   22.8s
[Parallel(n_jobs=-1)]: Done  83 out of 280 | elapsed:    6.5s remaining:   15.5s
[Parallel(n_jobs=-1)]: Done 112 out of 280 | elapsed:    7.9s remaining:   11.9s
[Parallel(n_jobs=-1)]: Done 141 out of 280 | elapsed:    9.2s remaining:    9.0s
[Parallel(n_jobs=-1)]: Done 170 out of 280 | elapsed:   10.5s remaining:    6.8s
[Parallel(n_jobs=-1)]: Done 199 out of 280 | elapsed:   11.6s remaining:    4.7s
[Parallel(n_jobs=-1)]: Done 228 out of 280 | elapsed:   12.3s remaining:    2.8s
[Parallel(n_jobs=-1)]: Done 257 out of 280 | elapsed:   12.9s remaining:    1.2s
[Parallel(n_jobs=-1)]: Done 280 out of 280 | elapsed:   13.8s finished


In [37]:
dlight_df.drop(count_bins.columns, axis=1, inplace=True, errors="ignore")

try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass

dlight_df = dlight_df.set_index(count_bins.index.names)
dlight_df = dlight_df.join(count_bins, how="left")

In [38]:
stat_cols = entropy_bins.columns.union(count_bins.columns)

In [39]:
dlight_df["labels_sorted"] = dlight_df["labels"].map(mapping)

In [40]:
try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass

In [41]:
within_z = dlight_df.groupby(["uuid", "labels_sorted"])[stat_cols].transform(
    pd_zscore
)
between_z = dlight_df.groupby(["uuid"])[stat_cols].transform(pd_zscore)

In [42]:
between_cols = []
within_cols = []
for _col in stat_cols:
    var, bin_size = _col.split("_")
    between_key = f"{var}_between_{bin_size}"
    within_key = f"{var}_within_{bin_size}"
    between_cols.append(between_key)
    within_cols.append(within_key)

In [43]:
within_z.columns = within_cols
between_z.columns = between_cols

In [44]:
dlight_df.drop(within_z.columns, axis=1, inplace=True, errors="ignore")

try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass

dlight_df = dlight_df.join(within_z, how="left")

In [45]:
dlight_df.drop(between_z.columns, axis=1, inplace=True, errors="ignore")

try:
    dlight_df = dlight_df.reset_index()
except ValueError:
    pass

dlight_df = dlight_df.join(between_z, how="left")

In [46]:
# make sure we re-sort back into uuid/timestamp
dlight_df = dlight_df.sort_values(["uuid", "timestamp"]).reset_index(drop=True)

In [47]:
diff_cols = [
    _
    for _ in dlight_df.filter(regex="(entropy|count|counts|trans|p_out)").columns
    if "diff" not in _ and "trial" not in _
]

In [49]:
# take derivatives over stats columns
for _col in tqdm(diff_cols):
    dlight_df[f"{_col}_diff"] = dlight_df.groupby("uuid")[_col].diff()

  0%|          | 0/69 [00:00<?, ?it/s]

In [50]:
dlight_df["labels_sorted"] = dlight_df["labels"].map(mapping)
dlight_df["is_transition"] = dlight_df.groupby("uuid", group_keys=False)["labels_sorted"].transform(
    lambda x: x.diff() != 0
)
dlight_df["count_trans_only"] = dlight_df["count"] * dlight_df["is_transition"]

In [51]:
dlight_df.to_parquet(
    os.path.join(proc_dirs["dlight"], "encoding_model_data_preprocessed.parquet")
)