# Here we are computing the "binned" correlation between dLight and scalar variables shown in Figure 1

Specifically, in this notebook we are:

1. Z-scoring the dLight data per session
1. Binning dLight and scalars
1. For dLight computing the average per bin and the peak rate per bin

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from rl_analysis.io.df import dlight_exclude
from rl_analysis.util import rle
from functools import partial
from sklearn.utils import shuffle
from joblib import delayed, Parallel
from typing import Sequence
from contextlib import redirect_stderr

import toml
import numpy as np
import pandas as pd
import os
import sys

terminal = sys.__stderr__

## Load in new dlight data and preprocess

In [3]:
with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
dlight_cfg = analysis_config["dlight_basic_analysis"]
dlight_common_cfg = analysis_config["dlight_common"]

In [5]:
dlight_df = pd.read_parquet(os.path.join(raw_dirs["dlight"], "dlight_photometry_processed_full.parquet"))

signal_keys = dlight_df.filter(regex="(signal|reference|reref)_dff$").columns.tolist()

dlight_df["labels"] = dlight_df["predicted_syllable (offline)"].replace(-5, np.nan).astype("UInt8")
dlight_df = dlight_df.dropna(subset=["labels"])

data_keys = [
    "transition",
] + dlight_cfg["scalars"]

dlight_df["transition"] = dlight_df.groupby("uuid", group_keys=False)["labels"].apply(
    lambda x: (x.diff() != 0).astype("float")
)
data_keys = dlight_df.columns.intersection(data_keys).tolist()

transition_tstamps = dlight_df.loc[dlight_df["transition"] == 1, "timestamp"]

dlight_df["transition_time"] = np.nan
dlight_df.loc[dlight_df["transition"] == 1, "transition_time"] = transition_tstamps

## Strip out the data we need

In [6]:
signal_keys_z = [f"{_}_z" for _ in signal_keys]

In [7]:
all_signal_keys = signal_keys_z
all_signal_keys = dlight_df.columns.intersection(all_signal_keys).tolist()

In [8]:
rle_df = dlight_df.groupby("uuid")["labels"].apply(rle)
rle_df = rle_df.dropna().astype("int8")
sorted_cats = rle_df.value_counts().index

In [9]:
# we can't cleanly exclude target so get rid of all post sessions
use_data = dlight_exclude(
    dlight_df, syllable_key="labels", exclude_target=False, **dlight_common_cfg
).sort_index()
use_data = use_data.loc[
    ~((use_data["opsin"] == "chrimson") & (use_data["session_number"].isin([3, 4])))
].copy()

cats = np.arange(len(sorted_cats))
zdata_keys = [f"{_}_z" for _ in data_keys]

# re-standardize dLight data so we're using comparable thresholds...
use_data[all_signal_keys] = use_data.groupby("uuid")[all_signal_keys].transform(
    lambda x: (x - x.mean()) / x.std()
)

use_data[zdata_keys] = use_data.groupby("uuid")[data_keys].transform(
    lambda x: (x - x.mean()) / x.std()
)
use_data["labels"] = use_data["labels"].astype("int8")

In [10]:
use_data["syllable_number"] = (
    use_data.groupby(["uuid"], group_keys=False)["labels"].apply(lambda x: (x.diff() != 0).cumsum()).astype("uint16")
)
use_data["uuid"] = use_data["uuid"].astype("category")
use_data["area"] = use_data["area"].astype("category")
use_data["mouse_id"] = use_data["mouse_id"].astype("category")

In [11]:
syllable_size = use_data.groupby(["uuid", "syllable_number"]).size().rename("duration")
use_data = (
    use_data.set_index(["uuid", "syllable_number"]).join(syllable_size).reset_index()
)

## Define some useful helper functions

In [12]:
timescales = np.arange(*dlight_cfg["timescale_correlation"]["bins"])

In [13]:
use_data["uuid"] = use_data["uuid"].astype("str")
use_data["mouse_id"] = use_data["mouse_id"].astype("str")

In [14]:
all_signal_keys = ["signal_reref_dff_z"]

