# Find the best parameters for the RL model and fit

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm.auto import tqdm
from rl_analysis.util import count_transitions
from rl_analysis.io.df import dlight_exclude
from rl_analysis.models_rl.models import simulate
from rl_analysis.models_rl.loss import rl_loss
from rl_analysis.models_rl.util import (
    shuffle_rows_copy,
    zscore_missing_data,
    compare_tms,
)
from copy import deepcopy
from sklearn import model_selection

import joblib
import pandas as pd
import numpy as np
import os

In [3]:
# use the numpyro helpers to set the device count
import numpyro
import multiprocessing as mp

numpyro.set_host_device_count(mp.cpu_count() * 10)
numpyro.set_platform("cpu")

import jax
from jax.experimental.maps import Mesh, xmap

In [4]:
import toml

with open("../analysis_configuration.toml", "r") as f:
    analysis_config = toml.load(f)

In [5]:
raw_dirs = analysis_config["raw_data"]
dlight_cfg = analysis_config["dlight_common"]
proc_dirs = analysis_config["intermediate_results"]
lagged_cfg = analysis_config["dlight_lagged_correlations"]
encoding_cfg = analysis_config["dlight_encoding_features"]
figure_cfg = analysis_config["figures"]

## Pre-process data

### Split into train/test by session

In [6]:
use_data = pd.read_parquet(
    os.path.join(raw_dirs["rl_modeling"], "rl_modeling_dlight_data_offline.parquet")
)
use_data = dlight_exclude(
    use_data, exclude_3s=False, exclude_target=False, **dlight_cfg
)
use_data = use_data[~use_data["session_number"].isin([1, 2, 3, 4])].copy()
use_data = use_data.loc[use_data["area"] == "dls"].copy()
use_data = use_data.dropna(subset=["signal_reref_dff_z_max"])

In [7]:
syllable_key = "syllable"
feature_key = "signal_reref_dff_z_max" # reref'd dlight, max per transition
reference_key = "reference_dff_z_max" # reference alone

In [8]:
# use equal amount of data for all sessions (2500 syllables in this case)
cutoff_uuid = 2500
cutoff_date = cutoff_uuid
test_size = 0.5
truncate = 10

group_obj = use_data.groupby(["mouse_id", "uuid"])
use_data = group_obj.filter(lambda x: len(x) >= cutoff_uuid)
group_obj = use_data.groupby(["mouse_id", "uuid"])

# gather data, split off training data for optimizing hyperparameters
seqs = np.array([df[syllable_key].values[:cutoff_date] for k, df in group_obj])
features = np.array([df[feature_key].values[:cutoff_date] for k, df in group_obj])
timestamps = np.array([df["timestamp"].values[:cutoff_date] for k, df in group_obj])
keys = [k for k, _ in group_obj]

In [9]:
nsyllables = use_data["syllable"].max()

In [10]:
base_q_table = np.ones((nsyllables, nsyllables))
np.fill_diagonal(base_q_table, 0)

In [11]:
ref_tms = [count_transitions(_, K=57).astype("float") for _ in seqs]
[np.fill_diagonal(_, np.nan) for _ in ref_tms]
ref_tms = np.array(ref_tms)[:, :truncate, :truncate]

In [12]:
ref_tms_clip = [count_transitions(_[:], K=57).astype("float") for _ in seqs]
[np.fill_diagonal(_, np.nan) for _ in ref_tms_clip]
ref_tms_clip = np.array(ref_tms)[:, :truncate, :truncate]

In [13]:
splits = model_selection.KFold(n_splits=2, shuffle=False)
splits = list(splits.split(range(len(seqs))))
n_splits = len(splits)

In [14]:
train_ref_tms = {}
test_ref_tms = {}
for i, _split in enumerate(splits):
    train_ref_tms[i] = np.nan_to_num(
        zscore_missing_data(ref_tms_clip[_split[0]], axis=-1), 0
    ).mean(axis=0)
    test_ref_tms[i] = np.nan_to_num(
        zscore_missing_data(ref_tms_clip[_split[1]], axis=-1), 0
    ).mean(axis=0)

In [15]:
mask_threshold = 0  # note results were acceptable with 1 testing with 0

