In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from tqdm.auto import tqdm
from rl_analysis.plotting import (
    setup_plotting_env,
    clean_plot_labels,
    clean_ticks,
    savefig,
    plot_pval,
)
from functools import partial
from joblib import Parallel, delayed

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns

In [None]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [None]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
figure_cfg = analysis_config["figures"]

In [None]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [None]:
use_area = "dls"

In [None]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"
if lagged_cfg["use_renormalized"]:
    file, ext = os.path.splitext(features_save_file)
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)
if lagged_cfg["estimate_within_bin"]:
    file = f"{file}_withinbin"

results_file = f"{file}_lag_usage_and_scalars_{use_area}.parquet"
corr_file = f"{file}_lag_usage_and_scalars_correlations_{use_area}.parquet"
shuffle_file = f"{file}_lag_usage_and_scalars_shuffle_{use_area}.parquet"

In [None]:
use_features = []
for _use_win in lagged_cfg["use_windows"]:
    use_features += [f"{_}_{_use_win}" for _ in lagged_cfg["use_neural_features"]]
scalar_keys = lagged_cfg["usage_and_scalars"]["scalars"]

# Plot it!

In [None]:
agg_keys = ["bin", "feature", "time_bin"]
final_agg_keys = ["feature", "bin"]

In [None]:
rle_df = pd.read_parquet(rle_save_file)
obs_corrs = pd.read_parquet(corr_file)
shuffle_df = pd.read_parquet(shuffle_file)

In [None]:
obs_corrs_raw = obs_corrs.copy()
obs_corrs = obs_corrs.groupby(agg_keys).mean()
shuffle_mus = shuffle_df.groupby(obs_corrs.index.names).mean()
shuffle_sigs = shuffle_df.groupby(obs_corrs.index.names).std()
obs_mu = ((obs_corrs - shuffle_mus) / shuffle_sigs).groupby(final_agg_keys).mean()

shuffle_mu = (
    shuffle_df.reset_index().set_index(agg_keys + ["idx"])[use_features]
    - shuffle_mus[use_features]
) / shuffle_sigs[use_features]

In [None]:
shuffle_compare = shuffle_mu.groupby(final_agg_keys + ["idx"]).mean()[use_features]
shuffle_compare.index = shuffle_compare.index.droplevel("idx")
obs_compare = obs_mu[use_features]

dfs = []
for _feature in use_features:
    _pval = obs_compare.groupby(["feature", "bin"])[_feature].apply(
        lambda x: (
            x.abs().values < shuffle_compare.loc[x.name][_feature].abs().values
        ).mean()
    )
    dfs.append(_pval)
pval_df = pd.concat(dfs, axis=1)

from statsmodels.stats.multitest import multipletests

for _feature in use_features:
    pval_df[_feature].values[:] = multipletests(
        pval_df[_feature].values.ravel(), method="holm"
    )[1]

In [None]:
chance = shuffle_mu.groupby(["feature", "bin"]).quantile(0.95).groupby("bin").max()

chk_features = use_features

beh_features = [
    "count",
    "velocity_2d_mm_global_bin",
    "velocity_2d_mm_specific_bin",
] + [f"{_}_global_bin" for _ in chk_features]

chk_features = obs_corrs.columns.intersection(chk_features).tolist()

chance = (
    shuffle_mu.abs().groupby(["feature", "bin"]).quantile(0.95).groupby("bin").mean()
)

In [None]:
# from scipy.stats import invgauss, norm
threshold = 0

In [None]:
from scipy.optimize import curve_fit, OptimizeWarning
from sklearn.utils import resample
from sklearn.metrics import r2_score
import warnings


