In [None]:
import os.path as path
from os import makedirs, chmod
import glob
import functools

import time
import datetime
import tqdm

import aopy
import ecog_is2s

import numpy as np
import scipy as sp
import pandas as pd
import torch
import torch.optim as optim
from torchvision.transforms import Compose

import matplotlib.pyplot as plt

import sys

In [None]:
# get a single file data array
# file list to dataset
data_path_root = 'C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1'
data_path_day = path.join(data_path_root,'18032*')
data_file_list = glob.glob(path.join(data_path_day,'0[0-9]*\\*ECOG*clfp_ds250_fl0u10.dat'))
# data_file_list = glob.glob(path.join(data_path_day,'0[0-9]*\\*ECOG*clfp_ds250.dat'))
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print('mounting to device: {}'.format(device))
print(f'files found:\t{len(data_file_list)}')
print(f'files: {data_file_list}')
datafile_list = [aopy.data.DataFile(df) for df in data_file_list]
file_idx = 6
data = datafile_list[file_idx].read().T
mask = datafile_list[file_idx].data_mask
srate = datafile_list[file_idx].srate

In [None]:
fill_mask = sp.signal.convolve(mask,np.ones(60*srate,dtype=bool),mode='same')

In [None]:
datafile_list[5].data_file_path

In [None]:
# get test/train splits
data = data[~fill_mask,:]
ch_var = data.var(axis=0)
ch_var_mean = ch_var.mean()
ch_var_std = ch_var.std()
ch_idx = np.logical_and(ch_var > ch_var_mean-1.5*ch_var_std, ch_var < ch_var_mean+1.5*ch_var_std)
n_samples, _ = data.shape
n_ch = sum(ch_idx)
train_test_valid_frac = (.8, 0.1, 0.1)
n_train_samples = round(n_samples*train_test_valid_frac[0])
n_valid_samples = round(n_samples*train_test_valid_frac[1])
n_test_samples = round(n_samples*train_test_valid_frac[2])
train_idx = np.arange(0,n_train_samples)
valid_idx = np.arange(n_train_samples,n_train_samples+n_valid_samples)
test_idx = np.arange(n_train_samples+n_valid_samples,n_samples)
train_data = data[train_idx,:][:,ch_idx]
valid_data = data[valid_idx,:][:,ch_idx]
test_data = data[test_idx,:][:,ch_idx]
train_data = train_data - train_data.mean()
valid_data = valid_data - valid_data.mean()
test_data = test_data - test_data.mean()

In [None]:
print(data.shape)
print(train_data.shape)
print(test_data.shape)

In [None]:
# AR model
from statsmodels.tsa.api import VAR
model = VAR(train_data)
# model.select_order(10)
model_fit = model.fit(10)
# model = VARMAX(data, order=(10, 10))
# model_fit = model.fit(disp=False)

In [None]:
fit_order = model_fit.k_ar
# print(fit_order)

In [None]:
start_idx = 20000
pred_data = model_fit.forecast(test_data[start_idx:start_idx+fit_order,:], steps=250)

In [None]:
plt.plot(np.arange(srate)/srate,test_data[start_idx+fit_order:start_idx+250+fit_order,0],label='trg')
plt.plot(np.arange(srate)/srate,pred_data[:,0],label='pred');
plt.xlabel('time (s)')
plt.ylabel('amp ($\mu$V)')
plt.title(f'p={fit_order} (r = {np.corrcoef(test_data[start_idx+fit_order:start_idx+250+fit_order,0],pred_data[:,0])[0,1]:0.3f})')
plt.legend(loc=0)

In [None]:
pred_window_T = 1
pred_window_n = pred_window_T*srate + fit_order
n_pred_window = len(test_idx)//pred_window_n
mse = np.empty((n_pred_window,n_ch))
rpe = np.empty((n_pred_window,n_ch))
corr = np.empty((n_pred_window,n_ch))
ft_corr = np.empty((n_pred_window,n_ch))
p_lim = [2.5,97.5]
for pred_win_idx in tqdm.tqdm(range(n_pred_window)):
    window_idx = pred_win_idx*pred_window_n + np.arange(pred_window_n)
    data_window = data[window_idx,:]
    pred = model_fit.forecast(data_window[:fit_order,:],steps=pred_window_n-fit_order)
    # mse
    mse[pred_win_idx,:] = np.sqrt(np.mean((data_window[fit_order:,:] - pred)**2, axis=0))
    data_std = data_window[fit_order:,:].std(axis=0)
    rpe[pred_win_idx,:] = mse[pred_win_idx,:]/data_std
    corr[pred_win_idx,:] = np.diag(np.corrcoef(pred,data_window[fit_order:,:],rowvar=False)[:n_ch,n_ch:]) #take full corrcoef matrix, cut to cross-terms, take the diagonal.