In [16]:
# here we set to nan any q-table entries below a pre-specified threshold
training_nan_mask = np.ones((truncate, truncate), dtype="float")
training_nan_mask[ref_tms.mean(axis=0) <= mask_threshold] = np.nan
np.fill_diagonal(training_nan_mask, np.nan)

In [17]:
# zscore features per session
z_features = features - features.mean(axis=1, keepdims=True)
z_features /= z_features.std(axis=1, keepdims=True)

At this point our data per session is organized by:

1. train/test transition matrices `train_ref_tms` and `test_ref_tms`.
   1. These variables are dicts containing the average TM per fold
2. DA per syllable is stored in `z_features`, which is z-scored per session.

# Parameter grid-search

In [18]:
# specify scan as a dict
parameter_dct = {
    "gamma": np.around(np.arange(0.5, 0.95, 0.1), 2),
    "alpha": np.around(np.arange(0.1, 0.9, 0.1), 2),
    "temperature_baseline": [
        0.1,
        0.2,
        0.3,
        0.4,
    ], 
    "temperature_alpha": [0.1], # tau ~ 10 syllables
    "temperature_peak": [0.25, 0.5, 0.75, 1],
    "temperature_threshold1": [1.0, 2.0],
    "temperature_threshold2": [-np.inf],
}

repeats = 10

In [19]:
param_names = list(parameter_dct.keys())

In [20]:
in_axes = [
    [...],
    ["dataset", ...],
    ["dataset", ...],
    ["dataset", ...],
    ["dataset", ...],
]

In [21]:
# parallelize our simulations
loss_xmap = xmap(
    rl_loss, in_axes=in_axes, out_axes=["dataset", ...], axis_resources={"dataset": "x"}
)

In [22]:
# assume lag 0 to start
# For parameter optimization we're using the following objective:
# LL of using original DA data - LL of using DA data shifted by 5 syllables
lag_func = lambda x: np.roll(x, 0, axis=1)
shifted_lag_func = lambda x: np.roll(x, -5, axis=1)

In [23]:
param_grid = model_selection.ParameterGrid(parameter_dct)

The input data for the simulations:

TIMESTEP X (CURRENT_SYLLABLE, NEXT_SYLLABLE, DA_PEAK_VALUE)

Each timestep is a syllable (i.e. the data is run-length-encoded)

Overall, we're testing if DA, acting as reward in a q-learning model, produces a q-table similar to the empirical TM

1. The state is the current syllable, and the action is the next syllable, thus $s_t, a_t$ is simply $syll_t, syll_{t_1}$
1. Policy is simulated as a softmax, though this doesn't matter since we're only looking at the q-table
1. We assume that $r_t$ is $\text{DA}_t$, or the peak dLight during $syll_t$
1. Temperature increases if $\text{DA}_t\gt\beta_1$ and decreases if $\text{DA}_t\lt\beta_2$.  DA relaxes back to the baseline temperature exponentially viz. $\tau_{t+1} =  \alpha * \tau_{baseline} + (1 - \alpha) * \tau_{t}$ 
1. At the end of the simulation we measure the Pearson correlation between the model's final q-table and the mouse's empirical TM
1. To ensure the model isn't cheating compare to trial-shuffled data, or even better data where DA is shifted in time by a fixed amount per trial (here trial is a 30 minute session)

In [24]:
# the input data for the simulation:CURRENT SYLLABLE, NEXT SYLLABLE, DA FEATURE
input_data = np.stack([seqs[:, :-1], seqs[:, 1:], lag_func(z_features[:, :-1])], axis=2)
shifted_input_data = np.stack(
    [seqs[:, :-1], seqs[:, 1:], shifted_lag_func(z_features[:, :-1])], axis=2
)

In [25]:
# build repeats and splits into the grid search
param_grid = list(param_grid)
param_names = list(parameter_dct.keys())

new_grid = []
for dct in param_grid:
    for _repeat in range(repeats):
        for _split in range(n_splits):
            append_dct = deepcopy(dct)
            append_dct["repeat"] = _repeat
            append_dct["fold"] = _split
            new_grid.append(append_dct)

