In [16]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from syd import make_viewer
from tqdm import tqdm

from vrAnalysis.database import get_database
from vrAnalysis.helpers import Timer, get_placefield_location, cross_validate_trials, sort_by_preferred_environment
from vrAnalysis.sessions import B2Session
from vrAnalysis.processors import SpkmapProcessor
from vrAnalysis.processors.support import median_zscore
from vrAnalysis.processors.placefields import get_placefield, get_frame_behavior, get_placefield_prediction
from dimilibi import Population
from dimilibi import ReducedRankRegression
from dimilibi import measure_r2
from dimensionality_manuscript.regression_models.registry import PopulationRegistry, get_model
from dimensionality_manuscript.regression_models.hyperparameters import PlaceFieldHyperparameters, RBFPosHyperparameters, ReducedRankRegressionHyperparameters

# get session database
sessiondb = get_database("vrSessions")

# get population registry and models
registry = PopulationRegistry()
ext_model = get_model("external_placefield_1d", registry)
int_model = get_model("internal_placefield_1d", registry)
ext_gain_model = get_model("external_placefield_1d_gain", registry)
int_gain_model = get_model("internal_placefield_1d_gain", registry)
rbfpos_model = get_model("rbfpos", registry)
rrr_model = get_model("rrr", registry)

In [2]:
# KEY STEPS:
# 1. Speed up slow models :)
# ... Check the duplicate neuron improvement issue. 

In [60]:
import speedystats as ss

session = B2Session.create("ATL012", "2023-03-03", "701", params=dict(spks_type="oasis"))
population = Population(session.spks[:, session.idx_rois].T)
for alpha in torch.logspace(0, 4, 9):
    model = ReducedRankRegression(alpha=alpha, fit_intercept=True)
    model_max = ReducedRankRegression(alpha=alpha, fit_intercept=True)

    train_source, train_target = population.get_split_data(0, center=False, scale=True, scale_type="std", pre_split=True)
    test_source, test_target = population.get_split_data(1, center=False, scale=True, scale_type="std", pre_split=True)
    train_source_max, train_target_max = population.get_split_data(0, center=False, scale=True, scale_type="max", pre_split=True)
    test_source_max, test_target_max = population.get_split_data(1, center=False, scale=True, scale_type="max", pre_split=True)

    model.fit(train_source.T, train_target.T)
    model_max.fit(train_source_max.T, train_target_max.T)

    s = model.score(test_source.T, test_target.T, rank=200, nonnegative=True)
    s_max = model_max.score(test_source_max.T, test_target_max.T, rank=200, nonnegative=True)

    print(f"{alpha:.1f}: {s:.3f}, {s_max:.3f}, {s_max - s:.3f}")

1.0: -0.283, -0.062, 0.221
3.2: -0.281, -0.022, 0.259
10.0: -0.275, 0.007, 0.282
31.6: -0.258, 0.023, 0.281
100.0: -0.221, 0.023, 0.244
316.2: -0.158, 0.014, 0.172
1000.0: -0.081, 0.005, 0.086
3162.3: -0.016, -0.003, 0.014
10000.0: 0.021, -0.014, -0.035


In [42]:
print(np.min(prct), np.max(prct))
plt.close('all')
plt.hist(prct, bins=np.linspace(0, 10, 1001))
plt.show()

0.0 6000.632080078125


In [None]:
fpath = r"D:\localData\analysis\placeCellSingleSession\temp\population_ATL012_2023-03-03_701"
rpath = r"D:\localData\analysis\placeCellSingleSession\temp\rrr_optimization_results_ATL012_2023-03-03_701"
rspath = r"D:\localData\analysis\placeCellSingleSession\temp\rrr_optimization_results_state_ATL012_2023-03-03_701_fast"
session = B2Session.create("ATL012", "2023-03-03", "701", params=dict(spks_type="oasis"))

obj = np.load(fpath, allow_pickle=True)
res = np.load(rpath, allow_pickle=True)
resstate = np.load(rspath, allow_pickle=True)
print(obj["num_neurons"], obj["num_timepoints"])

valid_planes = session.valid_plane_idx()
plane_idx = session.get_plane_idx()
roi_in_keep = np.isin(plane_idx, [1, 2, 3, 4])
print(roi_in_keep.sum())
print(session.spks.shape[0])

print(res.keys())
print(np.array(res["alphas"])/10000)
print(res["ranks"])
print(res["test_scores"])

print(resstate.keys())
print(resstate["pf_pred_score"])
print(resstate["pf_pred_score_withgain"])
print(resstate["pf_pred_score_withcvgain"])

spks_from_past = session.spks[:, roi_in_keep].T
print(spks_from_past.shape)

population = Population.make_from_indices(obj, spks_from_past)
source_train, target_train = population.get_split_data(0, center=False, scale=True, scale_type="preserve")
source_test, target_test = population.get_split_data(2, center=False, scale=True, scale_type="preserve")

