In [1]:
import os, time
import numpy as np
import pickle
from tqdm import tqdm

from numpy.linalg import norm, svd
from helpers import comp_pr
from sklearn.linear_model import RidgeCV

# Data path
from specs import data_path, rng

from data_loaders import (
    load_golub_2018, load_hennig_2018, load_degenhart_2020, load_russo_2018, load_nlb_maze, load_nlb_rtt)

In [2]:
def comp_corr(X, w_out):
    """ Correlation between states and output weights.
    With population average at each time point subtracted.
    """
    return norm(X @ w_out.T) / (norm(X) * norm(w_out))

def comp_ridge(hids, output, comp_neuron_subs=False, d_thr=0.9, ):
    X = hids.astype(float)
    y = output
    n_samples, dim_hid = X.shape
    # Center along time
    X = X - X.mean(0)
    y = y - y.mean(0)
    # # Z-score along time
    # X = (X - X.mean(0)) / X.std(0)
    # y = (y - y.mean(0)) / y.std(0)
    
    # Ridge regression with cross validation. 
    alpha_range = np.logspace(-3, 6, 20)
    ridge = RidgeCV(alphas=alpha_range)
    ridge.fit(X, y)
    r_sq = ridge.score(X, y)
    ridge_alpha = ridge.alpha_  # Save for documentation
    w_out = ridge.coef_
    
    # Correlation between output weights and hidden states
    corr_w_x = comp_corr(X, w_out)
    
    # Participation ratio
    pr = comp_pr(X)
    
    # Compute dimensions of data and necessary to fit
    d_var, d_fit_rel, d_fits, r_sq_ps = comp_dim_var_fit(X, y, ridge, d_thr)
    # Ratio between relative dimensions
    ratio_d_fit_var_rel = d_fit_rel / d_var
    
    res = {
        "r_sq": r_sq, 
        "ridge_alpha": ridge_alpha, 
        "corr_w_x": corr_w_x, 
        "pr": pr, 
        "n_samples": n_samples, 
        "dim_hid": dim_hid, 
        "d_var": d_var, 
        "d_fit_rel": d_fit_rel, 
        "ratio_d_fit_var_rel": ratio_d_fit_var_rel,
        "w_out": w_out,
    }
    
    ### Fits on subsets
    lbls_res_t = [
        "r_sq", 
        "ridge_alpha", 
        "corr_w_x", 
        "pr", 
        "d_var", 
        "d_fit_rel", 
        "ratio_d_fit_var_rel", 
               ]
    # Fit on subsamples in time points
    n_fit_t = 20
    frac_n_t = 1/2
    n_subs_t = int(n_samples * frac_n_t)
    # Results
    res_subs_t = np.zeros((len(lbls_res_t), n_fit_t))
    for i_fit in tqdm(range(n_fit_t)):
        idx_s = rng.choice(n_samples, n_subs_t, replace=False)
        X_s, y_s = X[idx_s], y[idx_s]
        ridge.fit(X_s, y_s)
        w_out_s = ridge.coef_
        res_subs_t[0, i_fit] = ridge.score(X_s, y_s)
        res_subs_t[1, i_fit] = ridge.alpha_
        res_subs_t[2, i_fit] = comp_corr(X_s, w_out_s)
        res_subs_t[3, i_fit] = comp_pr(X_s)
        
        # Compute dimensions of data and necessary to fit
        d_var, d_fit_rel, d_fits, r_sq_ps = comp_dim_var_fit(X_s, y_s, ridge, d_thr)
        ratio_d_fit_var_rel = d_fit_rel / d_var
        res_subs_t[4:7, i_fit] = d_var, d_fit_rel, ratio_d_fit_var_rel,
    # Save as dict
    res_subs_t = {lbls_res_t[i]: res_sub for i, res_sub in enumerate(res_subs_t)}
    
    # Fit on subsets of neurons
    # Number of subsets
    n_fit_n = 20
    # Number of neurons
    frac_dim_hids = np.linspace(0.1, 1., 10)
    dim_hid_subs = np.int_(dim_hid * frac_dim_hids)
    n_dim_hid = len(dim_hid_subs)
    # Results
    lbls_res_n = ["r_sq", "ridge_alpha", "corr_w_x", "pr",
               ]
    res_subs_n = np.zeros((len(lbls_res_n), n_fit_n, n_dim_hid))
    if comp_neuron_subs:
        for i_fit in tqdm(range(n_fit_n)):
            for i_n in range(n_dim_hid):
                dim_hid_i = dim_hid_subs[i_n]
                idx_n = rng.choice(dim_hid, dim_hid_i, replace=False)
                X_n = X[:, idx_n]
                ridge.fit(X_n, y)
                w_out_s = ridge.coef_
                res_subs_n[0, i_fit, i_n] = ridge.score(X_n, y)
                res_subs_n[1, i_fit, i_n] = ridge.alpha_
                res_subs_n[2, i_fit, i_n] = comp_corr(X_n, w_out_s)
                res_subs_n[3, i_fit, i_n] = comp_pr(X_n)

    # Save as dict
    res_subs_n = {lbls_res_n[i]: res_sub for i, res_sub in enumerate(res_subs_n)}

    return res, res_subs_t, res_subs_n

