# seq2seq training with DafaFile datasets
Michael Nolan - 2020.09.11.3125

In [1]:
import aopy
import ecog_is2s

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

from functools import partial

import os.path as path
import glob

# modules that aren't done yet
sys.path.append('C:\\Users\\mickey\\aoLab\\code\\py4sid')
# sys.path.append('/home/mickey/analyze/')
import estimation
import util

In [None]:
# get data files and create datafile objects
data_path_root = 'E:\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1'
data_path_day = path.join(data_path_root,'180325')
data_file_list = glob.glob(path.join(data_path_day,'*\\*ECOG*clfp_ds250_fl0u10.dat'))
print(f'files found:\t{len(data_file_list)}')
print(f'files: {data_file_list}')
datafile_list = [aopy.data.DataFile(df) for df in data_file_list]


In [None]:
# set time parameters - source, target and step lengths
src_t = 1.0
trg_t = 0.5
step_t = 0.5
scale_factor = 0.25
transform = partial(aopy.data.data_transform_normalize,scale_factor=scale_factor)
datafile_concat_dataset = aopy.data.DatafileConcatDataset([aopy.data.DatafileDataset(df,src_t,trg_t,step_t,transform=transform) for df in datafile_list])
srate = datafile_concat_dataset.srate

In [None]:
src,trg = datafile_concat_dataset.__getitem__(1000)
src_time = np.arange(src.shape[-1])/srate
trg_time = np.arange(trg.shape[-1])/srate + src_t
plot_ch_idx = 11
f,ax = plt.subplots(1,1,dpi=100,figsize=(10,4))
ax.plot(src_time,src[plot_ch_idx,:],label='src')
ax.plot(trg_time,trg[plot_ch_idx,:],label='trg')
ax.set_xlabel('time (s)')
ax.set_ylabel('(a.u.)')
ax.set_title('Normalized Data Sample')
ax.legend(loc=0)

## Linear Methods - baselines for comparison
So: we have a data sampling interface that gives us access to the entire first day's data simultaneously. Those samples are normalized to be ~ in the range of \[-1, 1\].

Now that I have that, I can easily (!) test out some linear prediction models to get baseline prediction method performance measures. The first and most basic (really basic) of these is a sample-wise MSE estimate of one-step signal dynamics. Here's an implementation of that:

In [None]:
import tqdm

def mse_est(datafile_concat_dataset):
    n_sample = len(datafile_concat_dataset)
    fve = np.zeros((datafile_concat_dataset.n_ch,n_sample))
    n_ch = datafile_concat_dataset.n_ch
    A_hat_all = np.zeros((n_ch,n_ch,n_sample))
    for sample_idx in tqdm.tqdm(range(n_sample)):
        # get sample
        src, trg = datafile_concat_dataset.__getitem__(sample_idx)
        # estimate dynamics (MSE)
        X = src[:,:-1]
        Y = src[:,1:]
        A_hat = (Y @ X.T) @ np.linalg.inv(X @ X.T)
        A_hat_all[:,:,sample_idx] = A_hat
        # predict target activity
        out = np.zeros(trg.shape)
        out[:,0] = A_hat @ src[:,-1]
        for est_idx in range(1,trg.shape[-1]):
            out[:,est_idx] = A_hat @ out[:,est_idx-1]
        # measure error
        ss_err = np.var(trg-out, axis=-1)
        ss_trg = np.var(-trg.mean(axis=-1)[:,None] + trg, axis=-1)
        fve[:,sample_idx] = 1 - ss_err/ss_trg
        # # plot prediction
        # plt.plot(src_time,src[plot_ch_idx,:],label='src')
        # plt.plot(trg_time,trg[plot_ch_idx,:],label='trg')
        # plt.plot(trg_time,out[plot_ch_idx,:],label='out')
        # plt.ylim([-1,1])
        # plt.legend(loc=0)
        # plt.xlabel('time (s)')
        # plt.ylabel('(a.u.)')
        # plt.title(f'MSE Prediction (fve = {fve[plot_ch_idx]:0.2f})')
        # f = plt.gcf()
    return fve, A_hat_all

