# Loading hc3 data models

### Table of contents

1. [**One-dimensional regression**](#1D)
2. [**Two-dimensional regression**](#2D)



In [1]:
%load_ext autoreload
%autoreload

import sys

sys.path.append("../../../GaussNeuro")
import gaussneuro as lib

sys.path.append("../fit/")
import hc3

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np


import analyze_hc3
import utils

  from tqdm.autonotebook import tqdm


In [2]:
jax.config.update('jax_platform_name', 'cpu')
#jax.config.update("jax_enable_x64", True)

In [3]:
### names ###
reg_config_names = [
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_PP-log__factorized_gp-32-1000_X[x-hd-theta]_Z[]_freeze[]', 
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_PP-log_rcb-8-17.-36.-6.-30.-self-H500_factorized_gp-32-1000_' + \
    'X[x-hd-theta]_Z[]_freeze[obs_model0spikefilter0a-obs_model0spikefilter0log_c-obs_model0spikefilter0phi]', 
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_gamma-log__rate_renewal_gp-32-1000_X[x-hd-theta]_Z[]_freeze[]', 
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_invgauss-log__rate_renewal_gp-32-1000_X[x-hd-theta]_Z[]_freeze[]', 
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_lognorm-log__rate_renewal_gp-32-1000_X[x-hd-theta]_Z[]_freeze[]', 
    'ec014.29_ec014.468_isi5ISI5sel0.0to0.5_isi4__nonparam_pp_gp-64-matern32-matern32-1000-n2._' + \
    'X[x-hd-theta]_Z[]_freeze[obs_model0log_warp_tau]', 
]

tuning_model_name = reg_config_names[-1]


data_path = '../../data/hc3/'
checkpoint_dir = '../checkpoint/'


seed = 123
rng = np.random.default_rng(seed)
prng_state = jr.PRNGKey(seed)
batch_size = 10000

In [4]:
### load dataset ###
session_name = 'ec014.29_ec014.468_isi5'
max_ISI_order = 4

select_fracs = [0.0, 0.5]
dataset_dict = hc3.spikes_dataset(session_name, data_path, max_ISI_order, select_fracs)

test_select_fracs = [
    [0.5, 0.6], 
    [0.6, 0.7], 
    [0.7, 0.8], 
    [0.8, 0.9], 
    [0.9, 1.0], 
]
test_dataset_dicts = [
    hc3.spikes_dataset(session_name, data_path, max_ISI_order, tf) for tf in test_select_fracs
]

In [None]:
regression_dict = utils.evaluate_regression_fits(
    checkpoint_dir, reg_config_names, hc3.observed_kernel_dict_induc_list, 
    dataset_dict, test_dataset_dicts, rng, prng_state
)

Analyzing regression for ec014.29_ec014.468_isi5ISI5sel0.0to0.5_PP-log__factorized_gp-32-1000_X[x-hd-theta]_Z[]_freeze[]...


In [12]:
dataset_dict["properties"]

{'tbin': array(0.001),
 'name': 'ec014.29_ec014.468_isi5ISI4sel0.0to0.5',
 'neurons': 49,
 'metainfo': {}}

In [13]:
dataset_dict["align_start_ind"]

2437742

In [16]:
dataset_dict["spiketrains"].sum(-1)

array([3411., 3583., 5281.,  325.,  936.,  173., 2179., 1142., 1994.,
       2080.,  798., 5554., 6053., 3751., 1570.,  284., 4961., 4514.,
        478., 4774., 6572., 2019., 5604., 1799., 3964.,  691.,  659.,
        192., 3704., 1792., 4544.,  911., 5594.,  950.,  342., 1122.,
        382., 3123., 2549., 4931., 1014., 3305., 2768.,  682.,  546.,
       2984.,  916., 1714., 2754.])

In [8]:
dataset_dict["ISIs"]

array([[[8.80392742e+00, 8.97099495e+00, 9.07203579e+00, 9.50991333e-01],
        [2.11502910e+00, 2.65399003e+00, 1.31000087e-01, 7.54993856e-01],
        [4.71884060e+00, 8.00000038e-03, 8.09999853e-02, 9.00000054e-03],
        ...,
        [2.93826080e+02, 8.64870987e+01, 9.71232529e+01, 2.61132774e+01],
        [1.00000005e-03, 8.37775517e+00, 4.67639275e+01, 1.80976830e+01],
        [1.16000056e-01, 2.39000306e-01, 1.36200762e+00, 1.34100664e+00]],

       [[8.80492783e+00, 8.97099495e+00, 9.07203579e+00, 9.50991333e-01],
        [2.11602902e+00, 2.65399003e+00, 1.31000087e-01, 7.54993856e-01],
        [4.71984053e+00, 8.00000038e-03, 8.09999853e-02, 9.00000054e-03],
        ...,
        [2.93827087e+02, 8.64870987e+01, 9.71232529e+01, 2.61132774e+01],
        [2.00000009e-03, 8.37775517e+00, 4.67639275e+01, 1.80976830e+01],
        [1.17000058e-01, 2.39000306e-01, 1.36200762e+00, 1.34100664e+00]],

       [[8.80592823e+00, 8.97099495e+00, 9.07203579e+00, 9.50991333e-01],
        