def comp_dim_var_fit(X_i, y_i, ridge, d_thr):
    # Compare variance explained with the ability to fit based on the leading components
    dim_hid = X_i.shape[1]

    # Use SVD instead of PCA (adds one mode if not z-scored)
    U, S, _ = svd(X_i.T, full_matrices=False)
    # Dimension of data
    cevr = (S**2).cumsum() / (S**2).sum()
    i_thr = np.where(cevr > d_thr)[0]
    d_var = i_thr[0] + 1
    
    # Fit on full dataset first
    ridge.fit(X_i, y_i)
    r_sq_full = ridge.score(X_i, y_i)
    
    # Fit based on first k PCs of X. 
    d_fits = []
    r_sq_ps = []
    d_fit = 0
    while True:
        d_fit += 1
        if d_fit > dim_hid:
            break
        # Projection of X onto leading modes
        X_ip = X_i @ U[:, :d_fit]
        # Fit the output based on the projection
        ridge.fit(X_ip, y_i)
        r_sq_p = ridge.score(X_ip, y_i)
        d_fits.append(d_fit)
        r_sq_ps.append(r_sq_p)
        if r_sq_p > r_sq_full * d_thr:
            d_fit_rel = d_fit
            break

    return d_var, d_fit_rel, d_fits, r_sq_ps
    


In [6]:
# Compute results

from importlib import reload
import data_loaders; reload(data_loaders)
from data_loaders import (
    load_golub_2018, load_hennig_2018, load_degenhart_2020, load_russo_2018, load_nlb_maze, load_nlb_rtt)

dataset_supers = [
    "bci-golub_2018",
    "bci-hennig_2018",
    "bci-degenhart_2020",
    "russo_2018_1", 
    "russo_2018_2", 
    "nlb-mc_maze_large",
    "nlb-mc_rtt",
    # "nlb-mc_maze_small",
]

# Decide whether to compute everything or only a part.
# compute_mods = "vel_only"
compute_mods = "all"

if compute_mods == "vel_only":
    ba_learning = ["before"]
elif compute_mods == "all":
    # ba_learning = ["before", "after"]
    ba_learning = ["before"]
    
# Compute subsets of neurons? 
comp_neuron_subs = True #compute_mods == "vel_only"

