In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import scipy
import importlib
import matplotlib.pyplot as plt
from glob import glob
import sys
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
import pdb
from tqdm import tqdm
import pandas as pd
import pickle
from sklearn.preprocessing import normalize, StandardScaler

In [4]:
sys.path.append('../..')
from loaders import load_sabes, load_peanut
from decoders import lr_decoder
from utils import calc_loadings

In [7]:
sys.path.append('/home/akumar/nse/neural_control/submit_files')

#### Goal: Find the supervised subspaces

### Sabes

In [88]:
data_files = glob('/mnt/Secondary/data/sabes/*.mat')

In [89]:
# Use a fixed set of preprocessing/decoding parameters
sabes_args = importlib.import_module('sabes_decoding_args')

In [90]:
sabes_args.loader_args

[{'bin_width': 50,
  'filter_fn': 'none',
  'filter_kwargs': {},
  'boxcox': 0.5,
  'spike_threshold': 100}]

In [83]:
with open('/home/akumar/nse/neural_control/data/sabes_dimreduc_df.dat', 'rb') as f:
    sabes_dimreduc_df = pickle.load(f)

In [86]:
np.unique(sabes_dimreduc_df['dimreduc'].values)

array(['DCA', 'OLS1', 'OLS3', 'OLS5', 'PCA', 'SFA'], dtype=object)

In [84]:
sabes_dimreduc_df.keys()

Index(['dim', 'fold_idx', 'train_idxs', 'test_idxs', 'data_file',
       'loader_args', 'T', 'dimreduc', 'coef', 'score'],
      dtype='object')

In [100]:
sabes_args.decoder_args[0]

{'trainlag': 4, 'testlag': 4, 'decoding_window': 5}

In [106]:
results_list = []
for i, data_file in tqdm(enumerate(data_files)):

    dat = load_sabes(data_file, **sabes_args.loader_args[0])
    X = np.squeeze(dat['spike_rates'])
    Y = np.squeeze(dat['behavior'])
    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = X[train_idxs, :]
        Xtest = X[test_idxs, :]
        Ytrain = Y[train_idxs, :]
        Ytest = Y[test_idxs, :]

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **sabes_args.decoder_args[0])
        
        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['data_file'] = data_file.split('/')[-1]
        results_dict['coef'] = Vh.T
        ### CALCULATING LOADINGS ONTO THE VELCOCITY PREDICTION
        results_dict['loadings'] = calc_loadings(Vh.T[:, 2:4], sabes_args.decoder_args[0]['decoding_window'])
        for k, v in sabes_args.decoder_args[0].items():
            results_dict[k] = v
        for k, v in sabes_args.loader_args[0].items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_list.append(results_dict)
        fold_idx += 1


0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.06s/it]
1it [00:19, 19.72s/it]

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.52s/it]
2it [00:46, 23.91s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.57s/it]
3it [00:52, 15.53s/it]

Processing spikes


100%|██████████| 1/1 [00:30<00:00, 30.91s/it]
4it [01:29, 23.98s/it]

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.44s/it]
5it [01:43, 20.64s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.50s/it]
6it [01:48, 15.06s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.32s/it]
7it [01:52, 11.47s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.26s/it]
8it [01:57,  9.48s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.89s/it]
9it [02:02,  7.99s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.46s/it]
10it [02:06,  6.77s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.80s/it]
11it [02:11,  6.43s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.55s/it]
12it [02:17,  6.10s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.29s/it]
13it [02:24,  6.52s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.85s/it]
14it [02:30,  6.28s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.91s/it]
15it [02:36,  6.14s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.63s/it]
16it [02:41,  5.90s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.32s/it]
17it [02:46,  5.63s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.65s/it]
18it [02:51,  5.60s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.63s/it]
19it [02:57,  5.53s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.38s/it]
20it [03:03,  5.79s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.84s/it]
21it [03:11,  6.39s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.17s/it]
22it [03:16,  5.95s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.22s/it]
23it [03:22,  5.94s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.31s/it]
24it [03:28,  5.96s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.62s/it]
25it [03:34,  6.08s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.46s/it]
26it [03:41,  6.16s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.73s/it]
27it [03:49,  6.71s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.51s/it]
28it [03:57,  8.49s/it]


In [107]:
sabes_supervised_decoding = pd.DataFrame(results_list)

In [108]:
sabes_supervised_decoding.iloc[0]['loadings'].shape

(186,)