alpha = 1e5
model = ReducedRankRegression(alpha=alpha, fit_intercept=True).fit(source_train.T, target_train.T)
# print(model.score(source_test.T, target_test.T, rank=200, nonnegative=True))


# print(model.score(source_test.T, target_test.T, rank=200, nonnegative=True))
# print(rrr_model.score(session, spks_type="oasis", reduce="mean", hyperparameters=ReducedRankRegressionHyperparameters(rank=200, alpha=1e5)))

8888 10851
8888
10851
dict_keys(['mouse_name', 'datestr', 'sessionid', 'ranks', 'alphas', 'test_scores', 'test_scaled_mses'])
[10. 10. 10. 10.  1. 10. 10. 10. 10.]
(1, 2, 3, 5, 8, 15, 50, 100, 200)
[tensor(0.0439), tensor(0.0650), tensor(0.0858), tensor(0.1157), tensor(0.1424), tensor(0.1712), tensor(0.1863), tensor(0.1982), tensor(0.2080)]
dict_keys(['mouse_name', 'datestr', 'sessionid', 'study', 'params', 'test_score', 'test_scaled_mse', 'test_score_direct', 'test_scaled_mse_direct', 'params_direct', 'pf_pred_score', 'encoder_position_score', 'latent_position_score', 'pf_pred_score_withgain', 'opt_pos_estimate_source', 'opt_pos_estimate_target', 'opt_pos_estimate_target_withgain', 'opt_pos_estimate_position', 'pf_pred_score_withcvgain', 'opt_pos_estimate_target_withcvgain', 'rbfpos_to_target_score', 'test_score_doublecv'])
tensor(-0.0683)
tensor(0.0326)
tensor(0.0284)
(8888, 10851)
svd || elapsed time: 11.6880 seconds
linalg.svd || elapsed time: 12.1205 seconds
True


In [12]:
source_train.shape, target_train.shape

(torch.Size([4444, 6508]), torch.Size([4444, 6508]))

In [94]:
model_predicted = model.predict(source_test.T, rank=200, nonnegative=True)
print(model_predicted.shape, target_test.shape)

print(measure_r2(model_predicted, target_test.T, dim=0))
print(model.score(source_test.T, target_test.T, rank=200, nonnegative=True))

torch.Size([1301, 4444]) torch.Size([4444, 1301])
0.00898069329559803
0.00898069329559803


In [101]:
mse, predicted, target, extras = rrr_model.score(session, spks_type="oasis", reduce="none", hyperparameters=ReducedRankRegressionHyperparameters(rank=200, alpha=1e5), full_results=True)

print(mse.mean())
print(np.mean((predicted - target)**2))
print(measure_r2(predicted, target, dim=0))
print(measure_r2(predicted, target, dim=None))

print(predicted.shape, target.shape)
print(source_test.shape, target_test.shape)

4.918569
4.91857
0.07077919691801071
0.06868863105773926
(1534, 742) (1534, 742)
torch.Size([4444, 1301]) torch.Size([4444, 1301])


In [32]:
for session in sessiondb.gen_sessions(imaging=True):
    hyperparameters = PlaceFieldHyperparameters(num_bins=50, smooth_width=2.0)
    mse, predicted, target, extras = ext_gain_model.score(session, spks_type="oasis", reduce="none", hyperparameters=hyperparameters, full_results=True)
    print(np.mean(mse))

    constant_mse = np.mean((target - np.mean(target, axis=1, keepdims=True))**2, axis=1)
    scaled_mse = np.full_like(mse, np.nan)
    np.divide(mse, constant_mse, where=constant_mse > 0, out=scaled_mse)
    print(np.nanmean(scaled_mse))

    msecirc, pcirc, tcirc, _ = ext_gain_model.score(session, spks_type="oasis", train_split="train", test_split="train", reduce="none", hyperparameters=hyperparameters, full_results=True)
    print(np.mean(msecirc))
    break

8.529827318529819
1.2954132696524524e+31
8.445601370780215


In [34]:
isortr2 = np.argsort(scaled_mse)
viewer = make_viewer()
viewer.add_integer("roi", value=0, min=0, max=len(mse)-1)
viewer.add_float_range("ylims", min=0, max=1)
viewer.add_float_range("xlims0", min=0, max=predicted.shape[1])
viewer.add_float_range("xlims1", min=0, max=pcirc.shape[1])
def plot(state):
    idx_roi = isortr2[state["roi"]]
    mean_target = np.mean(target[idx_roi])
    constant_mse = np.mean((target[idx_roi] - mean_target)**2)
    scaled_mse = mse[idx_roi] / (constant_mse + 1e-9)
    fig, ax = plt.subplots(2, 1, figsize=(12, 7), layout="constrained", sharey=True)
    ax[0].plot(predicted[idx_roi], color="b")
    ax[0].plot(target[idx_roi], color="k")
    ax[0].set_title(f"MSE: {mse[idx_roi]:.6f}, Constant MSE: {constant_mse:.6f}, Scaled: {scaled_mse:.6f}")
    ax[1].plot(pcirc[idx_roi], color="b")
    ax[1].plot(tcirc[idx_roi], color="k")
    ax[1].set_title(f"MSE: {msecirc[idx_roi]:.6f}")
    ax[0].set_ylim(state["ylims"][0], state["ylims"][1])
    ax[0].set_xlim(state["xlims0"][0], state["xlims0"][1])
    ax[1].set_xlim(state["xlims1"][0], state["xlims1"][1])
    return fig
