# Compute the timecourse of cumulative syllable counts

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from collections import defaultdict

import pandas as pd

import os
import numpy as np

In [3]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
closed_loop_cfg = analysis_config["closed_loop_behavior"]
figure_cfg = analysis_config["figures"]

In [5]:
test_statistics_file = os.path.join(
    proc_dirs["closed_loop_behavior"], "stats_closed_loop_behavior.toml"
)

In [6]:
bin_size = closed_loop_cfg["learning_timecourse"]["bin_size"]

In [7]:
save_file = os.path.join(
    raw_dirs["closed_loop_behavior"], f"learning_timecourse_binsize-{bin_size}.parquet"
)

In [8]:
test_stats = defaultdict(dict)

# Load in data

In [9]:
from rl_analysis.util import hampel_filter

In [10]:
hampel_kwargs = {"threshold": 6, "window": 7, "min_periods": 1, "center": True}

In [11]:
use_df = pd.read_parquet(save_file)

In [12]:
use_df = use_df.loc[(use_df["syllable"] == use_df["target_syllable"])].copy()

In [13]:
use_df = use_df.sort_values(
    [
        "syllable_group_unique",
        "session_number",
        "syllable",
        "bin_start",
    ]
)

In [14]:
use_df["log2_fold_change_usage"] = np.log2(use_df["fold_change_usage"])
use_df["log2_fold_change_count"] = np.log2(use_df["fold_change_count"])

In [15]:
use_df.loc[use_df["genotype"].isin(["snc-dls-eyfp"]), "opsin"] = "eyfp"

In [16]:
filter_columns = [
    "change_count",
    "change_usage",
    "log2_fold_change_count",
    "log2_fold_change_usage",
    "count",
    "usage",
    "baseline_count",
    "baseline_usage",
]
use_filtered_counts = False

In [17]:
group_keys = ["session_type", "syllable_group_unique", "syllable", "rle"]

In [18]:
if use_filtered_counts:
    for _col in tqdm(filter_columns):
        use_df[f"{_col}_filtered"] = use_df.groupby(group_keys, dropna=False)[
            _col
        ].transform(lambda x: hampel_filter(x, **hampel_kwargs))
    formatter = "{}_filtered"
else:
    formatter = "{}"

In [19]:
for _col in tqdm(filter_columns):
    use_df[f"{_col}_cumulative"] = use_df.groupby(group_keys, dropna=False)[
        formatter.format(_col)
    ].transform(lambda x: x.cumsum())
    use_df[f"{_col}_cumulative_within_session"] = use_df.groupby(
        group_keys + ["session_number"], dropna=False
    )[formatter.format(_col)].transform(lambda x: x.cumsum())

  0%|          | 0/8 [00:00<?, ?it/s]

In [20]:
normalize_cols = ["count", "usage"]
eps = {"count": 2, "usage": 1e-2}

In [21]:
for _col in tqdm(normalize_cols):
    group_key = f"{_col}_cumulative"
    session_key = f"{_col}_cumulative_within_session"

    for _use_key in [group_key, session_key]:
        use_df[f"change_{_use_key}_v2"] = (
            use_df[_use_key] - use_df[f"baseline_{_use_key}"]
        )
        use_df[f"log2_fold_change_{_use_key}_v2"] = np.log2(
            (use_df[_use_key] + eps[_col])
            / (use_df[f"baseline_{_use_key}"] + eps[_col])
        )

  0%|          | 0/2 [00:00<?, ?it/s]

In [22]:
cum_df = use_df.loc[
    (use_df["stim_duration"] == 0.25)
    & (use_df["opsin"].isin(["chr2", "ctrl", "halo", "chrimson", "eyfp"]))
].copy()

In [23]:
session_count = (
    cum_df.loc[cum_df["session_number"].isin([1, 2])]
    .drop_duplicates(["mouse_id", "cohort", "target_syllable", "session_number"])
    .groupby(["mouse_id", "cohort", "target_syllable"])
    .size()
)
include_tups = session_count[session_count >= 2].index
cum_df = cum_df.set_index(include_tups.names).loc[include_tups].reset_index()

mouse_count = (
    cum_df.drop_duplicates(["mouse_id", "cohort", "target_syllable"])
    .groupby(["mouse_id", "cohort"])
    .size()
)
include_mice = mouse_count[mouse_count >= 2].index.tolist()
cum_df = cum_df.set_index(["mouse_id", "cohort"]).loc[include_mice].reset_index()

In [24]:
cum_df.index = range(len(cum_df))

In [25]:
cum_df.to_parquet(
    os.path.join(
        raw_dirs["closed_loop_behavior"], "learning_timecourse_processed.parquet"
    )
)

# Compute summary learning and save

In [26]:
summary_df = cum_df.copy()
max_bin = summary_df["bin_start"].unique().max()
summary_df = summary_df.loc[summary_df["bin_start"] == max_bin].copy()

In [27]:
summary_df.index = range(len(summary_df))

In [28]:
summary_df.to_parquet(
    os.path.join(
        raw_dirs["closed_loop_behavior"],
        "learning_timecourse_processed_summary.parquet",
    )
)