In [109]:
sabes_supervised_decoding.iloc[0]['coef'].shape

(930, 6)

In [95]:
with open('/home/akumar/nse/neural_control/data/sabes_supervised_decoding.dat', 'wb') as f:
    f.write(pickle.dumps(sabes_supervised_decoding))

### Peanut

In [49]:
data_file = '/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj'

In [62]:
# Use a fixed set of preprocessing/decoding parameters
#sys.path.append('/home/akumar/nse/neural_control/submit_files')
peanut_args = importlib.import_module('peanut_kca_args')
peanut_decoding_args = importlib.import_module('peanut_decoding_args')

In [66]:
with open('/home/akumar/nse/neural_control/data/peanut_kca_decoding_df.dat', 'rb') as f:
    peanut_decoding_df = pickle.load(f)

In [76]:
peanut_loader_args = peanut_args.loader_args[0]
peanut_loader_args.pop('epoch')

2

In [80]:
epochs = np.arange(2, 18, 2)
results_list = []
for i, epoch in tqdm(enumerate(epochs)):

    # Using defaults
    dat = load_peanut(data_file, epoch, **peanut_loader_args)

    X = np.squeeze(dat['spike_rates'])
    Y = np.squeeze(dat['behavior'])
    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = X[train_idxs, :]
        Xtest = X[test_idxs, :]
        Ytrain = Y[train_idxs, :]
        Ytest = Y[test_idxs, :]

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **peanut_decoding_args.decoder_args[0])
        
        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['epoch'] = epoch
        for k, v in peanut_loader_args.items():
            results_dict[k] = v
        for k, v in peanut_decoding_args.decoder_args[0].items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_list.append(results_dict)
        fold_idx += 1


8it [00:31,  3.91s/it]


In [81]:
peanut_supervised_decoding = pd.DataFrame(results_list)

In [82]:
with open('/home/akumar/nse/neural_control/data/peanut_supervised_decoding.dat', 'wb') as f:
    f.write(pickle.dumps(peanut_supervised_decoding))

### Assess whether improved decoding performance is associated with closer alignment with the supervised subspace

In [5]:
with open('/home/akumar/nse/neural_control/data/peanut_supervised_decoding.dat', 'rb') as f:
    peanut_supervised_df = pickle.load(f)
with open('/home/akumar/nse/neural_control/data/peanut_supervised_decoding.dat', 'rb') as f:
    sabes_supervised_df = pickle.load(f)

In [6]:
with open('/home/akumar/nse/neural_control/data/peanut_dimreduc_df', 'rb') as f:
    peanut_dimreduc_df = pickle.load(f)
with open('/home/akumar/nse/neural_control/data/sabes_dimreduc_df', 'rb') as f:
    sabes_dimreduc_df = pickle.load(f)

FileNotFoundError: [Errno 2] No such file or directory: '/home/akumar/nse/neural_control/data/peanut_dimreduc_df'

### Fit supervised model to segmented behavior

In [8]:
# Sabes segmented
from segmentation import reach_segment_sabes

In [55]:
data_files = glob('/mnt/Secondary/data/sabes/*.mat')
# Use a fixed set of preprocessing/decoding parameters
sys.path.append('/home/akumar/nse/neural_control/submit_files')
sabes_args = importlib.import_module('sabes_dimreduc_args')
decoder_args = importlib.import_module('sabes_decoding_args')

In [56]:
decoder_args = {'trainlag':4, 'testlag':4, 'decoding_window':3}

In [7]:
start_times = {'indy_20160426_01': 0,
               'indy_20160622_01':1700,
               'indy_20160624_03': 500,
               'indy_20160627_01': 0,
               'indy_20160630_01': 0,
               'indy_20160915_01': 0,
               'indy_20160921_01': 0,
               'indy_20160930_02': 0,
               'indy_20160930_05': 300,
               'indy_20161005_06': 0,
               'indy_20161006_02': 350,
               'indy_20161007_02': 950,
               'indy_20161011_03': 0,
               'indy_20161013_03': 0,
               'indy_20161014_04': 0,
               'indy_20161017_02': 0,
               'indy_20161024_03': 0,
               'indy_20161025_04': 0,
               'indy_20161026_03': 0,
               'indy_20161027_03': 500,
               'indy_20161206_02': 5500,
               'indy_20161207_02': 0,
               'indy_20161212_02': 0,
               'indy_20161220_02': 0,
               'indy_20170123_02': 0,
               'indy_20170124_01': 0,
               'indy_20170127_03': 0,
               'indy_20170131_02': 0,
               }