def update_ylim(state):
    idx_roi = isortr2[state["roi"]]
    all_data = np.concatenate((predicted[idx_roi], pcirc[idx_roi], target[idx_roi], tcirc[idx_roi]))
    ymin = np.minimum(np.min(all_data), 0)
    ymax = np.maximum(np.max(all_data), 1)
    viewer.update_float_range("ylims", value=(ymin, ymax), min=ymin, max=ymax)
viewer.on_change("roi", update_ylim)
update_ylim(viewer.state)
viewer.set_plot(plot)
viewer.show()

The behavior of the viewer will almost definitely not work as expected!


HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), IntSlider(value=0, continuous_up…

In [21]:
x = np.zeros(100)
x[0] = 100

np.sum((x - np.mean(x))**2), np.sum((x - 0)**2)

(9900.0, 10000.0)

In [182]:
for idx, session in enumerate(sessiondb.gen_sessions(imaging=True, mouseName="ATL027")):
    if len(session.environments) > 1:
        break

print(idx, session)

3 B2Session(mouse_name='ATL027', date='2023-07-24', session_id='701', spks_type='significant')


In [70]:
for model in [ext_model, int_model, ext_gain_model, int_gain_model]: #, rbfpos_model, rrr_model]:
    print(model.get_best_score(session, spks_type="oasis", reduce="mean"))

                                                            

-0.06549675360715122


                                                            

-0.07308238843951242


                                                            

-0.0759176616150783


                                                            

-0.08994863652436405


In [71]:
for model in [ext_model, int_model, ext_gain_model, int_gain_model]: #, rbfpos_model, rrr_model]:
    print(model.get_best_score(session, spks_type="deconvolved", reduce="mean"))

                                                            

-0.02211749862308902


                                                            

-0.023860554452715548


                                                            

-0.032284715467016216


                                                            

-0.038749732219321445


In [108]:
spks.shape

(12173, 17547)

In [154]:
with Timer("reluspks"):
    spks = session.spks[:, session.idx_rois]
    spks_ming0 = np.minimum.reduce(
        spks,
        axis=0,
        where=spks > 0,
        initial=np.max(spks)
    )
    spks_base = spks - spks_ming0
    spks_nonneg = np.maximum(spks_base, 0)

print(np.where(np.sum(spks > 0, axis=0)==0)[0])
np.max(spks_nonneg[:, 3223])

reluspks || elapsed time: 0.4476 seconds
[3223 4245 4538]


0.0

In [65]:
prms = ext_model.get_best_hyperparameters(session, spks_type="significant")
print(prms)
prms = PlaceFieldHyperparameters(num_bins=50, smooth_width=5.0)
prms = RBFPosHyperparameters(num_basis=20, basis_width=10, alpha=1e6)
prms = ReducedRankRegressionHyperparameters(rank=200, alpha=1e2)
r2, smse, predicted_data, target_data, extras = rrr_model.score(session, spks_type="deconvolved", train_split="half0", test_split="half1", reduce="none", hyperparameters=prms, full_results=True)

mse = ((predicted_data - target_data)**2).sum(axis=1)
print(r2.shape, mse.shape, smse.shape)
print(np.mean(r2), np.mean(mse), np.mean(smse))

PlaceFieldHyperparameters(num_bins=10, smooth_width=50.0)
(1291,) (1291,) (1291,)
-1.116276 10632.351 2.116276


In [66]:
rrr_model.get_best_score(session, spks_type="deconvolved")

                                                            

-0.05395011

In [55]:
from dimilibi import scaled_mse
scaled_mse(target_data.T, target_data.T + np.random.randn(target_data.shape[1], target_data.shape[0]) * 1000.0)

tensor(1.0003, dtype=torch.float64)

In [42]:
viewer = make_viewer()
viewer.add_integer("roi", value=0, min=0, max=len(r2)-1)
def plot(state):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(predicted_data[state["roi"]], color="b")
    ax.plot(target_data[state["roi"]], color="k")
    ax.set_title(f"R2: {r2[state['roi']]:.2f}, MSE: {mse[state['roi']]:.2f}, SMSE: {smse[state['roi']]:.2f}")
    return fig

viewer.set_plot(plot)
viewer.show()

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), IntSlider(value=0, continuous_up…