In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from syd import make_viewer
from tqdm import tqdm

from vrAnalysis.database import get_database
from vrAnalysis.helpers import Timer, get_placefield_location, cross_validate_trials, sort_by_preferred_environment
from vrAnalysis.sessions import B2Session
from vrAnalysis.processors import SpkmapProcessor
from vrAnalysis.processors.support import median_zscore
from vrAnalysis.processors.placefields import get_placefield, get_frame_behavior, get_placefield_prediction
from dimilibi import Population
from dimilibi import ReducedRankRegression, RidgeRegression
from dimilibi import measure_r2, mse
from dimensionality_manuscript.registry import PopulationRegistry, get_model, ModelName
from dimensionality_manuscript.regression_models.hyperparameters import PlaceFieldHyperparameters, RBFPosHyperparameters, ReducedRankRegressionHyperparameters

# get session database
sessiondb = get_database("vrSessions")

# get population registry and models
registry = PopulationRegistry()
ext_model = get_model("external_placefield_1d", registry)
int_model = get_model("internal_placefield_1d", registry)
ext_gain_model = get_model("external_placefield_1d_gain", registry)
int_gain_model = get_model("internal_placefield_1d_gain", registry)
ext_vgain_model = get_model("external_placefield_1d_vector_gain", registry)
int_vgain_model = get_model("internal_placefield_1d_vector_gain", registry)
rbfpos_decoder_only_model = get_model("rbfpos_decoder_only", registry)
rbfpos_model = get_model("rbfpos", registry)
rbfpos_leak_model = get_model("rbfpos_leak", registry)
rrr_model = get_model("rrr", registry)

In [None]:
# Aesthetic: make a repr for various things like Report, RegressionModel, etc.

In [6]:
session = sessiondb.iter_sessions(imaging=True)[0]
report_ext_model = ext_model.process(session, spks_type="oasis")
report_int_model = int_model.process(session, spks_type="oasis")
report_ext_gain_model = ext_gain_model.process(session, spks_type="oasis")
report_int_gain_model = int_gain_model.process(session, spks_type="oasis")
report_ext_vgain_model = ext_vgain_model.process(session, spks_type="oasis")
report_int_vgain_model = int_vgain_model.process(session, spks_type="oasis")

In [8]:
for report in [report_ext_model, report_int_model, report_ext_gain_model, report_int_gain_model, report_ext_vgain_model, report_int_vgain_model]:
    print(report.metrics["r2"])


0.020181808054372086
0.0241917768758938
0.02955026077368994
0.028654429647126944
0.021623953914423844
0.025190177370237654


In [7]:
# Check vectorized gain model!
idx_session = 0
session = sessiondb.iter_sessions(imaging=True)[idx_session]
spks_type = "oasis"

with Timer("Getting session data"):
    source_data_train, target_data_train, frame_behavior_train = int_model.get_session_data(session, spks_type, "train")
    source_data_test, target_data_test, frame_behavior_test = int_model.get_session_data(session, spks_type, "test")

with Timer("Training"):
    trained_model = int_model.train(session, spks_type=spks_type)

with Timer("Predicting"):
    predicted_data, extras = int_model.predict(session, trained_model, spks_type=spks_type)

Getting session data || elapsed time: 0.2907 seconds
Training || elapsed time: 0.0208 seconds
Predicting || elapsed time: 0.0296 seconds


In [34]:
from sklearn.decomposition import randomized_svd

def get_arousal_coefficients(data, n_iter: int = 5):
    U, s, V = randomized_svd(data.numpy(), n_components=1, n_iter=n_iter)
    arousal_coefficients = U[:, 0]
    arousal_estimate = V[0] * s[0]
    return arousal_coefficients, arousal_estimate

with Timer("Extracting arousal coefficients train"):
    source_prediction_train = torch.tensor(get_placefield_prediction(trained_model[1], frame_behavior_train)[0].T)
    target_prediction_train = torch.tensor(get_placefield_prediction(trained_model[0], frame_behavior_train)[0].T)
    source_deviation_train = source_data_train - source_prediction_train
    target_deviation_train = target_data_train - target_prediction_train
    num_source = source_data_train.shape[0]
    num_target = target_data_train.shape[0]
    full_deviation_train = torch.cat([source_deviation_train, target_deviation_train], dim=0)

    n_iter = 100
    arousal_coefficients_source, arousal_estimate_source = get_arousal_coefficients(source_deviation_train, n_iter=n_iter)
    arousal_coefficients_target, arousal_estimate_target = get_arousal_coefficients(target_deviation_train, n_iter=n_iter)
    arousal_coefficients_full, arousal_estimate_full = get_arousal_coefficients(full_deviation_train, n_iter=n_iter)

    ac_source_from_full = arousal_coefficients_full[:num_source]
    ac_target_from_full = arousal_coefficients_full[num_source:]