ft_corr = np.arctanh(corr)
mse_mean = mse.mean(axis=0)
mse_95ci = np.percentile(mse,p_lim,axis=0)
rpe_mean = rpe.mean(axis=0)
rpe_95ci = np.percentile(rpe,p_lim,axis=0)
ft_corr_mean = ft_corr.mean(axis=0)
ft_corr_95ci = np.percentile(ft_corr,p_lim,axis=0)

In [None]:
# pred - real correlation, fisher transform
def metric_channel_plot(metric,p_val,title,ylabel):
    metric_mean = metric.mean(axis=0)
    metric_95ci = np.percentile(metric,p_val,axis=0)
    f = plt.figure(figsize=(20,2),dpi=100)
    plt.violinplot(metric);
    plt.plot(np.arange(1,n_ch+1),metric_mean,'.',label='mean')
    plt.plot(np.arange(1,n_ch+1),metric_95ci.T,'.',label='95% CI')
    plt.xlabel('ch.')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend(loc=0)
    return f
f = metric_channel_plot(ft_corr,p_lim,f'Fisher Transform Correlation (p = {fit_order})','z')
f.savefig(f'td_corr_p{fit_order}_10Hz.png')

In [None]:
f = metric_channel_plot(rpe,p_lim,f'Relative Prediction Error (p = {fit_order})','rpe')
plt.ylim(0,5)
plt.axhline(1,color='k',alpha=0.5)
f.savefig(f'rpe_p{fit_order}_10Hz.png')

In [None]:
def compute_err_metrics(pred,trg,axis=0):
    assert pred.shape == trg.shape, "pred and trg arrays must have equal shape"
    n_sample, n_ch = pred.shape
    err = pred-trg
    # mean square error
    mse = np.sqrt(np.mean(err**2,axis=axis))
    # relative prediction error, MSE scaled by target \sigma
    trg_std = trg.std(axis=axis)
    rpe = mse/trg_std
    # element-wise correlation coefficient
    corr = np.diag(np.corrcoef(pred,trg,rowvar=False)[:n_ch,n_ch:])
    return mse, rpe, corr