fve_mse, A_hat = mse_est(datafile_concat_dataset)

In [None]:
plt.hist(fve_mse[plot_ch_idx,fve_mse[plot_ch_idx,:]>-1.0],100,label='FVE')
plt.axvline(1.0,color='r',label='max')
plt.legend(loc=0)
plt.xlabel('FVE')
plt.title('MSE FVE, single-trial estimate')

...not great! Let's take a look at the best case:

In [None]:
# _best_sample_idx = np.nanargmax(fve_mse)
# best_ch_idx = _best_sample_idx // len(datafile_concat_dataset)
# best_sample_idx = _best_sample_idx % len(datafile_concat_dataset)
# # get sample
# src, trg = datafile_concat_dataset.__getitem__(best_sample_idx+2)
# # estimate dynamics (MSE)
# X = src[:,:-1]
# Y = src[:,1:]
# A_hat = (Y @ X.T) @ np.linalg.inv(X @ X.T)
# # predict target activity
# out = np.zeros(trg.shape)
# out[:,0] = A_hat @ src[:,-1]
# for est_idx in range(1,trg.shape[-1]):
#     out[:,est_idx] = A_hat @ out[:,est_idx-1]
# # measure error
# ss_err = np.var(trg-out, axis=-1)
# ss_trg = np.var(-trg.mean(axis=-1)[:,None] + trg, axis=-1)
# # fve[:,sample_idx] = 1 - ss_err/ss_trg
# # plot prediction
# plt.plot(src_time,src[best_ch_idx,:],label='src')
# plt.plot(trg_time,trg[best_ch_idx,:],label='trg')
# plt.plot(trg_time,out[best_ch_idx,:],label='out')
# plt.ylim([-1,1])
# plt.legend(loc=0)
# plt.xlabel('time (s)')
# plt.ylabel('(a.u.)')
# plt.title(f'Best MSE Prediction (fve = {fve_mse[best_ch_idx,best_sample_idx]:0.2f})')
# # f = plt.gcf()

In [None]:
cmap = plt.cm.coolwarm
f,ax = plt.subplots(1,1,dpi=80)
im = ax.imshow(A_hat[:,:,-5000],cmap)
im.set_clim(vmin=-1.1,vmax=1.1)
plt.colorbar(im)
# ax[1].imshow(A_hat.std(axis=-1),origin='lower')

In [None]:
plt.hist(np.log10(np.linalg.norm(A_hat,axis=(0,1))),500);

## Subspace ID

In [None]:
from synthetic_data import rand_lds_and_data
from estimation import estimate_parameters_4sid, estimate_parameters_moments
from util import plot_eigvals, normalize, plot_singularvals

# x: LDS
# y: measurement

n, p = 16, 8 # x dimensions, y dimensions
T = 30000 # time points

## generate a system and simulate from it
(A,B,C,D), (x,y) = rand_lds_and_data(T,n,p,eig_min=0.5,eig_max=1.0)

In [None]:
# use the estimation code to get a linear system estimate:
trial_idx = 1003
src,trg = datafile_concat_dataset.__getitem__(trial_idx)
lags = 80
latent_dims = 10
A_hat, C_hat = estimate_parameters_moments(src.T,lags,latent_dims)
x0 = np.linalg.pinv(C_hat).dot(trg[:,0])
x_out = np.zeros((latent_dims,trg.shape[-1]))
x_out[:,0] = x0
y_out = np.zeros(trg.shape)
y_out[:,0] = C_hat @ x0
for out_idx in range(1,trg.shape[-1]):
    x_out[:,out_idx] = A_hat @ x_out[:,out_idx-1]
    y_out[:,out_idx] = C_hat @ x_out[:,out_idx]

In [None]:
# plot target and
plt_ch_idx = 0 
trg_time = np.arange(500)
plt.plot(trg_time,trg[plt_ch_idx,:],label='trg')
plt.plot(trg_time,y_out[plt_ch_idx,:],label='out')
plt.legend(loc=0)
plt.xlabel('time (ms)')
plt.ylabel(f'ch {plt_ch_idx} (a.u.)')