# Here we're computing features for the decoding and reinforcement learning models

1. Re-run once with `file_suffix` set to `offline` and once with `file_suffix` set to `online`.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.util import pd_zscore
import pandas as pd
import numpy as np
import os

In [3]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [4]:
raw_dirs = analysis_config["raw_data"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
encoding_cfg = analysis_config["dlight_encoding_features"]
figure_cfg = analysis_config["figures"]

In [5]:
file_suffix = "offline" if lagged_cfg["use_offline"] else "online"
# file_suffix = "online"
load_file = os.path.join(raw_dirs["dlight"], f"dlight_snippets_{file_suffix}.parquet")

In [6]:
# possibly get more compressive and noise robust...
def minmax_da_scale(x, q1=0, q2=1):
    return (x - x.quantile(q1)) / (x.quantile(q2) - x.quantile(q1))


def squash(x, q1=0.2, q2=0.8):
    x_q1 = x.quantile(q1)
    x_q2 = x.quantile(q2)
    x[(x > x_q1) & (x < x_q2)] = 0
    x[x > q2] -= q2
    x[x < q1] += q1
    return x

zscore = pd_zscore
zscore.__name__ = "zscore"

def threshold(x, threshold=1.96):
    y = x.copy()
    y *= 0
    y[x > threshold] = 1
    y[x < -threshold] = -1
    return y

In [7]:
read_columns = [
    "uuid",
    "signal_reference_corr",
    "signal_max",
    "reference_max",
    "velocity_2d_mm",
]

In [8]:
raw_df = pd.read_parquet(load_file, columns=read_columns)
meta_df = raw_df.groupby("uuid").mean()

In [9]:
file, ext = os.path.splitext(load_file)
features_save_file = f"{file}_features{ext}"

dirname, filename = os.path.split(features_save_file)
file, ext = os.path.splitext(filename)
file = os.path.join(proc_dirs["dlight"], file)

if lagged_cfg["use_renormalized"]:
    features_save_file = f"{file}_renormalize{ext}"

rle_save_file = features_save_file.replace("snippet", "usage")

In [10]:
window_tup = "(0.0, 0.3)"

## Load in pre-processed data

### Load raw data and munge

In [11]:
feature_df = pd.read_parquet(features_save_file)
rle_df = pd.read_parquet(rle_save_file)

In [12]:
feature_df = feature_df.loc[feature_df["window_tup"] == window_tup].copy()

In [13]:
syllable_stats = toml.load(
    os.path.join(proc_dirs["dlight"], f"syllable_stats_photometry_{file_suffix}.toml")
)
usage = syllable_stats["usages"]
mapping = {int(k): int(v) for k, v in syllable_stats["syllable_to_sorted_idx"].items()}
reverse_mapping = {
    int(k): int(v) for k, v in syllable_stats["sorted_idx_to_syllable"].items()
}

use_syllables = np.array(list(mapping.keys()))
use_syllables = use_syllables[use_syllables >= 0]

In [14]:
feature_df["syllable_original_id"] = feature_df["syllable"]
feature_df["target_syllable_original_id"] = feature_df["target_syllable"]
feature_df["syllable"] = feature_df["syllable"].map(mapping)
rle_df["syllable"] = rle_df["syllable"].map(mapping)
feature_df["target_syllable"] = feature_df["target_syllable"].map(mapping)

In [15]:
rle_df["uuid"] = rle_df["uuid"].astype("str")

In [16]:
rle_df["syllable_num"] = rle_df.groupby(["uuid"])["syllable"].transform(
    lambda x: np.arange(len(x))
)
feature_df = feature_df.sort_values(["timestamp", "uuid"])
rle_df = rle_df.sort_values(["timestamp", "uuid"])
rle_df["timestamp"] = rle_df["timestamp"].astype("float32")

feature_df["uuid"] = feature_df["uuid"].astype("str")
feature_df = pd.merge_asof(
    feature_df, rle_df[["timestamp", "uuid", "syllable_num"]], on="timestamp", by="uuid"
)
feature_df = feature_df.sort_values(["uuid", "timestamp"])
rle_df = rle_df.sort_values(["uuid", "timestamp"])

In [17]:
syllable_key = "syllable"
feature_key = "signal_reref_dff_z_max"
reference_key = "reference_dff_z_max"

In [18]:
zscore_keys = [feature_key, reference_key]
zscore_keys = {_: f"{_}_z" for _ in zscore_keys}

squash_keys = [feature_key, reference_key]
squash_keys = {_: f"{_}_squash" for _ in squash_keys}

In [19]:
transform_keys = {
    "signal_reref_dff_z_max": [zscore, squash, threshold],
}

In [20]:
use_data = feature_df  # including all data 2/16/2022 to fold in stim data
all_keys = list(set(transform_keys.keys()))
for k, funcs in tqdm(transform_keys.items()):
    for v in funcs:
        new_key = f"{k}_{v.__name__}"
        use_data[new_key] = use_data.groupby("uuid")[k].transform(v)
        all_keys.append(new_key)

  0%|          | 0/1 [00:00<?, ?it/s]

In [21]:
use_data["date_rnd"] = use_data["date"].dt.floor("d")

use_data = use_data.sort_values(["mouse_id", "date_rnd", "uuid", "timestamp"])
cnts = use_data.groupby(["mouse_id", "date_rnd"], observed=True)["uuid"].nunique()

In [22]:
use_data.to_parquet(
    os.path.join(
        raw_dirs["rl_modeling"], f"rl_modeling_dlight_data_{file_suffix}.parquet"
    )
)

In [23]:
print(file_suffix)

online