plt.close('all')
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].plot(arousal_estimate_full, color="black", linewidth=0.5)
ax[0].plot(guess_arousal_estimate, color="red", linewidth=0.5)
# ax[0].plot(arousal_estimate_source, color="red", linewidth=0.5)
# ax[0].plot(arousal_estimate_target, color="blue", linewidth=0.5)
ax[1].scatter(ac_source_from_full, arousal_coefficients_source, alpha=0.1, color="black")
ax[2].scatter(arousal_estimate_source, arousal_estimate_full, alpha=0.1, color="black")
plt.show()

Extracting arousal coefficients train || elapsed time: 0.6595 seconds
True


In [24]:
len(frame_behavior_train), source_prediction_train.shape

(1466, torch.Size([2118, 1466]))

In [12]:
U.shape, V.shape, s.shape

((2118, 1), (1, 1466), (1,))

In [None]:
U @ 

In [3]:
source_data_train.shape, target_data_train.shape, source_data_test.shape, target_data_test.shape

(torch.Size([1228, 6004]),
 torch.Size([1227, 6004]),
 torch.Size([1228, 750]),
 torch.Size([1227, 750]))

In [None]:
# 0. Why is RBFPos the worst? 
# 0. Why isn't gain helping? 
# 0. Check to make sure best hyperparameters aren't at edges of grid!
# ----------------------------
# 1. I'll just analyze the results, let's make the r2 plots across models, averaging each mouse
# 3. Then make the auxiliary plots for the models, showing what they do etc beyond the schematic

# Plot Structure:
# - for each model, show predictions (using a session with average performance? or same session throughout?)
# - show any internals of the model where relevant
# - for RRR, analyze how the RRR latents contain the spatial information or not. Probably using RidgeRegression in both directions.

In [6]:
# This code checks where the best hyperparameters are for each model
show_results = False
if show_results:
    for session in sessiondb.gen_sessions(imaging=True):
        print(session.session_print())
        if not rbfpos_decoder_only_model.check_existing_hyperparameters(session, spks_type="oasis"):
            continue
        hyp, best_score, results = rbfpos_decoder_only_model.get_best_hyperparameters(session, spks_type="oasis")
        params = [prm for prm in results.columns.tolist() if "score" not in prm]
        prmvalues = [results[p].unique() for p in params]
        best_result = results.loc[results["score"].idxmin()]
        best_grididx = [1.0* (prmvalues[i] == best_result[p]) for i, p in enumerate(params)]
        for prm, grididx in zip(params, best_grididx):
            print(prm, grididx)

In [4]:
for s in sessiondb.gen_sessions(imaging=True, mouseName="ATL057"):
    print(s.session_print())

ATL057/2024-07-08/701
ATL057/2024-07-09/701
ATL057/2024-07-10/701
ATL057/2024-07-12/701
ATL057/2024-07-16/701
ATL057/2024-07-19/701


In [3]:
session = B2Session.create("CR_Hippocannula6", "2022-08-30", "701", params=dict(spks_type="oasis"))

In [9]:
do_grid_search = False
if do_grid_search:
    with Timer("Grid Search"):
        grid_params, grid_score, grid_results = rbfpos_model._optimize_grid(session, "oasis", "train", "validation")
    print(grid_params, grid_score)
    print(grid_results)

In [32]:
# session = random.choice(sessiondb.iter_sessions(imaging=True))
model = rrr_model
with Timer("Optuna Study"):
    opt_params, opt_score, opt_results = model._optimize_optuna(session, "oasis", "train", "validation", n_trials=30, nan_safe=True)
    
hyperparameters = model._model_hyperparameters(**opt_params)
report = model.process(session, spks_type="oasis", hyperparameters=hyperparameters)

print(opt_params, opt_score)
# print(opt_results)
print(model.get_best_score(session, spks_type="oasis")["r2"])
print(report.metrics["r2"])

Optuna search:   0%|          | 0/30 [00:00<?, ?it/s]

Optuna Study || elapsed time: 1561.8268 seconds
{'rank': 100, 'alpha': 10.389182971504683} 0.000541586778126657
0.14549678564071655
0.14552301168441772


In [27]:
# session = random.choice(sessiondb.iter_sessions(imaging=True))
model = rrr_model
with Timer("Golden Section Search"):
    opt_params, opt_score, opt_results = model._optimize_golden(session, "oasis", "train", "validation", nan_safe=True)
    
hyperparameters = model._model_hyperparameters(**opt_params)
report = model.process(session, spks_type="oasis", hyperparameters=hyperparameters)

print(opt_params, opt_score)
# print(opt_results)
print(model.get_best_score(session, spks_type="oasis")["r2"])
print(report.metrics["r2"])

Golden Section Search || elapsed time: 28.2870 seconds
{'alpha': 11.36915244958074, 'rank': 359} 0.0014256469439715147
0.17497670650482178
0.17607557773590088


In [18]:
session = B2Session.create("CR_Hippocannula6", "2022-08-30", "701", params=dict(spks_type="oasis"))
show_full_results = True
if show_full_results:
    for model in [ext_model, int_model, ext_gain_model, int_gain_model, rbfpos_decoder_only_model, rbfpos_model, rbfpos_leak_model,rrr_model]:
        print(model.get_best_score(session, spks_type="oasis", method="best")["r2"])