In [26]:
# initialize the q_table to base
def init_q_table(base=base_q_table, truncate=truncate, nmats=1):
    return np.array([base[:truncate, :truncate]] * nmats, dtype="float")

In [27]:
# simulate each parameter set and store results
dcts = []
for dct in tqdm(new_grid):
    use_dct = deepcopy(dct)
    repeat_num = use_dct.pop("repeat")
    use_split = use_dct.pop("fold")
    train_idx, test_idx = splits[use_split]

    for use_devices in range(len(jax.devices()), 0, -1):
        if len(train_idx) % use_devices == 0:
            break

    with Mesh(np.array(jax.devices()[:use_devices]), ("x",)):
        prng_keys = np.array(
            [jax.random.PRNGKey(repeat_num) for _ in range(len(train_idx))]
        )

        q_table = init_q_table(base=training_nan_mask, nmats=len(train_idx))
        results, shifted_results = loss_xmap(
            use_dct,
            q_table,
            input_data[train_idx],
            shifted_input_data[train_idx],
            prng_keys,
        )

        # multi-lag?
        loss, (lls, ps, actions), q_table = results
        (
            shifted_loss,
            (shifted_lls, shifted_ps, shifted_actions),
            shifted_q_table,
        ) = shifted_results

        mean_q_table = np.nan_to_num(
            zscore_missing_data(np.array(q_table), axis=-1), 0
        ).mean(axis=0)
        mean_shifted_q_table = np.nan_to_num(
            zscore_missing_data(np.array(shifted_q_table), axis=-1), 0
        ).mean(axis=0)

        mean_q_table[np.isnan(training_nan_mask)] = np.nan
        mean_shifted_q_table[np.isnan(training_nan_mask)] = np.nan

        _dct = {
            "q_table": mean_q_table,
            "shifted_q_table": mean_shifted_q_table,
            "repeat": repeat_num,
            "fold": use_split,
            "loss": loss.mean(),
            "shifted_loss": shifted_loss.mean(),
        }
        dcts.append(_dct)

  0%|          | 0/25600 [00:00<?, ?it/s]

In [28]:
# q_table

In [29]:
new_dcts = deepcopy(new_grid)
for _dct, _results in zip(new_dcts, dcts):
    for k, v in _results.items():
        _dct[k] = v

In [30]:
loss_df = pd.DataFrame(new_dcts)
loss_df["loss"] = loss_df["loss"].astype("float")
# loss_df["shifted_loss"] = loss_df["shifted_loss"].astype("float")

In [31]:
loss_df["ref_tm"] = loss_df["fold"].map(train_ref_tms)

group_keys = param_names + ["fold"]

ave_loss_df = loss_df.groupby(group_keys).mean()
ave_loss_df["q_table"] = loss_df.groupby(group_keys)["q_table"].mean()
ave_loss_df["shifted_q_table"] = loss_df.groupby(group_keys)["shifted_q_table"].mean()
ave_loss_df["ref_tm"] = loss_df.groupby(group_keys)["ref_tm"].mean()

  ave_loss_df = loss_df.groupby(group_keys).mean()


In [32]:
ave_loss_df["r_tm"] = ave_loss_df.apply(
    lambda x: compare_tms(x["ref_tm"], x["q_table"]), axis=1
)
ave_loss_df["shifted_r_tm"] = ave_loss_df.apply(
    lambda x: compare_tms(x["ref_tm"], x["shifted_q_table"]), axis=1
)

ave_loss_df["diff_r_tm"] = ave_loss_df["r_tm"].values - np.maximum(
    ave_loss_df["shifted_r_tm"], 0
)

ave_loss_df["raw_loss"] = ave_loss_df["loss"]
ave_loss_df["total_r_tm"] = 2 * ave_loss_df["diff_r_tm"] + ave_loss_df["r_tm"]

In [33]:
# sns.heatmap(ave_loss_df.xs(0,level="fold").groupby(["alpha","temperature_peak"])["r_tm"].mean().unstack())

# Collect optimal parameters

In [35]:
param_order = [
    "temperature_baseline",
    "temperature_peak",
    "gamma",
    "alpha",
    "temperature_threshold1",
    "temperature_threshold2",
    "temperature_alpha",
]