def fitter(
    corrs,
    idx,
    neural_feature="signal_reref_dff_z_max_abs_peak",
    beh_feature="count",
    threshold=threshold,
    pin_zero=True,
    resample_input=True,
):

    corrs = resample(corrs, random_state=idx).groupby(agg_keys).mean()
    shuffle_mus = shuffle_df.groupby(corrs.index.names).mean()
    shuffle_sigs = shuffle_df.groupby(corrs.index.names).std()
    res_mu = (
        ((corrs - shuffle_mus) / shuffle_sigs)
        .groupby(final_agg_keys)
        .mean()
        .xs(beh_feature, level="feature")
        .groupby("bin")[neural_feature]
        .mean()
    )

    x = res_mu.index
    x -= x[0]
    y = res_mu.values

    maxloc = np.argmax(np.abs(y[:5]))
    if y[maxloc] < 0:
        y *= -1

    x = x[maxloc:]
    x -= x[0]
    y = y[maxloc:]

    if len(x) < 3:
        return {}

    if pin_zero:
        y -= threshold
        y = np.clip(y, 0, np.inf)
        use_threshold = 0
    else:
        use_threshold = threshold

    def func(x, a, b, c):
        return a * np.exp(-b * x) + c

    # set everything less than threshold to 0? this would get rid of weirdness with slope..
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", OptimizeWarning)
            popt, pcov = curve_fit(func, x, y, p0=[y[0], 0.05, 0])
        #                                bounds=([-100, -100, -100], [100, 100, 100]))
        fitted_vals = func(x, *popt)
        dct = {
            "a": popt[0],
            "b": popt[1],
            "offset": popt[2],
            "tau": 1 / popt[1],
            "index": idx,
            "beh_feature": beh_feature,
            "neural_feature": neural_feature,
            "r2": r2_score(y, fitted_vals),
            "maxloc": maxloc,
        }
    except (RuntimeError, ValueError):
        dct = {}

    return dct

In [None]:
dfs = []
delays = []
for _neural_feature in chk_features:
    for _feature in beh_features:
        _fitter = partial(
            fitter,
            beh_feature=_feature,
            neural_feature=_neural_feature,
            pin_zero=False,
            resample_input=True,
        )
        for _resample in range(lagged_cfg["nshuffles"]):
            delays.append(delayed(_fitter)(obs_corrs_raw, _resample))
print(len(delays))
ret_list = Parallel(n_jobs=-10, verbose=5)(delays)

In [None]:
nshuffles = len(shuffle_mu.index.get_level_values("idx").unique())

In [None]:
decay_df = pd.DataFrame([_ for _ in ret_list if _ is not None])

In [None]:
dfs = []
delays = []
for _neural_feature in chk_features:
    for _feature in beh_features:
        _fitter = partial(
            fitter,
            beh_feature=_feature,
            neural_feature=_neural_feature,
            pin_zero=False,
            resample_input=False,
        )
        for _resample in range(nshuffles):
            delays.append(
                delayed(_fitter)(shuffle_mu.xs(_resample, level="idx"), _resample)
            )
print(len(delays))
ret_list = Parallel(n_jobs=-10, verbose=5)(delays)

In [None]:
shuffle_decay_df = pd.DataFrame([_ for _ in ret_list if _ is not None])

In [None]:
use_decay_df = decay_df[
    (decay_df["r2"] > 0) & (decay_df["tau"].between(0, 1000)) & (decay_df["a"] > 0)
]
use_decay_df = decay_df

In [None]:
setup_plotting_env()

In [None]:
neural_feature = chk_features[0]

In [None]:
plt_features = [
    "count",
    "velocity_2d_mm_global_bin",
    f"{neural_feature}_global_bin",
]

In [None]:
show_decay = plt_features

In [None]:
duration = np.around(
    (rle_df.groupby("uuid")["timestamp"].shift(-1) - rle_df["timestamp"]).median(), 2
)

In [None]:
aliases = {
    "count": "counts",
    "pseudocount": "counts",
    "usage": "usages",
    "duration_bin": "duration",
    "duration_global_bin": "dur. (global)",
    "velocity_2d_mm_global_bin": "vel. (global)",
    "velocity_height_bin": "z vel.",
    "velocity_angle_bin": "ang. vel.",
    "acceleration_2d_mm_bin": "acc.",
    "total_duration": "time in syll.",
}
for _feature in chk_features:
    aliases[f"{_feature}_global_bin"] = "dlight (global)"
    aliases[f"{_feature}_specific_bin"] = "dlight (specific)"

In [None]:
use_offline = lagged_cfg["use_offline"]
renormalize = lagged_cfg["use_renormalized"]
within_bin = lagged_cfg["estimate_within_bin"]