0.15390602828627287
0.1552059872009105
0.158541797739719
0.15745377474761046
0.1553073525428772
0.13527292013168335
0.1482049822807312
0.17497670650482178


In [None]:
# This code runs another grid search "locally" with updated parameters...
run_grid_search = True
if run_grid_search:
    from itertools import product

    alpha_encoder = np.logspace(-2, 2, 9)
    alpha_decoder = np.logspace(-2, 2, 9)
    num_basis = [100]
    basis_width = [5]

    best_score = -np.inf
    best_hyp = None

    param_grid = list(product(num_basis, basis_width, alpha_encoder, alpha_decoder))
    for hyp in tqdm(param_grid):
        hyp = RBFPosHyperparameters(*hyp)
        r2 = rbfpos_leak_model.process(session, spks_type="oasis", hyperparameters=hyp).metrics["r2"]
        if r2 > best_score:
            best_score = r2
            best_hyp = hyp
        
    print("RBFPos(OPT)", best_hyp, best_score)

show_full_results = False
if show_full_results:
    for model in [ext_model, int_model, ext_gain_model, int_gain_model, rbfpos_decoder_only_model, rbfpos_model, rbfpos_leak_model,rrr_model]:
        print(model.get_best_score(session, spks_type="oasis")["r2"])

100%|██████████| 81/81 [01:25<00:00,  1.06s/it]


RBFPos(OPT) 

AttributeError: 'RBFPosHyperparameters' object has no attribute 'alpha'

In [10]:
MODEL_NAMES: list[ModelName] = [
    "external_placefield_1d",
    "internal_placefield_1d",
    "external_placefield_1d_gain",
    "internal_placefield_1d_gain",
    "rbfpos_decoder_only",
    "rbfpos",
    "rbfpos_leak",
    "rrr",
]

spks_type = "oasis"
sessiondb = get_database("vrSessions")
registry = PopulationRegistry()

get_sample_r2 = True
num_sessions = len(sessiondb.iter_sessions(imaging=True))
mouse_names = []
scores = np.full((len(MODEL_NAMES), num_sessions), np.nan)
if get_sample_r2:
    r2_avg_from_samples = np.full((len(MODEL_NAMES), num_sessions), np.nan)

for imodel, model_name in enumerate(MODEL_NAMES):
    model = get_model(model_name, registry)
    for isession, session in enumerate(tqdm(sessiondb.iter_sessions(imaging=True, session_params=dict(spks_type=spks_type)))):
        if get_sample_r2:
            hyps = model.get_best_hyperparameters(session, spks_type=spks_type)[0]
            report = model.process(session, spks_type=spks_type, hyperparameters=hyps)
            r2_avg_from_samples[imodel, isession] = measure_r2(report.predicted_data, report.target_data, reduce="mean", dim=0)

        metrics = model.get_best_score(session, spks_type=spks_type)
        scores[imodel, isession] = metrics["r2"]
        mouse_names.append(session.mouse_name)

100%|██████████| 149/149 [03:23<00:00,  1.37s/it]
100%|██████████| 149/149 [02:33<00:00,  1.03s/it]
100%|██████████| 149/149 [02:23<00:00,  1.04it/s]
100%|██████████| 149/149 [03:05<00:00,  1.25s/it]
100%|██████████| 149/149 [00:59<00:00,  2.50it/s]
100%|██████████| 149/149 [01:20<00:00,  1.84it/s]
100%|██████████| 149/149 [01:50<00:00,  1.35it/s]
100%|██████████| 149/149 [08:27<00:00,  3.41s/it]


In [29]:
mice = np.unique(mouse_names)
results = {mouse: np.zeros(len(MODEL_NAMES)) for mouse in mice}
r2res = {mouse: np.zeros(len(MODEL_NAMES)) for mouse in mice}
counts = {mouse: 0 for mouse in mice}
for isession in range(num_sessions):
    cname = mouse_names[isession]
    results[cname] += scores[:, isession]
    r2res[cname] += r2_avg_from_samples[:, isession]
    counts[cname] += 1

for mouse in mice:
    results[mouse] /= counts[mouse]
    r2res[mouse] /= counts[mouse]

mouse_averages = np.stack([results[mouse] for mouse in mice])
mouse_averages_r2 = np.stack([r2res[mouse] for mouse in mice])

idx_keep = [0, 1, 2, 3, 4, -1]

plt.close('all')
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(range(len(idx_keep)), mouse_averages[:, idx_keep].T, alpha=0.4)
ax[0].plot(range(len(idx_keep)), np.mean(mouse_averages[:, idx_keep], axis=0), color="black", alpha=1.0)
ax[1].plot(range(len(idx_keep)), mouse_averages_r2[:, idx_keep].T, alpha=0.4)
ax[1].plot(range(len(idx_keep)), np.mean(mouse_averages_r2[:, idx_keep], axis=0), color="black", alpha=1.0)
plt.show()