In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from syd import make_viewer
from tqdm import tqdm

from vrAnalysis.database import get_database
from vrAnalysis.helpers import Timer, get_placefield_location, cross_validate_trials, sort_by_preferred_environment
from vrAnalysis.sessions import B2Session
from vrAnalysis.processors import SpkmapProcessor
from vrAnalysis.processors.support import median_zscore
from vrAnalysis.processors.placefields import get_placefield, get_frame_behavior, get_placefield_prediction
from dimilibi import Population
from dimilibi import ReducedRankRegression, RidgeRegression
from dimilibi import measure_r2, mse
from dimensionality_manuscript.registry import PopulationRegistry, get_model, ModelName
from dimensionality_manuscript.regression_models.hyperparameters import PlaceFieldHyperparameters, RBFPosHyperparameters, ReducedRankRegressionHyperparameters

# get session database
sessiondb = get_database("vrSessions")

# get population registry and models
registry = PopulationRegistry()
ext_model = get_model("external_placefield_1d", registry)
int_model = get_model("internal_placefield_1d", registry)
ext_gain_model = get_model("external_placefield_1d_gain", registry)
int_gain_model = get_model("internal_placefield_1d_gain", registry)
rbfpos_decoder_only_model = get_model("rbfpos_decoder_only", registry)
rbfpos_model = get_model("rbfpos", registry)
rbfpos_leak_model = get_model("rbfpos_leak", registry)
rrr_model = get_model("rrr", registry)

In [None]:
# Aesthetic: make a repr for various things like Report, RegressionModel, etc.

In [34]:
# 0. Why is RBFPos the worst? 
# 0. Why isn't gain helping? 
# 0. Check to make sure best hyperparameters aren't at edges of grid!
# ----------------------------
# 1. I'll just analyze the results, let's make the r2 plots across models, averaging each mouse
# 2. Should consider saving more than just mse for each model in get_best_score... I mean it's worth it to have a bit of summary no? 
# 3. Then make the auxiliary plots for the models, showing what they do etc beyond the schematic

# Plot Structure:
# - for each model, show predictions (using a session with average performance? or same session throughout?)
# - show any internals of the model where relevant
# - for RRR, analyze how the RRR latents contain the spatial information or not. Probably using RidgeRegression in both directions.

In [6]:
# This code checks where the best hyperparameters are for each model
show_results = False
if show_results:
    for session in sessiondb.gen_sessions(imaging=True):
        print(session.session_print())
        if not rbfpos_decoder_only_model.check_existing_hyperparameters(session, spks_type="oasis"):
            continue
        hyp, best_score, results = rbfpos_decoder_only_model.get_best_hyperparameters(session, spks_type="oasis")
        params = [prm for prm in results.columns.tolist() if "score" not in prm]
        prmvalues = [results[p].unique() for p in params]
        best_result = results.loc[results["score"].idxmin()]
        best_grididx = [1.0* (prmvalues[i] == best_result[p]) for i, p in enumerate(params)]
        for prm, grididx in zip(params, best_grididx):
            print(prm, grididx)

In [4]:
for s in sessiondb.gen_sessions(imaging=True, mouseName="ATL057"):
    print(s.session_print())

ATL057/2024-07-08/701
ATL057/2024-07-09/701
ATL057/2024-07-10/701
ATL057/2024-07-12/701
ATL057/2024-07-16/701
ATL057/2024-07-19/701


In [23]:
session = B2Session.create("CR_Hippocannula6", "2022-08-30", "701", params=dict(spks_type="oasis"))

In [9]:
do_grid_search = False
if do_grid_search:
    with Timer("Grid Search"):
        grid_params, grid_score, grid_results = rbfpos_model._optimize_grid(session, "oasis", "train", "validation")
    print(grid_params, grid_score)
    print(grid_results)

In [None]:
session = random.choice(sessiondb.iter_sessions(imaging=True))
model = rrr_model
with Timer("Optuna Study"):
    opt_params, opt_score, opt_results = model._optimize_optuna(session, "oasis", "train", "validation", n_trials=30, nan_safe=True)
    
hyperparameters = model._model_hyperparameters(**opt_params)
report = model.process(session, spks_type="oasis", hyperparameters=hyperparameters)

print(opt_params, opt_score)
# print(opt_results)
print(model.get_best_score(session, spks_type="oasis")["r2"])
print(report.metrics["r2"])

Optuna search:   0%|          | 0/30 [00:00<?, ?it/s]

Optuna Study || elapsed time: 324.8005 seconds
{'rank': 100, 'alpha': 13.414362834698029} 0.001431182841770351
0.17497670650482178
0.1737656593322754


In [31]:
for irow, row in opt_results.iterrows():
    if row["rank"] == 100:
        print(row["rank"], np.round(row["alpha"], 3), np.round(row["score"] * 1000, 3))

100.0 0.315 1.536
100.0 63.512 1.446
100.0 8.472 1.432
100.0 2.481 1.447
100.0 0.042 1.829
100.0 0.042 1.83
100.0 0.017 2.111
100.0 29.154 1.435
100.0 2.538 1.447
100.0 6.797 1.433
100.0 0.192 1.58
100.0 11.282 1.431
100.0 28.772 1.435
100.0 8.86 1.432
100.0 9.956 1.431
100.0 0.436 1.514
100.0 1.576 1.458
100.0 75.967 1.45
100.0 14.647 1.431
100.0 3.757 1.44
100.0 21.781 1.433
100.0 14.805 1.431
100.0 0.818 1.481
100.0 15.757 1.431
100.0 40.19 1.439
100.0 3.497 1.441
100.0 90.565 1.454
100.0 13.414 1.431
100.0 0.912 1.476
100.0 4.231 1.438