def compute_prediction_metrics(test_data,model_fit,pred_window_T,bin_T,p_lim=[2.5, 97.5],srate=250):
    n_sample, n_ch = test_data.shape
    time = np.arange(pred_window_T*srate)/srate
    bin_T_left_edge = np.arange(pred_window_T,step=bin_T)
    bin_T_right_edge = bin_T_left_edge + bin_T
    fit_order = model_fit.k_ar
    pred_window_n = int(pred_window_T*srate + fit_order)
    n_pred_window = len(test_idx)//pred_window_n
    n_time_bin = len(bin_T_left_edge)
    mse = np.empty((n_pred_window,n_ch))
    mse_all = np.empty((n_pred_window))
    mse_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    mse_bin_all = np.empty((n_pred_window,n_time_bin))
    rpe = np.empty((n_pred_window,n_ch))
    rpe_all = np.empty((n_pred_window))
    rpe_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    rpe_bin_all = np.empty((n_pred_window,n_time_bin))
    corr = np.empty((n_pred_window,n_ch))
    # corr_all = np.empty((n_pred_window))
    corr_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    # corr_bin_all = np.empty((n_pred_window,n_time_bin))
    for pred_win_idx in tqdm.tqdm(range(n_pred_window)):
        window_idx = pred_win_idx*pred_window_n + np.arange(pred_window_n)
        data_window = test_data[window_idx,:]
        pred = model_fit.forecast(data_window[:fit_order,:],steps=pred_window_n-fit_order)
        ## time bins
        for tb_idx in range(n_time_bin):
            bin_idx = np.logical_and(time >= bin_T_left_edge[tb_idx], time < bin_T_right_edge[tb_idx])
            mse_bin[pred_win_idx,:,tb_idx], rpe_bin[pred_win_idx,:,tb_idx], corr_bin[pred_win_idx,:,tb_idx] = compute_err_metrics(pred[bin_idx,:],data_window[fit_order:,:][bin_idx,:])
            mse_bin_all[pred_win_idx,tb_idx], rpe_bin_all[pred_win_idx,tb_idx], _ = compute_err_metrics(pred[bin_idx,:],data_window[fit_order:,:][bin_idx,:],axis=(0,1))
        mse[pred_win_idx,:], rpe[pred_win_idx,:], corr[pred_win_idx,:] = compute_err_metrics(pred,data_window[fit_order:,:])
        mse_all[pred_win_idx], rpe_all[pred_win_idx], _ = compute_err_metrics(pred,data_window[fit_order:,:],axis=(0,1))
        trg_fft = np.fft.fft(data_window[fit_order:,:],axis=0)[:int(pred_window_T*srate/2),:]
        pred_fft = np.fft.fft(pred,axis=0)[:int(pred_window_T*srate/2),:]
        f_fft = np.fft.fftfreq(srate,d=1/srate)[:int(pred_window_T*srate/2)]
        # add coherence stats here!
    # get stats from sample distributions
    stat_dict = {
        'mse_mean': mse.mean(axis=0),
        'mse_95ci': np.percentile(mse,p_lim,axis=0),
        'mse_bin_mean': mse_bin.mean(axis=0),
        'mse_bin_95ci': np.percentile(mse_bin,p_lim,axis=0),
        'rpe_mean': rpe.mean(axis=0),
        'rpe_95ci': np.percentile(rpe,p_lim,axis=0),
        'rpe_bin_mean': rpe_bin.mean(axis=0),
        'rpe_bin_95ci': np.percentile(rpe_bin,p_lim,axis=0),
        'corr_mean': np.tanh(np.arctanh(corr).mean(axis=0)),
        'corr_95ci': np.percentile(corr,p_lim,axis=0),
        'corr_bin_mean': np.tanh(np.arctanh(corr_bin).mean(axis=0)),
        'corr_bin_95ci': np.percentile(corr_bin,p_lim,axis=0)
    }
    stat_dict_all = {
        'mse_mean': mse_all.mean(axis=0),
        'mse_95ci': np.percentile(mse_all,p_lim,axis=0),
        'mse_bin_mean': mse_bin_all.mean(axis=0),
        'mse_bin_95ci': np.percentile(mse_bin_all,p_lim,axis=0),
        'rpe_mean': rpe_all.mean(axis=0),
        'rpe_95ci': np.percentile(rpe_all,p_lim,axis=0),
        'rpe_bin_mean': rpe_bin_all.mean(axis=0),
        'rpe_bin_95ci': np.percentile(rpe_bin_all,p_lim,axis=0),
    }
    return stat_dict, stat_dict_all, bin_T_left_edge

# def compute_agg_stat_dataframe(stat_dict,)

pred_window_T = 1.0
bin_T = 0.1
metric_stat_dict, metric_stat_dict_all, bin_time = compute_prediction_metrics(test_data,model_fit,pred_window_T,bin_T)

In [None]:
plt.fill_between(bin_time,metric_stat_dict['rpe_bin_95ci'][0,:],metric_stat_dict['rpe_bin_95ci'][1,:],alpha=0.3)
plt.plot(bin_time,metric_stat_dict['rpe_bin_mean']);
plt.xlabel('time (s)')

In [None]:
for ch_idx in range(n_ch):
    plt.fill_between(bin_time,metric_stat_dict['corr_bin_95ci'][0,ch_idx,:],metric_stat_dict['corr_bin_95ci'][1,ch_idx,:],alpha=0.3)
    plt.plot(bin_time,metric_stat_dict['corr_bin_mean'][ch_idx,:]);
plt.axhline(1,color='k',linestyle=':')
plt.ylim(0,1.1)
plt.xlabel('time (s)')
plt.ylabel('RPE (a.u.)')
plt.title(f'Relative Prediction Accuracy, p = {fit_order}')

In [None]:
def create_metric_stat_table(stat_dict):
    df = pd.DataFrame(data = {
        'mse_mean': [stat_dict['mse_mean']],
        'mse_ci_2.5': [stat_dict['mse_95ci'][0,]],
        'mse_ci_97.5': [stat_dict['mse_95ci'][1,]],
        'rpe_mean': [stat_dict['rpe_mean']],
        'rpe_ci_2.5': [stat_dict['rpe_95ci'][0,]],
        'rpe_ci_97.5': [stat_dict['rpe_95ci'][1,]]
    })

    return df

stat_df = create_metric_stat_table(metric_stat_dict)
stat_all_df = create_metric_stat_table(metric_stat_dict_all)

In [None]:
create_metric_stat_table(metric_stat_dict)

In [None]:
metric_stat_dict_all['mse_bin_mean']