results = {}
time0 = time.time()
for dataset_super in dataset_supers[:]:
    if dataset_super.startswith('bci'):
        _, dataset = dataset_super.split('-')
        if dataset in ['golub_2018', 'hennig_2018']:
            if dataset == 'golub_2018':
                output_dict, hids_dict = load_golub_2018(0)
            if dataset == 'hennig_2018':
                output_dict, hids_dict = load_hennig_2018(0)
            # Before and after
            for key in ba_learning:
                ds_name = dataset_super + '-' + key
                print(ds_name)
                output = output_dict[key]
                hids = hids_dict[key]
                results[ds_name] = comp_ridge(hids, output, comp_neuron_subs)
        elif dataset == "degenhart_2020":
            key = 'before'
            ds_name = dataset_super + '-' + key
            print(ds_name)
            output, hids = load_degenhart_2020(fit_kalman=False)
            results[ds_name] = comp_ridge(hids, output, comp_neuron_subs)

    if dataset_super.startswith('russo'):
        i_monkey = np.where(dataset_super[-1] == np.array(list('12')))[0][0]
        file_name = ["Cousteau_tt.mat", "Drake_tt.mat"][i_monkey]
        output_dict, hids_dict = load_russo_2018(file_name, subs_step=5)
        # Output modalities
        output_mods = ["emg", "hand_pos", "hand_vel", "hand_acc"]
        if compute_mods == "vel_only":
            output_mods = [om for om in output_mods if om.split('_')[-1] == "vel"]
        elif compute_mods == "all":
            pass
        for key in output_mods:
            ds_name = dataset_super + '-' + key
            print(ds_name)
            output = output_dict[key]
            hids = hids_dict[key]
            results[ds_name] = comp_ridge(hids, output, comp_neuron_subs)

    if dataset_super.startswith('nlb'):
        if dataset_super.split('-')[1].startswith("mc_maze"):
            output_dict, hids_dict = load_nlb_maze(dataset_super)
            # Output modalities
            output_mods = ["hand_pos", "hand_vel", "hand_acc"]
            if compute_mods == "vel_only":
                output_mods = [om for om in output_mods if om.split('_')[-1] == "vel"]
            elif compute_mods == "all":
                pass
            # Single trials or averages?
            single_or_tcas = ["single", "tca"]
            for output_mod in output_mods:
                for single_or_tca in single_or_tcas:
                    ds_name = dataset_super + '-' + output_mod + '-' + single_or_tca
                    print(ds_name)
                    key = output_mod
                    if single_or_tca == 'tca':
                        key += "_tca"
                    output = output_dict[key]
                    hids = hids_dict[key]
                    results[ds_name] = comp_ridge(hids, output, comp_neuron_subs)

        if dataset_super.split('-')[1].startswith("mc_rtt"):
            output_dict, hids_dict = load_nlb_rtt(dataset_super)
            output_mods = ["finger_pos", "finger_vel", "finger_acc"]
            if compute_mods == "vel_only":
                output_mods = [om for om in output_mods if om.split('_')[-1] == "vel"]
            elif compute_mods == "all":
                pass
            for key in output_mods:
                ds_name = dataset_super + '-' + key
                print(ds_name)
                output = output_dict[key]
                hids = hids_dict[key]
                results[ds_name] = comp_ridge(hids, output, comp_neuron_subs)

print("Took %.3f sec." % (time.time() - time0))

# Save data
res = [
    dataset_supers, 
    results,
]
# Save data
file_name = "data_corr_dims.pkl"
if compute_mods == "vel_only":
    file_name = "data_corr_dims_vel_only.pkl"
data_file = os.path.join(data_path, file_name)
with open(data_file, 'wb') as handle:
    pickle.dump(res, handle)
print('Saved to ', data_file)

bci-golub_2018-before


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.67it/s]
100%|███████████████████████████████████████████| 20/20 [00:13<00:00,  1.48it/s]


bci-hennig_2018-before


100%|███████████████████████████████████████████| 20/20 [00:06<00:00,  3.10it/s]
100%|███████████████████████████████████████████| 20/20 [00:35<00:00,  1.75s/it]


bci-degenhart_2020-before


100%|███████████████████████████████████████████| 20/20 [00:09<00:00,  2.02it/s]
100%|███████████████████████████████████████████| 20/20 [00:47<00:00,  2.36s/it]


russo_2018_1-emg