In [61]:
results_list = []

# Binning of orientation
bins = np.arange(-np.pi,np.pi,.25 * np.pi)

for i, data_file in tqdm(enumerate(data_files)):

    dat = load_sabes(data_file, **sabes_args.loader_args[0])
    dat_segmented = reach_segment_sabes(dat, start_times[data_file.split('/')[-1].split('.mat')[0]])

    spike_rates = dat_segmented['spike_rates']
    spike_rates = spike_rates.reshape(spike_rates.shape[1], -1)
    vels = dat_segmented['vel']
    #‖𝑉(𝑡)‖
    peak_vels_in_windows = np.array([np.amax(np.absolute(vels[start : end + 1])) \
                                    for start, end in dat_segmented['transition_times']])[:,np.newaxis]
    orientation_in_windows = dat_segmented['transition_orientation']

    #‖𝑉(𝑡)‖𝑠𝑖𝑛[𝜃(𝑡)]
    peak_vels_in_windows = normalize(peak_vels_in_windows, axis = 0)
    vel_sin = normalize(np.sin(orientation_in_windows)[:,np.newaxis] * peak_vels_in_windows, axis = 0)
#     vel_sin = np.sin(orientation_in_windows)[:,np.newaxis] * peak_vels_in_windows
    #‖𝑉(𝑡)‖𝑐𝑜𝑠[𝜃(𝑡)]
    vel_cos = np.cos(orientation_in_windows)[:,np.newaxis] * peak_vels_in_windows
#    vel_cos = normalize(np.cos(orientation_in_windows)[:,np.newaxis] * peak_vels_in_windows, axis = 0)

    #Binning
    binned_indices = np.digitize(orientation_in_windows, bins)
    binned_indices = [np.where(binned_indices == idx) for idx in range(1,9)]

    for j, bin_idxs in enumerate(binned_indices):
        transitions_inbin = np.array(dat_segmented['transition_times'])[bin_idxs]
        X = np.squeeze(dat['spike_rates'])
        X = np.array([X[max(start_time, 0):min(end_time + 1, spike_rates.shape[0])] for start_time, end_time in transitions_inbin])
        Y = np.squeeze(dat['behavior'])
        Y= np.array([Y[max(start_time, 0):min(end_time + 1, spike_rates.shape[0])] for start_time, end_time in transitions_inbin])

        cv = KFold(n_splits=5, shuffle=False)
        fold_idx = 0
        for train_idxs, test_idxs in cv.split(X, Y):
            
            Xtrain = list(X[train_idxs])
            Xtest = list(X[test_idxs])
            Ytrain = list(Y[train_idxs])
            Ytest = list(Y[test_idxs])

            _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **decoder_args)
            
            U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
            
            results_dict = {}
            results_dict['data_file'] = data_file.split('/')[-1]
            results_dict['coef'] = Vh.T


            ### CALCULATING LOADINGS ONTO THE VELCOCITY PREDICTION
            results_dict['loadings'] = calc_loadings(Vh.T[:, 2:4], decoder_args['decoding_window'])
            for k, v in decoder_args.items():
                results_dict[k] = v
            for k, v in sabes_args.loader_args[0].items():
                results_dict[k] = v

            results_dict['fold_idx'] = fold_idx        
            results_dict['orientation_bin_index'] = j
            results_list.append(results_dict)
            fold_idx += 1


0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.01s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.55s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.67s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:30<00:00, 30.85s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.38s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.52s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.30s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.19s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.70s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.49s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.65s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.56s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.22s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.75s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.79s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.60s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.40s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.71s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.61s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.39s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.75s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.08s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.17s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.23s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.55s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.46s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.69s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.34s/it]
  m = straight[1]/straight[0]
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype,

In [62]:
segmented_supervised_df = pd.DataFrame(results_list)

In [65]:
with open('/home/akumar/nse/neural_control/data/sabes_segmented_supervised.dat', 'wb') as f:
    f.write(pickle.dumps(segmented_supervised_df))

In [None]:
## Peanut

In [5]:
from loaders import segment_peanut

In [12]:
data_file = '/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj'
# Use a fixed set of preprocessing/decoding parameters
#sys.path.append('/home/akumar/nse/neural_control/submit_files')
peanut_args = importlib.import_module('peanut_kca_args')
peanut_decoding_args = importlib.import_module('peanut_decoding_args')

