In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from syd import make_viewer
from tqdm import tqdm

from vrAnalysis.database import get_database
from vrAnalysis.helpers import Timer, get_placefield_location, cross_validate_trials, sort_by_preferred_environment
from vrAnalysis.sessions import B2Session
from vrAnalysis.processors import SpkmapProcessor
from vrAnalysis.processors.support import median_zscore
from vrAnalysis.processors.placefields import get_placefield, get_frame_behavior, get_placefield_prediction
from dimilibi import Population
from dimilibi import ReducedRankRegression
from dimilibi import measure_r2, mse
from dimensionality_manuscript.regression_models.registry import PopulationRegistry, get_model
from dimensionality_manuscript.regression_models.hyperparameters import PlaceFieldHyperparameters, RBFPosHyperparameters, ReducedRankRegressionHyperparameters

# get session database
sessiondb = get_database("vrSessions")

# get population registry and models
registry = PopulationRegistry()
ext_model = get_model("external_placefield_1d", registry)
int_model = get_model("internal_placefield_1d", registry)
ext_gain_model = get_model("external_placefield_1d_gain", registry)
int_gain_model = get_model("internal_placefield_1d_gain", registry)
rbfpos_model = get_model("rbfpos", registry)
rrr_model = get_model("rrr", registry)

In [None]:
# KEY STEPS:
# 1. Speed up slow models :)

In [None]:
session = B2Session.create("ATL012", "2023-03-03", "701", params=dict(spks_type="oasis"))
population = Population(session.spks[:, session.idx_rois].T)
for alpha in torch.logspace(0, 4, 9):
    model = ReducedRankRegression(alpha=alpha, fit_intercept=True)
    model_max = ReducedRankRegression(alpha=alpha, fit_intercept=True)

    train_source, train_target = population.get_split_data(0, center=False, scale=True, scale_type="std", pre_split=True)
    test_source, test_target = population.get_split_data(1, center=False, scale=True, scale_type="std", pre_split=True)
    train_source_max, train_target_max = population.get_split_data(0, center=False, scale=True, scale_type="max", pre_split=True)
    test_source_max, test_target_max = population.get_split_data(1, center=False, scale=True, scale_type="max", pre_split=True)

    model.fit(train_source.T, train_target.T)
    model_max.fit(train_source_max.T, train_target_max.T)

    s = model.score(test_source.T, test_target.T, rank=200, nonnegative=True)
    s_max = model_max.score(test_source_max.T, test_target_max.T, rank=200, nonnegative=True)

    print(f"{alpha:.1f}: {s:.3f}, {s_max:.3f}, {s_max - s:.3f}")

1.0: -0.283, -0.062, 0.221
3.2: -0.281, -0.022, 0.259
10.0: -0.275, 0.007, 0.282
31.6: -0.258, 0.023, 0.281
100.0: -0.221, 0.023, 0.244
316.2: -0.158, 0.014, 0.172
1000.0: -0.081, 0.005, 0.086
3162.3: -0.016, -0.003, 0.014
10000.0: 0.021, -0.014, -0.035


In [32]:
for session in sessiondb.gen_sessions(imaging=True):
    hyperparameters = PlaceFieldHyperparameters(num_bins=50, smooth_width=2.0)
    mse, predicted, target, extras = ext_gain_model.score(session, spks_type="oasis", reduce="none", hyperparameters=hyperparameters, full_results=True)
    print(np.mean(mse))

    constant_mse = np.mean((target - np.mean(target, axis=1, keepdims=True))**2, axis=1)
    scaled_mse = np.full_like(mse, np.nan)
    np.divide(mse, constant_mse, where=constant_mse > 0, out=scaled_mse)
    print(np.nanmean(scaled_mse))

    msecirc, pcirc, tcirc, _ = ext_gain_model.score(session, spks_type="oasis", train_split="train", test_split="train", reduce="none", hyperparameters=hyperparameters, full_results=True)
    print(np.mean(msecirc))
    break