In [36]:
use_params = {}
for _fold in range(n_splits):
    use_params[_fold] = {}
    use_loss_df = ave_loss_df.xs(_fold, level="fold")
    use_params[_fold][param_order[0]] = (
        use_loss_df.groupby(param_order[0])["diff_r_tm"].mean().idxmax()
    )
    for _param in param_order[1:]:
        use_params[_fold][_param] = (
            use_loss_df.xs(
                tuple(use_params[_fold].values()), level=tuple(use_params[_fold].keys())
            )
            .groupby(_param)["diff_r_tm"]
            .mean()
            .idxmax()
        )

In [37]:
savename = os.path.join(proc_dirs["rl_modeling"], "rl_model_parameters.toml")

In [38]:
with open(savename, "w") as f:
    toml.dump({str(k): v for k, v in use_params.items()}, f)

# Next scan over DA-->behavior lags

In [39]:
with open(os.path.join(proc_dirs["rl_modeling"], "rl_model_parameters.toml"), "r") as f:
    fit_params = toml.load(f)

In [40]:
new_fit_params = {}
for _fold, _params in fit_params.items():
    _params = {k: pd.to_numeric(v, errors="ignore") for k, v in _params.items()}
    new_fit_params[pd.to_numeric(_fold, errors="ignore")] = _params

In [41]:
use_params = new_fit_params

In [42]:
in_axes = [
    [...],
    ["dataset", ...],
    ["dataset", ...],
    ["dataset", ...],
]

In [43]:
loss_xmap = xmap(
    simulate,
    in_axes=in_axes,
    out_axes=["dataset", ...],
    axis_resources={"dataset": "x"},
)

In [44]:
sim_funcs = {
    "reward": loss_xmap,
}

In [45]:
# CURRENT SETTINGS
lags = np.arange(-10, 11)
repeats = 50
batches = 50

In [46]:
import itertools

In [47]:
save_file_basename = "rl_model_heldout_results"
save_file = os.path.join(proc_dirs["rl_modeling"], f"{save_file_basename}_lags.parquet")

In [48]:
force = False

In [49]:
dcts = []
inner_loop = list(
    itertools.product(["dynamic", "static"], lags, range(repeats), range(batches))
)
use_seed = 0
if not os.path.exists(save_file) or force:
    for _fold, (train_idx, test_idx) in enumerate(splits):
        for use_devices in range(len(jax.devices()), 0, -1):
            if len(test_idx) % use_devices == 0:
                break
        for k, v in sim_funcs.items():
            for _model, _lag, _repeat, _batch in tqdm(inner_loop):
                fold_params = deepcopy(use_params[_fold])

                if _model == "static":
                    fold_params["temperature_peak"] = 0

                with Mesh(np.array(jax.devices()[:use_devices]), ("x",)):

                    q_table = init_q_table(base=training_nan_mask, nmats=len(test_idx))
                    input_data = np.stack(
                        [
                            seqs[:, :-1],
                            seqs[:, 1:],
                            np.roll((z_features[:, :-1]), _lag, axis=1),
                        ],
                        axis=2,
                    )
                    prng_keys = np.array(
                        [jax.random.PRNGKey(use_seed) for _ in range(len(test_idx))]
                    )
                    use_seed += 1
                    loss, lls, q_table = v(
                        fold_params, q_table, input_data[test_idx], prng_keys
                    )

                    mean_q_table = np.nan_to_num(
                        zscore_missing_data(np.array(q_table), axis=-1), 0
                    ).mean(axis=0)
                    mean_q_table[np.isnan(training_nan_mask)] = np.nan
                    mean_q_table_raw = np.nanmean(q_table, axis=0)
                    mean_q_table_raw[np.isnan(training_nan_mask)] = np.nan

                    _dct = {
                        "loss": loss.mean(),
                        # "tm_loss": tm_loss,
                        "func": v.__name__,
                        "lag": _lag,
                        "fold": _fold,
                        "repeat": _repeat,
                        "batch": _batch,
                        "model": _model,
                        "q_table": mean_q_table,
                        "q_table_raw": mean_q_table_raw,
                    }
                    dcts.append(_dct)

    time_df = pd.DataFrame(dcts)
    time_df["ref_tm"] = time_df["fold"].map(test_ref_tms)
    time_df["r_tm"] = time_df.apply(
        lambda x: compare_tms(x["ref_tm"], x["q_table"]), axis=1
    )

    tm_dct = {}
    tm_dct["ref_tm"] = time_df["ref_tm"].to_list()
    tm_dct["q_table"] = time_df["q_table"].to_list()
    tm_dct["q_table_raw"] = time_df["q_table_raw"].to_list()

    time_df["loss"] = time_df["loss"].astype("float")
    time_df.drop(list(tm_dct.keys()), axis=1, errors="ignore").to_parquet(save_file)
    joblib.dump(tm_dct, save_file.replace("parquet", "p"))