with open('/home/akumar/nse/neural_control/data/peanut_kca_decoding_df.dat', 'rb') as f:
    peanut_decoding_df = pickle.load(f)

peanut_loader_args = peanut_args.loader_args[0]
peanut_loader_args.pop('epoch')
epochs = np.arange(2, 18, 2)


In [13]:
decoding_args = {'trainlag':0, 'testlag':0, 'decoding_window':3}

In [None]:
load

In [29]:
results_list = []
for i, epoch in tqdm(enumerate(epochs)):

    # Using defaults
    dat = load_peanut(data_file, epoch, **peanut_loader_args)
    t1, t2 = segment_peanut(dat, '/mnt/Secondary/data/peanut/linearization_dict_peanut_day14.obj', epoch)

    X = np.array([np.squeeze(dat['spike_rates'])[tidxs] for tidxs in t1])        
    Y = np.array([np.squeeze(dat['behavior'])[tidxs] for tidxs in t1])

    pdb.set_trace()

    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = list(X[train_idxs])
        Xtest = list(X[test_idxs])
        Ytrain = list(Y[train_idxs])
        Ytest = list(Y[test_idxs])

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **decoding_args)
        
        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['coef'] = Vh.T
        results_dict['loadings'] = calc_loadings(Vh.T, decoding_args['decoding_window'])
        results_dict['epoch'] = epoch
        for k, v in peanut_loader_args.items():
            results_dict[k] = v
        for k, v in decoding_args.items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_dict['transition_type'] = 1
        results_list.append(results_dict)
        fold_idx += 1

    X = np.array([np.squeeze(dat['spike_rates'])[tidxs] for tidxs in t2])        
    Y = np.array([np.squeeze(dat['behavior'])[tidxs] for tidxs in t2])

    cv = KFold(n_splits=5, shuffle=False)
    fold_idx = 0
    for train_idxs, test_idxs in cv.split(X, Y):
        
        Xtrain = list(X[train_idxs])
        Xtest = list(X[test_idxs])
        Ytrain = list(Y[train_idxs])
        Ytest = list(Y[test_idxs])

        _, _, _, lm = lr_decoder(Xtest, Xtrain, Ytest, Ytrain, **decoding_args)
        
        if np.any(np.isnan(lm.coef_)):
            pdb.set_trace()

        U, S, Vh = scipy.linalg.svd(lm.coef_, full_matrices=False)
        results_dict = {}
        results_dict['coef'] = Vh.T
        results_dict['loadings'] = calc_loadings(Vh.T, decoding_args['decoding_window'])

        results_dict['epoch'] = epoch
        for k, v in peanut_loader_args.items():
            results_dict[k] = v
        for k, v in decoding_args.items():
            results_dict[k] = v

        results_dict['fold_idx'] = fold_idx        
        results_dict['transition_type'] = 2
        results_list.append(results_dict)
        fold_idx += 1



  
  if __name__ == '__main__':


> [0;32m/tmp/ipykernel_7658/3265002900.py[0m(13)[0;36m<module>[0;34m()[0m
[0;32m     11 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m    [0mcv[0m [0;34m=[0m [0mKFold[0m[0;34m([0m[0mn_splits[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m [0mshuffle[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m    [0mfold_idx[0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m    [0;32mfor[0m [0mtrain_idxs[0m[0;34m,[0m [0mtest_idxs[0m [0;32min[0m [0mcv[0m[0;34m.[0m[0msplit[0m[0;34m([0m[0mX[0m[0;34m,[0m [0mY[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
(19,)
(19,)
(162, 54)


0it [02:18, ?it/s]


BdbQuit: 

In [25]:
segmented_supervised_df = pd.DataFrame(results_list)

In [26]:
with open('/home/akumar/nse/neural_control/data/peanut_segmented_supervised.dat', 'wb') as f:
    f.write(pickle.dumps(segmented_supervised_df))

In [27]:
segmented_supervised_df.iloc[0]

coef               [[-1.9392263429110821e-16, -1.439703190933491e...
loadings           [0.02646321329926449, 1.0, 0.08160567209556628...
epoch                                                              2
bin_width                                                         25
filter_fn                                                       none
filter_kwargs                                                     {}
boxcox                                                           0.5
spike_threshold                                                  200
speed_threshold                                                    4
trainlag                                                           0
testlag                                                            0
decoding_window                                                    3
fold_idx                                                           0
transition_type                                                    1
Name: 0, dtype: object