In [15]:
def peak_rate_cross(dat: pd.DataFrame, threshold: float = 1.96):
    values = dat.to_numpy()
    if len(values) > 2:
        return np.mean((values[:-1] < threshold) & (values[1:] > threshold))
    else:
        return np.nan

In [16]:
def bin_data(
    timescale: float,
    all_signal_keys: Sequence[str] = all_signal_keys,
    data_keys: Sequence[str] = data_keys + ["labels", "duration"],
    neural_agg: str = "mean",
    data_agg: str = "mean",
) -> pd.DataFrame:

    data = use_data

    if timescale == "syllable":
        time_bin = data["syllable_number"]
    else:
        time_bin = pd.cut(
            data["timestamp"], bins=np.arange(0, 1900, timescale), labels=False
        ).astype("int16")

    if neural_agg == data_agg:
        agg_matrix = data.groupby(
            [data["uuid"], data["mouse_id"], time_bin], observed=True
        )[all_signal_keys + data_keys].agg(neural_agg)
    else:
        y = data.groupby([data["uuid"], data["mouse_id"], time_bin], observed=True)[
            all_signal_keys
        ].agg(neural_agg)
        x = data.groupby([data["uuid"], data["mouse_id"], time_bin], observed=True)[
            data_keys
        ].agg(data_agg)
        agg_matrix = x.join(y)

    # only use this filter if we're not using syllable!
    if timescale != "syllable":
        # trim out data w/ missing frames and/or edge effects
        sz = data.groupby(
            [data["uuid"], data["mouse_id"], time_bin], observed=True
        ).size()
        modal_sz = sz.mode().iat[0]
        agg_matrix = agg_matrix.loc[sz.loc[sz == modal_sz].index]

    agg_matrix["timescale"] = timescale

    try:
        agg_matrix["neural_agg"] = neural_agg.__name__
    except AttributeError:
        agg_matrix["neural_agg"] = neural_agg
    return agg_matrix

In [17]:
func = partial(bin_data, neural_agg=peak_rate_cross)
delays = [delayed(func)(_timescale) for _timescale in timescales]
delays += [delayed(func)("syllable")]

func = partial(bin_data, neural_agg="mean")
delays += [delayed(func)(_timescale) for _timescale in timescales]
delays += [delayed(func)("syllable")]


with redirect_stderr(terminal):
    print(f"{len(delays)} jobs to process", file=terminal)
    agg_mats = Parallel(n_jobs=30, verbose=10, backend="multiprocessing")(delays)

242 jobs to process
[Parallel(n_jobs=30)]: Using backend MultiprocessingBackend with 30 concurrent workers.
[Parallel(n_jobs=30)]: Done   1 tasks      | elapsed:   19.6s
[Parallel(n_jobs=30)]: Done  12 tasks      | elapsed:   20.9s
[Parallel(n_jobs=30)]: Done  25 tasks      | elapsed:   23.8s
[Parallel(n_jobs=30)]: Done  38 tasks      | elapsed:   31.4s
[Parallel(n_jobs=30)]: Done  53 tasks      | elapsed:   35.0s
[Parallel(n_jobs=30)]: Done  68 tasks      | elapsed:   42.6s
[Parallel(n_jobs=30)]: Done  85 tasks      | elapsed:   49.7s
[Parallel(n_jobs=30)]: Done 102 tasks      | elapsed:   53.7s
[Parallel(n_jobs=30)]: Done 121 tasks      | elapsed:  1.0min
[Parallel(n_jobs=30)]: Done 140 tasks      | elapsed:  1.1min
[Parallel(n_jobs=30)]: Done 161 tasks      | elapsed:  1.2min
[Parallel(n_jobs=30)]: Done 182 tasks      | elapsed:  1.3min
[Parallel(n_jobs=30)]: Done 208 out of 242 | elapsed:  1.4min remaining:   13.5s
[Parallel(n_jobs=30)]: Done 233 out of 242 | elapsed:  1.5min remai

In [18]:
agg_data = pd.concat(agg_mats).reset_index()

In [19]:
agg_data.loc[agg_data["timescale"] == "syllable", "timescale"] = -5
agg_data["timescale"] = agg_data["timescale"].astype("int")

