In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import numpy as np
import scipy
import importlib
import matplotlib.pyplot as plt
from glob import glob
import sys
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
import pdb
from tqdm import tqdm
import pandas as pd
import pickle

In [6]:
sys.path.append('../..')
from loaders import load_sabes, load_peanut
from decoders import lr_decoder
from utils import calc_loadings

#### Goal: Characterize the dynamics matrix that lies within the 

### Sabes

In [7]:
data_files = glob('/mnt/Secondary/data/sabes/*.mat')

In [8]:
# Use a fixed set of preprocessing/decoding parameters
sys.path.append('/home/akumar/nse/neural_control/submit_files')
sabes_args = importlib.import_module('sabes_decoding_args')

In [10]:
with open('/home/akumar/nse/neural_control/data/sabes_dimreduc_df.dat', 'rb') as f:
    sabes_dimreduc_df = pickle.load(f)

In [106]:
results_list = []
for i, data_file in tqdm(enumerate(data_files)):

    dat = load_sabes(data_file, **sabes_args.loader_args[0])
    X = np.squeeze(dat['spike_rates'])
    Y = np.squeeze(dat['behavior'])
    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = X[train_idxs, :]
        Xtest = X[test_idxs, :]
        Ytrain = Y[train_idxs, :]
        Ytest = Y[test_idxs, :]

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **sabes_args.decoder_args[0])
        
        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['data_file'] = data_file.split('/')[-1]
        results_dict['coef'] = Vh.T
        ### CALCULATING LOADINGS ONTO THE VELCOCITY PREDICTION
        results_dict['loadings'] = calc_loadings(Vh.T[:, 2:4], sabes_args.decoder_args[0]['decoding_window'])
        for k, v in sabes_args.decoder_args[0].items():
            results_dict[k] = v
        for k, v in sabes_args.loader_args[0].items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_list.append(results_dict)
        fold_idx += 1


0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.06s/it]
1it [00:19, 19.72s/it]

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.52s/it]
2it [00:46, 23.91s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.57s/it]
3it [00:52, 15.53s/it]

Processing spikes


100%|██████████| 1/1 [00:30<00:00, 30.91s/it]
4it [01:29, 23.98s/it]

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.44s/it]
5it [01:43, 20.64s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.50s/it]
6it [01:48, 15.06s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.32s/it]
7it [01:52, 11.47s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.26s/it]
8it [01:57,  9.48s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.89s/it]
9it [02:02,  7.99s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.46s/it]
10it [02:06,  6.77s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.80s/it]
11it [02:11,  6.43s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.55s/it]
12it [02:17,  6.10s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.29s/it]
13it [02:24,  6.52s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.85s/it]
14it [02:30,  6.28s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.91s/it]
15it [02:36,  6.14s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.63s/it]
16it [02:41,  5.90s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.32s/it]
17it [02:46,  5.63s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.65s/it]
18it [02:51,  5.60s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.63s/it]
19it [02:57,  5.53s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.38s/it]
20it [03:03,  5.79s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.84s/it]
21it [03:11,  6.39s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.17s/it]
22it [03:16,  5.95s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.22s/it]
23it [03:22,  5.94s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.31s/it]
24it [03:28,  5.96s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.62s/it]
25it [03:34,  6.08s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.46s/it]
26it [03:41,  6.16s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.73s/it]
27it [03:49,  6.71s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]
28it [03:57,  8.49s/it]


In [107]:
sabes_supervised_decoding = pd.DataFrame(results_list)

In [108]:
sabes_supervised_decoding.iloc[0]['loadings'].shape

(186,)

In [109]:
sabes_supervised_decoding.iloc[0]['coef'].shape

(930, 6)

In [95]:
with open('/home/akumar/nse/neural_control/data/sabes_supervised_decoding.dat', 'wb') as f:
    f.write(pickle.dumps(sabes_supervised_decoding))

### Peanut

In [49]:
data_file = '/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj'

In [62]:
# Use a fixed set of preprocessing/decoding parameters
#sys.path.append('/home/akumar/nse/neural_control/submit_files')
peanut_args = importlib.import_module('peanut_kca_args')
peanut_decoding_args = importlib.import_module('peanut_decoding_args')

In [66]:
with open('/home/akumar/nse/neural_control/data/peanut_kca_decoding_df.dat', 'rb') as f:
    peanut_decoding_df = pickle.load(f)

In [76]:
peanut_loader_args = peanut_args.loader_args[0]
peanut_loader_args.pop('epoch')

2

In [80]:
epochs = np.arange(2, 18, 2)
results_list = []
for i, epoch in tqdm(enumerate(epochs)):

    # Using defaults
    dat = load_peanut(data_file, epoch, **peanut_loader_args)

    X = np.squeeze(dat['spike_rates'])
    Y = np.squeeze(dat['behavior'])
    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = X[train_idxs, :]
        Xtest = X[test_idxs, :]
        Ytrain = Y[train_idxs, :]
        Ytest = Y[test_idxs, :]

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **peanut_decoding_args.decoder_args[0])
        
        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['epoch'] = epoch
        for k, v in peanut_loader_args.items():
            results_dict[k] = v
        for k, v in peanut_decoding_args.decoder_args[0].items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_list.append(results_dict)
        fold_idx += 1


8it [00:31,  3.91s/it]


In [81]:
peanut_supervised_decoding = pd.DataFrame(results_list)

In [82]:
with open('/home/akumar/nse/neural_control/data/peanut_supervised_decoding.dat', 'wb') as f:
    f.write(pickle.dumps(peanut_supervised_decoding))