8.529827318529819
1.2954132696524524e+31
8.445601370780215


In [34]:
isortr2 = np.argsort(scaled_mse)
viewer = make_viewer()
viewer.add_integer("roi", value=0, min=0, max=len(mse)-1)
viewer.add_float_range("ylims", min=0, max=1)
viewer.add_float_range("xlims0", min=0, max=predicted.shape[1])
viewer.add_float_range("xlims1", min=0, max=pcirc.shape[1])
def plot(state):
    idx_roi = isortr2[state["roi"]]
    mean_target = np.mean(target[idx_roi])
    constant_mse = np.mean((target[idx_roi] - mean_target)**2)
    scaled_mse = mse[idx_roi] / (constant_mse + 1e-9)
    fig, ax = plt.subplots(2, 1, figsize=(12, 7), layout="constrained", sharey=True)
    ax[0].plot(predicted[idx_roi], color="b")
    ax[0].plot(target[idx_roi], color="k")
    ax[0].set_title(f"MSE: {mse[idx_roi]:.6f}, Constant MSE: {constant_mse:.6f}, Scaled: {scaled_mse:.6f}")
    ax[1].plot(pcirc[idx_roi], color="b")
    ax[1].plot(tcirc[idx_roi], color="k")
    ax[1].set_title(f"MSE: {msecirc[idx_roi]:.6f}")
    ax[0].set_ylim(state["ylims"][0], state["ylims"][1])
    ax[0].set_xlim(state["xlims0"][0], state["xlims0"][1])
    ax[1].set_xlim(state["xlims1"][0], state["xlims1"][1])
    return fig
def update_ylim(state):
    idx_roi = isortr2[state["roi"]]
    all_data = np.concatenate((predicted[idx_roi], pcirc[idx_roi], target[idx_roi], tcirc[idx_roi]))
    ymin = np.minimum(np.min(all_data), 0)
    ymax = np.maximum(np.max(all_data), 1)
    viewer.update_float_range("ylims", value=(ymin, ymax), min=ymin, max=ymax)
viewer.on_change("roi", update_ylim)
update_ylim(viewer.state)
viewer.set_plot(plot)
viewer.show()

The behavior of the viewer will almost definitely not work as expected!


HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), IntSlider(value=0, continuous_up…

In [None]:
# prms = ext_model.get_best_hyperparameters(session, spks_type="significant")
# print(prms)
# prms = PlaceFieldHyperparameters(num_bins=50, smooth_width=5.0)
# prms = RBFPosHyperparameters(num_basis=20, basis_width=10, alpha=1e6)
session = B2Session.create("ATL012", "2023-03-03", "701", params=dict(spks_type="oasis"))
prms = ReducedRankRegressionHyperparameters(rank=100, alpha=1e5)
score, predicted_data, target_data, extras = rrr_model.score(session, spks_type="oasis", reduce="mean", hyperparameters=prms, full_results=True)
r2 = measure_r2(predicted_data, target_data, dim=None)
_mse = mse(predicted_data, target_data, dim=None)
print(score, r2)

0.0013813513796776533 -0.024595260620117188


In [66]:
rrr_model.get_best_score(session, spks_type="deconvolved")

                                                            

-0.05395011

In [55]:
from dimilibi import scaled_mse
scaled_mse(target_data.T, target_data.T + np.random.randn(target_data.shape[1], target_data.shape[0]) * 1000.0)

tensor(1.0003, dtype=torch.float64)

In [42]:
viewer = make_viewer()
viewer.add_integer("roi", value=0, min=0, max=len(r2)-1)
def plot(state):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(predicted_data[state["roi"]], color="b")
    ax.plot(target_data[state["roi"]], color="k")
    ax.set_title(f"R2: {r2[state['roi']]:.2f}, MSE: {mse[state['roi']]:.2f}, SMSE: {smse[state['roi']]:.2f}")
    return fig

viewer.set_plot(plot)
viewer.show()

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), IntSlider(value=0, continuous_up…