In [20]:
use_agg_data = agg_data.copy()

In [21]:
corr_kwargs = {
    "method": "pearson"
}

In [22]:
corrs = (
    use_agg_data.dropna()
    .groupby(["timescale", "uuid", "neural_agg"])[data_keys + all_signal_keys]
    .corr(**corr_kwargs)
)

In [23]:
corrs.index = corrs.index.rename("feature", level=-1)

In [24]:
def shuffler(
    idx: int,
    shuffle_col: str = "signal_reref_dff_z",
    shuffle_group_by: Sequence[str] = ["timescale", "uuid", "neural_agg"],
    corr_group_by: Sequence[str] = ["timescale", "uuid", "neural_agg"],
    corr_keys: Sequence[str] = data_keys + all_signal_keys,
    corr_kwargs: dict = corr_kwargs,
):

    use_df = use_agg_data.copy()
    use_df[shuffle_col] = (
        use_df.groupby(shuffle_group_by)[shuffle_col]
        .apply(lambda x: shuffle(x, random_state=idx))
        .values
    )

    corrs = use_df.groupby(corr_group_by)[corr_keys].corr(**corr_kwargs)
    corrs.index = corrs.index.set_names("feature", level=-1)
    #     corrs = corrs.reset_index()
    corrs["index"] = idx
    corrs = corrs.xs(shuffle_col, level="feature")
    return corrs

In [25]:
with redirect_stderr(terminal):
    nshuffles = dlight_cfg["timescale_correlation"]["nshuffles"]
    print(f"{nshuffles} jobs to process", file=terminal)
    shuffle_mats = Parallel(n_jobs=-5, verbose=10, backend="multiprocessing")(
        [delayed(shuffler)(_idx) for _idx in range(nshuffles)]
    )

1000 jobs to process
[Parallel(n_jobs=-5)]: Using backend MultiprocessingBackend with 124 concurrent workers.
[Parallel(n_jobs=-5)]: Done  17 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-5)]: Done  40 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-5)]: Done  65 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-5)]: Done  90 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-5)]: Done 117 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-5)]: Done 144 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-5)]: Done 173 tasks      | elapsed:  2.1min
[Parallel(n_jobs=-5)]: Done 202 tasks      | elapsed:  2.2min
[Parallel(n_jobs=-5)]: Done 233 tasks      | elapsed:  2.2min
[Parallel(n_jobs=-5)]: Done 264 tasks      | elapsed:  2.9min
[Parallel(n_jobs=-5)]: Done 297 tasks      | elapsed:  3.1min
[Parallel(n_jobs=-5)]: Done 330 tasks      | elapsed:  3.2min
[Parallel(n_jobs=-5)]: Done 365 tasks      | elapsed:  3.2min
[Parallel(n_jobs=-5)]: Done 400 tasks      | elapsed:  3.9min
[Parallel(n_jobs=-5)]:

In [26]:
shuffle_df = pd.concat(shuffle_mats)

In [27]:
try:
    corrs = corrs.xs("signal_reref_dff_z", level="feature")
except KeyError:
    pass

In [28]:
meta_keys = ["mouse_id", "area", "opsin"]
metadata = dlight_df.drop_duplicates(["uuid"]).set_index("uuid")[meta_keys]

In [29]:
for _key in meta_keys:
    if _key not in corrs.index.names:
        corrs[_key] = corrs.index.get_level_values("uuid").map(metadata[_key])
        corrs = corrs.set_index(_key, append=True)

    if _key not in shuffle_df.index.names:
        shuffle_df[_key] = shuffle_df.index.get_level_values("uuid").map(metadata[_key])
        shuffle_df = shuffle_df.set_index(_key, append=True)

    agg_data[_key] = agg_data["uuid"].map(metadata[_key])

In [30]:
agg_data.to_parquet(
    os.path.join(proc_dirs["dlight"], "scalar_correlations_data.parquet")
)
corrs.to_parquet(
    os.path.join(proc_dirs["dlight"], "scalar_correlations.parquet")
)
shuffle_df.to_parquet(
    os.path.join(proc_dirs["dlight"], "scalar_correlations_shuffle.parquet")
)