In [27]:
model.get_best_hyperparameters(session, spks_type="oasis")[0]

ReducedRankRegressionHyperparameters(rank=100, alpha=10.0)

In [3]:
opt_params

{'use_smoothing': True, 'smooth_width': 2.178479577197587, 'num_bins': 36}

In [18]:
session = B2Session.create("CR_Hippocannula6", "2022-08-30", "701", params=dict(spks_type="oasis"))
show_full_results = True
if show_full_results:
    for model in [ext_model, int_model, ext_gain_model, int_gain_model, rbfpos_decoder_only_model, rbfpos_model, rbfpos_leak_model,rrr_model]:
        print(model.get_best_score(session, spks_type="oasis", method="best")["r2"])

0.15390602828627287
0.1552059872009105
0.158541797739719
0.15745377474761046
0.1553073525428772
0.13527292013168335
0.1482049822807312
0.17497670650482178


In [None]:
# This code runs another grid search "locally" with updated parameters...
run_grid_search = True
if run_grid_search:
    from itertools import product

    alpha_encoder = np.logspace(-2, 2, 9)
    alpha_decoder = np.logspace(-2, 2, 9)
    num_basis = [100]
    basis_width = [5]

    best_score = -np.inf
    best_hyp = None

    param_grid = list(product(num_basis, basis_width, alpha_encoder, alpha_decoder))
    for hyp in tqdm(param_grid):
        hyp = RBFPosHyperparameters(*hyp)
        r2 = rbfpos_leak_model.process(session, spks_type="oasis", hyperparameters=hyp).metrics["r2"]
        if r2 > best_score:
            best_score = r2
            best_hyp = hyp
        
    print("RBFPos(OPT)", best_hyp, best_score)

show_full_results = False
if show_full_results:
    for model in [ext_model, int_model, ext_gain_model, int_gain_model, rbfpos_decoder_only_model, rbfpos_model, rbfpos_leak_model,rrr_model]:
        print(model.get_best_score(session, spks_type="oasis")["r2"])

100%|██████████| 81/81 [01:25<00:00,  1.06s/it]


RBFPos(OPT) 

AttributeError: 'RBFPosHyperparameters' object has no attribute 'alpha'

In [10]:
MODEL_NAMES: list[ModelName] = [
    "external_placefield_1d",
    "internal_placefield_1d",
    "external_placefield_1d_gain",
    "internal_placefield_1d_gain",
    "rbfpos_decoder_only",
    "rbfpos",
    "rbfpos_leak",
    "rrr",
]

spks_type = "oasis"
sessiondb = get_database("vrSessions")
registry = PopulationRegistry()

get_sample_r2 = True
num_sessions = len(sessiondb.iter_sessions(imaging=True))
mouse_names = []
scores = np.full((len(MODEL_NAMES), num_sessions), np.nan)
if get_sample_r2:
    r2_avg_from_samples = np.full((len(MODEL_NAMES), num_sessions), np.nan)

for imodel, model_name in enumerate(MODEL_NAMES):
    model = get_model(model_name, registry)
    for isession, session in enumerate(tqdm(sessiondb.iter_sessions(imaging=True, session_params=dict(spks_type=spks_type)))):
        if get_sample_r2:
            hyps = model.get_best_hyperparameters(session, spks_type=spks_type)[0]
            report = model.process(session, spks_type=spks_type, hyperparameters=hyps)
            r2_avg_from_samples[imodel, isession] = measure_r2(report.predicted_data, report.target_data, reduce="mean", dim=0)

        metrics = model.get_best_score(session, spks_type=spks_type)
        scores[imodel, isession] = metrics["r2"]
        mouse_names.append(session.mouse_name)

100%|██████████| 149/149 [03:23<00:00,  1.37s/it]
100%|██████████| 149/149 [02:33<00:00,  1.03s/it]
100%|██████████| 149/149 [02:23<00:00,  1.04it/s]
100%|██████████| 149/149 [03:05<00:00,  1.25s/it]
100%|██████████| 149/149 [00:59<00:00,  2.50it/s]
100%|██████████| 149/149 [01:20<00:00,  1.84it/s]
100%|██████████| 149/149 [01:50<00:00,  1.35it/s]
100%|██████████| 149/149 [08:27<00:00,  3.41s/it]


In [29]:
mice = np.unique(mouse_names)
results = {mouse: np.zeros(len(MODEL_NAMES)) for mouse in mice}
r2res = {mouse: np.zeros(len(MODEL_NAMES)) for mouse in mice}
counts = {mouse: 0 for mouse in mice}
for isession in range(num_sessions):
    cname = mouse_names[isession]
    results[cname] += scores[:, isession]
    r2res[cname] += r2_avg_from_samples[:, isession]
    counts[cname] += 1

for mouse in mice:
    results[mouse] /= counts[mouse]
    r2res[mouse] /= counts[mouse]

mouse_averages = np.stack([results[mouse] for mouse in mice])
mouse_averages_r2 = np.stack([r2res[mouse] for mouse in mice])

idx_keep = [0, 1, 2, 3, 4, -1]

plt.close('all')
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(range(len(idx_keep)), mouse_averages[:, idx_keep].T, alpha=0.4)
ax[0].plot(range(len(idx_keep)), np.mean(mouse_averages[:, idx_keep], axis=0), color="black", alpha=1.0)
ax[1].plot(range(len(idx_keep)), mouse_averages_r2[:, idx_keep].T, alpha=0.4)
ax[1].plot(range(len(idx_keep)), np.mean(mouse_averages_r2[:, idx_keep], axis=0), color="black", alpha=1.0)
plt.show()