else:
    time_df = pd.read_parquet(save_file)
    tm_dct = joblib.load(save_file.replace("parquet", "p"))
    for k, v in tm_dct.items():
        time_df[k] = v

  0%|          | 0/105000 [00:00<?, ?it/s]

  mean_q_table_raw = np.nanmean(q_table, axis=0)
  mean_q_table_raw = np.nanmean(q_table, axis=0)


  0%|          | 0/105000 [00:00<?, ?it/s]

In [50]:
group_keys = ["lag", "fold", "batch", "model"]

In [51]:
ave_time_df = time_df.groupby(group_keys).mean()
ave_time_df["q_table"] = time_df.groupby(group_keys)["q_table"].mean()
ave_time_df["q_table_raw"] = time_df.groupby(group_keys)["q_table_raw"].mean()
ave_time_df["ref_tm"] = time_df.groupby(group_keys)["ref_tm"].mean()
ave_time_df["ave_r_tm"] = ave_time_df.apply(
    lambda x: compare_tms(x["ref_tm"], x["q_table"]), axis=1
)

  ave_time_df = time_df.groupby(group_keys).mean()


In [52]:
ave_time_df["ave_r_tm"].groupby(["lag", "model"]).mean()

lag  model  
-10  dynamic   -0.016271
     static    -0.390744
-9   dynamic   -0.013490
     static    -0.421764
-8   dynamic    0.046994
     static    -0.413248
-7   dynamic    0.056691
     static    -0.388023
-6   dynamic    0.008396
     static    -0.442007
-5   dynamic    0.049587
     static    -0.383897
-4   dynamic    0.089757
     static    -0.365063
-3   dynamic   -0.046379
     static    -0.428004
-2   dynamic   -0.013731
     static    -0.417210
-1   dynamic   -0.067664
     static    -0.391930
 0   dynamic    0.272956
     static     0.163235
 1   dynamic    0.340368
     static     0.007984
 2   dynamic    0.197415
     static    -0.221188
 3   dynamic    0.105905
     static    -0.232961
 4   dynamic    0.091233
     static    -0.329570
 5   dynamic    0.032321
     static    -0.359427
 6   dynamic    0.018275
     static    -0.341908
 7   dynamic    0.073129
     static    -0.368829
 8   dynamic    0.019334
     static    -0.384257
 9   dynamic   -0.061409
     static 

In [53]:
ave_time_df["q_table_raw"].apply(lambda x: np.quantile(np.nanmin(x, axis=1), 0.5))

lag  fold  batch  model  
-10  0     0      dynamic    0.877019
                  static    -0.080402
           1      dynamic    0.885933
                  static    -0.083562
           2      dynamic    0.886810
                               ...   
 10  1     47     static     0.073161
           48     dynamic    0.672214
                  static     0.071606
           49     dynamic    0.671364
                  static     0.071270
Name: q_table_raw, Length: 4200, dtype: float64

In [54]:
best_lag = (
    ave_time_df.groupby(["model", "lag"])["ave_r_tm"]
    .mean()
    .groupby("model")
    .idxmax()
    .apply(lambda x: x[-1])
)

In [55]:
lag_func = lambda x: np.roll(x, best_lag.loc["dynamic"], axis=1)

In [56]:
nrands = 100

In [57]:
# can add lag to the inner loop, then we subtract off the surround like with others...

