# Compute syllable counts and scalars in an expanding window after syllable onset

1. Note that you can use a dask scheduler by setting `dask_address` to the address of the scheduler (see `../analysis_configuration.toml`)

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.batch import apply_parallel_joblib
from rl_analysis.io.df import dlight_exclude
from rl_analysis.photometry.lagged import get_syllable_and_scalar_rates
from functools import partial
from joblib import Parallel, delayed
from contextlib import redirect_stderr

# only needed if you're using dask
import tempfile
import dask.delayed as dask_delayed
from distributed import Client
import pandas as pd
import numpy as np
import os
import toml
import sys

terminal = sys.__stderr__

In [3]:
force = True

In [4]:
from sklearn.utils import shuffle

In [5]:
with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [6]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
dlight_cfg = analysis_config["dlight_common"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
dask_address = analysis_config["dask"].get("address")
# dask_address = None # FORCES THE USE OF MULTIPROCESSING

In [7]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [9]:
use_area = "dls"

In [10]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

if lagged_cfg["use_renormalized"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)
if lagged_cfg["estimate_within_bin"]:
    file = f"{file}_withinbin"

results_file = f"{file}_lag_usage_and_scalars_{use_area}.parquet"
corr_file = f"{file}_lag_usage_and_scalars_correlations_{use_area}.parquet"
shuffle_file = f"{file}_lag_usage_and_scalars_shuffle_{use_area}.parquet"

In [11]:
use_features = []
for _use_win in lagged_cfg["use_windows"]:
    use_features += [f"{_}_{_use_win}" for _ in lagged_cfg["use_neural_features"]]
scalar_keys = lagged_cfg["usage_and_scalars"]["scalars"]

## Load in pre-processed data

In [12]:
feature_df = pd.read_parquet(features_save_file)
rle_df = pd.read_parquet(rle_save_file)

In [13]:
feature_df = feature_df.loc[
    (feature_df["area"] == use_area) & (~feature_df["session_number"].isin([1, 2]))
].copy()

In [14]:
feature_df = feature_df.loc[
    feature_df["window_tup"].isin(lagged_cfg["use_windows"])
].copy()

In [15]:
with open(
    os.path.join(proc_dirs["dlight"], "lagged_analysis_session_bins.toml"), "r"
) as f:
    use_session_bins = toml.load(f)["session_bins"]

## Stage data for downstream computation

In [16]:
feature_df["time_bin"] = pd.cut(
    feature_df["timestamp"], use_session_bins, labels=False, include_lowest=True
)

In [17]:
try:
    feature_df = feature_df.set_index("window_tup", append=True)
except KeyError:
    pass

In [18]:
wins = feature_df.index.get_level_values(-1).unique()

In [19]:
idx = pd.IndexSlice

In [20]:
dfs = []
for _idx in tqdm(wins):
    use_vals = feature_df.loc[idx[:, _idx], lagged_cfg["use_neural_features"]]
    use_vals.columns = [f"{_}_{_idx}" for _ in use_vals.columns]
    use_vals.index = use_vals.index.droplevel(-1)
    dfs.append(use_vals)

  0%|          | 0/1 [00:00<?, ?it/s]

In [21]:
meta_cols = feature_df.columns.difference(feature_df.filter(regex="dff").columns)

In [22]:
meta_df = feature_df[meta_cols].loc[idx[:, _idx], :]
meta_df.index = meta_df.index.droplevel(-1)

In [23]:
feature_df = pd.concat(dfs, axis=1).join(meta_df)

In [24]:
rle_df["duration"] = rle_df.groupby("uuid")["timestamp"].shift(-1) - rle_df["timestamp"]

In [25]:
feature_df = feature_df.reset_index()

In [26]:
rle_df["syllable_num"] = rle_df.groupby(["uuid"])["syllable"].transform(
    lambda x: np.arange(len(x))
)
feature_df = feature_df.sort_values(["timestamp", "uuid"])
rle_df = rle_df.sort_values(["timestamp", "uuid"])
rle_df["timestamp"] = rle_df["timestamp"].astype("float32")

feature_df["uuid"] = feature_df["uuid"].astype("str")
feature_df = pd.merge_asof(
    feature_df, rle_df[["timestamp", "uuid", "syllable_num"]], on="timestamp", by="uuid"
)
feature_df = feature_df.sort_values(["uuid", "timestamp"])
rle_df = rle_df.sort_values(["uuid", "timestamp"])

In [27]:
usages = rle_df.groupby("uuid")["syllable"].value_counts().groupby("syllable").mean()
usages = usages[usages.index != -5]
use_syllables = usages.sort_values()[::-1].index[:].tolist()

## Now we compute values triggered on syllable instances, storing features we want to split by downstream (dLight, scalars, etc.)

In [28]:
usage_bins = np.arange(*lagged_cfg["usage_and_scalars"]["bins"])

idx = pd.IndexSlice

feature_df = feature_df.reset_index(drop=True)

if lagged_cfg["estimate_within_bin"]:
    group_keys = ["uuid", "time_bin"]
else:
    group_keys = ["uuid"]
group_obj = feature_df.groupby(group_keys)

use_syllables = [int(_) for _ in use_syllables]

# we'll get syllable specific and global versions of each features
chk_features = [(0, _) for _ in scalar_keys] + [(1, _) for _ in use_features]
chk_features_specific = chk_features

In [29]:
if (not os.path.exists(results_file)) or force:
    func = partial(
        get_syllable_and_scalar_rates,
        chk_syllables=use_syllables,
        usage_bins=usage_bins,
        dlight_features=use_features,
        additional_keys_global=chk_features,
        additional_keys_specific=chk_features_specific,
    )
    njobs = group_obj.ngroups
    with redirect_stderr(terminal):
        print(f"{njobs} jobs to process", file=terminal)
        syllable_rates = apply_parallel_joblib(
            group_obj,
            func,
            n_jobs=-1,
            verbose=10,
            backend="dask" if dask_address is not None else "loky",
            dask_address=dask_address,
        )
    syllable_rates.to_parquet(results_file)
else:
    syllable_rates = pd.read_parquet(results_file)

2500 jobs to process


# Z-score the results

In [30]:
syllable_rates = syllable_rates.loc[syllable_rates["area"] == use_area].copy()
syllable_rates = dlight_exclude(
    syllable_rates, exclude_3s=False, syllable_key="syllable", **dlight_cfg
)

In [31]:
from copy import deepcopy

In [32]:
correlation_keys = deepcopy(lagged_cfg["usage_and_scalars"]["correlation_keys"])
if lagged_cfg["estimate_within_bin"]:
    correlation_keys += ["time_bin"]

bin_keys = syllable_rates.filter(regex="_bin$").columns.tolist()

In [33]:
z_features = ["count"] + bin_keys
z_features = syllable_rates.columns.intersection(z_features).tolist()

In [34]:
use_features = syllable_rates.columns.intersection(use_features).tolist()

syllable_rates = syllable_rates.reset_index()
syllable_rates["timestamp"] = syllable_rates["timestamp"].astype("float32")
syllable_rates[z_features] = syllable_rates[z_features].astype("float32")

In [35]:
if lagged_cfg["estimate_within_bin"]:
    syllable_rates["time_bin"] = syllable_rates["time_bin"].astype("uint8")

In [36]:
syllable_rates[use_features] = syllable_rates[use_features].astype("float32")

In [37]:
size_stats = {}

In [38]:
correlation_sz = syllable_rates.groupby(correlation_keys)[
    z_features + use_features
].size()
size_stats["n_total_syllable_uuid_time_bin_points"] = len(syllable_rates)
size_stats["n_per_group"] = correlation_sz.describe().to_dict()

size_stats["n_mice"] = correlation_sz.index.get_level_values("mouse_id").nunique()
size_stats["n_bins"] = correlation_sz.index.get_level_values("time_bin").nunique()
size_stats["n_syllables"] = correlation_sz.index.get_level_values("syllable").nunique()
# size_stats["n_groups"] = len(correlation_sz)

with open(
    os.path.join(
        proc_dirs["dlight"], f"stats_lagged_usage_and_scalars_n_{use_area}.toml"
    ),
    "w",
) as f:
    toml.dump(size_stats, f)

In [39]:
correlation_sz.rename("size").reset_index().to_parquet(
    os.path.join(
        proc_dirs["dlight"], f"stats_lagged_usage_and_scalars_n_{use_area}.parquet"
    )
)

In [40]:
%%time
if not os.path.exists(corr_file) or force:
    obs_corrs = syllable_rates.groupby(correlation_keys)[
        z_features + use_features
    ].corr(**lagged_cfg["correlation_kwargs"])
    obs_corrs.index = obs_corrs.index.set_names("feature", level=-1)
    obs_corrs.to_parquet(corr_file)
else:
    # pass
    obs_corrs = pd.read_parquet(corr_file)

CPU times: user 15.9 s, sys: 1.45 s, total: 17.3 s
Wall time: 17.4 s


In [42]:
convs = {
    "syllable": "int8",
    "bin": "uint16",
}
for k, v in convs.items():
    syllable_rates[k] = syllable_rates[k].astype(v)

In [43]:
shuffle_by = ["uuid", "bin"]
if lagged_cfg["estimate_within_bin"]:
    shuffle_by += ["time_bin"]

agg_group_by = ["bin", "feature"]
if lagged_cfg["estimate_within_bin"]:
    agg_group_by += ["time_bin"]

In [44]:
all_shuffle_keys = list(set(use_features + shuffle_by + correlation_keys + z_features))

In [45]:
def shuffler_corr(
    df: pd.DataFrame,
    idx: int,
    shuffle_cols: list[str] = use_features,
    suffix: str = "shuffle",
    group_func: str = "mean",
    shuffle_by: list[str] = shuffle_by,
    group_by: list[str] = correlation_keys,
    agg_group_by: list[str] = agg_group_by,
    agg_func: str = "mean",
    compute_cols: list[str] = z_features + use_features,
) -> pd.DataFrame:

    use_df = df.copy()
    shuffle_idx = use_df.groupby(shuffle_by)[use_features[0]].transform(
        lambda x: shuffle(x.index)
    )
    shuffle_data = use_df.loc[shuffle_idx][use_features]
    use_df[use_features] = shuffle_data.values
    shuffle_df = use_df.groupby(group_by)[compute_cols].corr(
        **lagged_cfg["correlation_kwargs"]
    )
    shuffle_df.index = shuffle_df.index.set_names("feature", level=-1)
    shuffle_df = shuffle_df.groupby(agg_group_by).mean()
    shuffle_df["idx"] = idx
    return shuffle_df

In [46]:
all_cols = list(set(shuffle_by + z_features + use_features + correlation_keys))

In [47]:
if not os.path.exists(shuffle_file) or force:
    if dask_address is not None:
        client = Client(dask_address)
        scattered_data = client.scatter(syllable_rates[all_cols], broadcast=True)
        futures = client.compute(
            [
                dask_delayed(shuffler_corr)(scattered_data, sidx)
                for sidx in range(lagged_cfg["nshuffles"])
            ],
            resources={"WORKERS": 1},
        )
        ret_list = client.gather(futures)
        client.cancel(scattered_data)
    else:
        with redirect_stderr(terminal):
            ret_list = Parallel(n_jobs=40, verbose=10, temp_folder=tempfile.gettempdir())(
                delayed(shuffler_corr)(syllable_rates[all_cols], sidx)
                for sidx in range(lagged_cfg["nshuffles"])
            )
    shuffle_df = pd.concat(ret_list, axis=0)
    shuffle_df.to_parquet(shuffle_file)

In [48]:
print(shuffle_file)

/home/markowitzmeister_gmail_com/jeff_win_share/reinforcement_data/_final_test/_data/dlight_intermediate_results/dlight_snippets_offline_features_withinbin_lag_usage_and_scalars_shuffle_dls.parquet