In [None]:
smooth_kwargs = {"window": 1, "min_periods": 1, "center": True}

In [None]:
alpha_thresh = 0.05
continuity_thresh = 0

In [None]:
palette = sns.color_palette()

In [None]:
first_bin = obs_mu.index.get_level_values("bin").min()

In [None]:
syllable_stats = toml.load(
    os.path.join(proc_dirs["dlight"], "syllable_stats_photometry_offline.toml")
)
duration = np.around(float(syllable_stats["duration"]["median"]), 1)

In [None]:
for neural_feature in chk_features:
    fig, ax = plt.subplots(
        1,
        2,
        figsize=(2.6, 1.7),  # for pape
        # figsize=(5.2, 3.4), # for sharing
        sharex=False,
        sharey=False,
        gridspec_kw={"width_ratios": [2.5, 1]},
    )
    for _beh, _color in zip(plt_features, palette):
        ax[0].plot(
            obs_mu.loc[_beh, neural_feature].rolling(**smooth_kwargs).mean(),
            label=aliases[_beh] if _beh in aliases.keys() else _beh,
            alpha=1 if _beh in show_decay else 0.15,
            clip_on=False,
            color=_color,
        )

    keys = plt_features
    key_level = "feature"
    use_colors = palette

    plot_pval(
        pval_df[neural_feature],
        keys,
        key_level,
        ax=ax[0],
        colors=use_colors,
        offset=0.9,
        height=0.05,
        spacing=1.3,
        min_width=20,
        alpha_threshold=alpha_thresh,
        continuity_threshold=continuity_thresh,
    )

    ax[0].fill_between(
        chance.index,
        -chance[neural_feature].rolling(**smooth_kwargs).mean().values,
        +chance[neural_feature].rolling(**smooth_kwargs).mean().values,
        zorder=-200,
        color=[0.8, 0.8, 0.8],
    )
    ax[0].set_xlim(10, 400)
    ax[0].set_ylim(-7, 7)
    sns.despine(offset=4)
    clean_ticks(ax[0], "y", precision=0, dtype=int)

    sns.boxplot(
        data=use_decay_df.loc[use_decay_df["neural_feature"] == neural_feature],
        x="beh_feature",
        order=show_decay,
        hue="beh_feature",
        hue_order=plt_features,
        dodge=False,
        showfliers=False,
        showcaps=False,
        palette=palette,
        # legend=False,
        y="tau",
        ax=ax[1],
    )

    fig.legend().remove()
    ax[0].set_ylabel("Corr. (Pearson r, z)")

    syllables_to_time = lambda x: x * duration
    time_to_syllables = lambda x: x / duration

    # use of a float for the position:
    # ax[0].set_xlim(0, 500)

    ax[1].set_ylabel("Tau")
    ax[1].set_xlabel("")
    ax[1].set_ylim(0, 150)
    l = ax[1].legend(bbox_to_anchor=(0.8, 0.8), framealpha=1, loc="center left")
    [
        _.set_text(aliases[_.get_text()])
        if _.get_text() in aliases.keys()
        else _.get_text()
        for _ in l.get_texts()
    ]

    plt.setp(ax[1].get_xticklabels(), rotation=90, ha="center")
    clean_plot_labels(label_map=aliases)

    # ax[1].set_ylim(-25, 300)
    sns.despine(ax=ax[0])
    secax_x = ax[0].secondary_xaxis(
        -0.3, functions=(syllables_to_time, time_to_syllables)
    )
    ax[0].set_xticks([first_bin, 200, 400])
    secax_x.set_ticks([syllables_to_time(first_bin), 80, 160])
    # sns.despine(ax=axins)

    fig.suptitle(neural_feature, fontsize=7)
    fig.tight_layout()
    plt.show()

In [None]:
import json

with open(
    os.path.join(
        proc_dirs["dlight"], f"stats_lagged_usage_and_scalars_{use_area}.toml"
    ),
    "w",
) as f:
    plt_json = json.loads(pval_df.to_json())
    plt_json["stat_type"] = "Pearson correlation"
    plt_json["p_type"] = "Comparison to shuffle"
    toml.dump(plt_json, f)