In [58]:
save_file = os.path.join(
    proc_dirs["rl_modeling"], f"{save_file_basename}_best_lag_rands.parquet"
)

In [59]:
# time_losses = defaultdict(list)
inner_loop = list(
    itertools.product([best_lag.loc["dynamic"]], range(repeats), range(nrands))
)
dcts = []
use_seed = 0
if not os.path.exists(save_file) or force:
    for _fold, (train_idx, test_idx) in enumerate(splits):
        # fold_params = dict(zip(param_names, use_params[_fold]))
        fold_params = use_params[_fold]

        for use_devices in range(len(jax.devices()), 0, -1):
            if len(test_idx) % use_devices == 0:
                break
        for k, v in sim_funcs.items():
            for _lag, _repeat, _rand in tqdm(inner_loop):
                with Mesh(np.array(jax.devices()[:use_devices]), ("x",)):

                    q_table = np.array([training_nan_mask] * len(test_idx))
                    features_rand = shuffle_rows_copy(z_features.copy())
                    rand_input_data = np.stack(
                        [seqs[:, :-1], seqs[:, 1:], lag_func(features_rand[:, :-1])],
                        axis=2,
                    )

                    prng_keys = np.array(
                        [jax.random.PRNGKey(use_seed) for _ in range(len(test_idx))]
                    )
                    use_seed += 1
                    loss, lls, q_table = v(
                        fold_params, q_table, rand_input_data[test_idx], prng_keys
                    )

                    mean_q_table = np.nan_to_num(
                        zscore_missing_data(np.array(q_table), axis=-1), 0
                    ).mean(axis=0)
                    mean_q_table[np.isnan(training_nan_mask)] = np.nan
                    mean_q_table_raw = np.nanmean(q_table, axis=0)
                    mean_q_table_raw[np.isnan(training_nan_mask)] = np.nan

                    _dct = {
                        "loss": loss.mean(),
                        "func": v.__name__,
                        "fold": _fold,
                        "batch": _rand,
                        "repeat": _repeat,
                        "lag": _lag,
                        "q_table": mean_q_table,
                        "q_table_raw": mean_q_table_raw,
                    }
                    dcts.append(_dct)

    rand_df = pd.DataFrame(dcts)
    rand_df["ref_tm"] = rand_df["fold"].map(test_ref_tms)
    rand_df["r_tm"] = rand_df.apply(
        lambda x: compare_tms(x["ref_tm"], x["q_table"]), axis=1
    )

    tm_dct = {}
    tm_dct["ref_tm"] = rand_df["ref_tm"].to_list()
    tm_dct["q_table"] = rand_df["q_table"].to_list()
    tm_dct["q_table_raw"] = rand_df["q_table_raw"].to_list()

    if "loss" in rand_df.columns:
        rand_df["loss"] = rand_df["loss"].astype("float")
    rand_df.drop(list(tm_dct.keys()), axis=1, errors="ignore").to_parquet(save_file)
    joblib.dump(tm_dct, save_file.replace("parquet", "p"))
else:
    rand_df = pd.read_parquet(save_file)
    tm_dct = joblib.load(save_file.replace("parquet", "p"))
    for k, v in tm_dct.items():
        rand_df[k] = v

  0%|          | 0/5000 [00:00<?, ?it/s]

  mean_q_table_raw = np.nanmean(q_table, axis=0)


  0%|          | 0/5000 [00:00<?, ?it/s]

In [60]:
group_keys = ["lag", "fold", "batch"]

In [61]:
ave_rand_df = rand_df.groupby(group_keys).mean()
ave_rand_df["q_table"] = rand_df.groupby(group_keys)["q_table"].mean()
ave_rand_df["q_table_raw"] = rand_df.groupby(group_keys)["q_table_raw"].mean()
ave_rand_df["ref_tm"] = rand_df.groupby(group_keys)["ref_tm"].mean()
ave_rand_df["ave_r_tm"] = ave_rand_df.apply(
    lambda x: compare_tms(x["q_table"], x["ref_tm"]), axis=1
)

  ave_rand_df = rand_df.groupby(group_keys).mean()