100%|███████████████████████████████████████████| 20/20 [00:10<00:00,  1.87it/s]
100%|███████████████████████████████████████████| 20/20 [00:42<00:00,  2.12s/it]


russo_2018_1-hand_pos


100%|███████████████████████████████████████████| 20/20 [00:13<00:00,  1.49it/s]
100%|███████████████████████████████████████████| 20/20 [00:38<00:00,  1.92s/it]


russo_2018_1-hand_vel


100%|███████████████████████████████████████████| 20/20 [00:11<00:00,  1.81it/s]
100%|███████████████████████████████████████████| 20/20 [00:36<00:00,  1.81s/it]


russo_2018_1-hand_acc


100%|███████████████████████████████████████████| 20/20 [00:12<00:00,  1.57it/s]
100%|███████████████████████████████████████████| 20/20 [00:41<00:00,  2.05s/it]


russo_2018_2-emg


100%|███████████████████████████████████████████| 20/20 [00:12<00:00,  1.56it/s]
100%|███████████████████████████████████████████| 20/20 [00:38<00:00,  1.92s/it]


russo_2018_2-hand_pos


100%|███████████████████████████████████████████| 20/20 [00:10<00:00,  1.99it/s]
100%|███████████████████████████████████████████| 20/20 [00:40<00:00,  2.02s/it]


russo_2018_2-hand_vel


100%|███████████████████████████████████████████| 20/20 [00:12<00:00,  1.65it/s]
100%|███████████████████████████████████████████| 20/20 [00:41<00:00,  2.07s/it]


russo_2018_2-hand_acc


100%|███████████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s]
100%|███████████████████████████████████████████| 20/20 [00:42<00:00,  2.12s/it]


nlb-mc_maze_large-hand_pos-single


100%|███████████████████████████████████████████| 20/20 [00:17<00:00,  1.13it/s]
100%|███████████████████████████████████████████| 20/20 [00:37<00:00,  1.90s/it]


nlb-mc_maze_large-hand_pos-tca


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  5.71it/s]
100%|███████████████████████████████████████████| 20/20 [00:13<00:00,  1.51it/s]


nlb-mc_maze_large-hand_vel-single


100%|███████████████████████████████████████████| 20/20 [00:20<00:00,  1.02s/it]
100%|███████████████████████████████████████████| 20/20 [00:39<00:00,  1.95s/it]


nlb-mc_maze_large-hand_vel-tca


100%|███████████████████████████████████████████| 20/20 [00:04<00:00,  4.22it/s]
100%|███████████████████████████████████████████| 20/20 [00:14<00:00,  1.39it/s]


nlb-mc_maze_large-hand_acc-single


100%|███████████████████████████████████████████| 20/20 [01:01<00:00,  3.05s/it]
100%|███████████████████████████████████████████| 20/20 [00:39<00:00,  1.97s/it]


nlb-mc_maze_large-hand_acc-tca


100%|███████████████████████████████████████████| 20/20 [00:08<00:00,  2.32it/s]
100%|███████████████████████████████████████████| 20/20 [00:15<00:00,  1.28it/s]
  obj = obj._drop_axis(labels, axis, level=level, errors=errors)


nlb-mc_rtt-finger_pos


100%|███████████████████████████████████████████| 20/20 [01:08<00:00,  3.40s/it]
100%|███████████████████████████████████████████| 20/20 [00:37<00:00,  1.90s/it]


nlb-mc_rtt-finger_vel


100%|███████████████████████████████████████████| 20/20 [00:24<00:00,  1.21s/it]
100%|███████████████████████████████████████████| 20/20 [00:38<00:00,  1.92s/it]


nlb-mc_rtt-finger_acc


100%|███████████████████████████████████████████| 20/20 [00:59<00:00,  2.95s/it]
100%|███████████████████████████████████████████| 20/20 [00:37<00:00,  1.88s/it]

Took 1142.214 sec.
Saved to  ../data/data_corr_dims.pkl



