In [None]:
import os.path as path
from os import makedirs, chmod
import glob
import functools

import time
import datetime
import tqdm

import aopy
import ecog_is2s

import numpy as np
from numpy.matlib import repmat
import scipy as sp
import pandas as pd
import torch
import torch.optim as optim
from torchvision.transforms import Compose

from statsmodels.tsa.api import VAR

import matplotlib.pyplot as plt

import sys

In [None]:
def partition_data(data,train_valid_test_frac=(0.8,0.2,0.0)):
    # channel mask
    ch_var = data.var(axis=0)
    ch_var_mean = ch_var.mean()
    ch_var_std = ch_var.std()
    # ch_idx = np.logical_and(ch_var > ch_var_mean-1.5*ch_var_std, ch_var < ch_var_mean+1.5*ch_var_std)
    ch_idx = [28,27,5,24,57,43,8,31,41,37,44,47,59,9,49,20,53,45,25,14,13,40,21,35,39,48,58,52,33,46,56,17,60,30,23,61,15,34,54,51,10,42]
    # partition data
    n_samples, _ = data.shape
    n_ch = sum(ch_idx)
    n_train_samples = round(n_samples*train_valid_test_frac[0])
    n_valid_samples = round(n_samples*train_valid_test_frac[1])
    n_test_samples = round(n_samples*train_valid_test_frac[2])
    train_idx = np.arange(0,n_train_samples)
    valid_idx = np.arange(n_train_samples,n_train_samples+n_valid_samples)
    test_idx = np.arange(n_train_samples+n_valid_samples,n_samples)
    train_data = data[train_idx,:][:,ch_idx]
    valid_data = data[valid_idx,:][:,ch_idx]
    test_data = data[test_idx,:][:,ch_idx]
    train_data = train_data - train_data.mean()
    valid_data = valid_data - valid_data.mean()
    test_data = test_data - test_data.mean()
    return train_data, valid_data, test_data, ch_idx

def train_AR_model(train_data,model_order):
    model = VAR(train_data)
    model_fit = model.fit(model_order)
    return model_fit

def compute_err_metrics(pred,trg,axis=0):
    assert pred.shape == trg.shape, "pred and trg arrays must have equal shape"
    n_sample, n_ch = pred.shape
    err = pred-trg
    # mean square error
    mse = np.mean(err**2,axis=axis)
    # root mean square error
    rmse = np.sqrt(mse)
    # mean absolute error
    mae = np.mean(np.abs(err),axis=axis)
    # relative prediction error, MSE scaled by target \sigma
    trg_std = trg.std(axis=axis)
    rpe = rmse/trg_std
    # normalized mae
    nmae = mae/np.abs(trg).mean(axis=axis)
    # element-wise correlation coefficient
    corr = np.diag(np.corrcoef(pred,trg,rowvar=False)[:n_ch,n_ch:])
    return mse, rmse, rpe, corr, mae, nmae

def compute_prediction_metrics(test_data,model_fit,pred_window_T,bin_T,p_lim=[2.5, 97.5],srate=250):
    n_sample, n_ch = test_data.shape
    time = np.arange(pred_window_T*srate)/srate
    bin_T_left_edge = np.arange(pred_window_T,step=bin_T)
    bin_T_right_edge = bin_T_left_edge + bin_T
    fit_order = model_fit.k_ar
    pred_window_n = int(pred_window_T*srate + fit_order)
    n_pred_window = n_sample//pred_window_n
    n_time_bin = len(bin_T_left_edge)
    mse = np.empty((n_pred_window,n_ch))
    mse_all = np.empty((n_pred_window))
    mse_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    mse_bin_all = np.empty((n_pred_window,n_time_bin))
    rmse = np.empty((n_pred_window,n_ch))
    rmse_all = np.empty((n_pred_window))
    rmse_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    rmse_bin_all = np.empty((n_pred_window,n_time_bin))
    mae = np.empty((n_pred_window,n_ch))
    mae_all = np.empty((n_pred_window))
    mae_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    mae_bin_all = np.empty((n_pred_window,n_time_bin))
    nmae = np.empty((n_pred_window,n_ch))
    nmae_all = np.empty((n_pred_window))
    nmae_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    nmae_bin_all = np.empty((n_pred_window,n_time_bin))
    rpe = np.empty((n_pred_window,n_ch))
    rpe_all = np.empty((n_pred_window))
    rpe_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    rpe_bin_all = np.empty((n_pred_window,n_time_bin))
    corr = np.empty((n_pred_window,n_ch))
    # corr_all = np.empty((n_pred_window))
    corr_bin = np.empty((n_pred_window,n_ch,n_time_bin))
    # corr_bin_all = np.empty((n_pred_window,n_time_bin))
    for pred_win_idx in tqdm.tqdm(range(n_pred_window)):
        window_idx = pred_win_idx*pred_window_n + np.arange(pred_window_n)
        data_window = test_data[window_idx,:]
        pred = model_fit.forecast(data_window[:fit_order,:],steps=pred_window_n-fit_order)
        ## time bins
        for tb_idx in range(n_time_bin):
            bin_idx = np.logical_and(time >= bin_T_left_edge[tb_idx], time < bin_T_right_edge[tb_idx])
            mse_bin[pred_win_idx,:,tb_idx], rmse_bin[pred_win_idx,:,tb_idx], rpe_bin[pred_win_idx,:,tb_idx], corr_bin[pred_win_idx,:,tb_idx], mae_bin[pred_win_idx,:,tb_idx], nmae_bin[pred_win_idx,:,tb_idx] = compute_err_metrics(pred[bin_idx,:],data_window[fit_order:,:][bin_idx,:])
            mse_bin_all[pred_win_idx,tb_idx], rmse_bin_all[pred_win_idx,tb_idx], rpe_bin_all[pred_win_idx,tb_idx], _, mae_bin_all[pred_win_idx,tb_idx], nmae_bin_all[pred_win_idx,tb_idx] = compute_err_metrics(pred[bin_idx,:],data_window[fit_order:,:][bin_idx,:],axis=(0,1))
        mse[pred_win_idx,:], rmse[pred_win_idx,:], rpe[pred_win_idx,:], corr[pred_win_idx,:], mae[pred_win_idx,:], nmae[pred_win_idx,:] = compute_err_metrics(pred,data_window[fit_order:,:])
        mse_all[pred_win_idx], rmse[pred_win_idx], rpe_all[pred_win_idx], _, mae_all[pred_win_idx], nmae_all[pred_win_idx] = compute_err_metrics(pred,data_window[fit_order:,:],axis=(0,1))
        trg_fft = np.fft.fft(data_window[fit_order:,:],axis=0)[:int(pred_window_T*srate/2),:]
        pred_fft = np.fft.fft(pred,axis=0)[:int(pred_window_T*srate/2),:]
        f_fft = np.fft.fftfreq(srate,d=1/srate)[:int(pred_window_T*srate/2)]
        # add coherence stats here!
    # get stats from sample distributions
    stat_dict = {
        'mse_mean': mse.mean(axis=0),
        'mse_95ci': np.percentile(mse,p_lim,axis=0),
        'mse_bin_mean': mse_bin.mean(axis=0),
        'mse_bin_95ci': np.percentile(mse_bin,p_lim,axis=0),
        'rmse_mean': rmse.mean(axis=0),
        'rmse_95ci': np.percentile(rmse,p_lim,axis=0),
        'rmse_bin_mean': rmse_bin.mean(axis=0),
        'rmse_bin_95ci': np.percentile(rmse_bin,p_lim,axis=0),
        'rpe_mean': rpe.mean(axis=0),
        'rpe_95ci': np.percentile(rpe,p_lim,axis=0),
        'rpe_bin_mean': rpe_bin.mean(axis=0),
        'rpe_bin_95ci': np.percentile(rpe_bin,p_lim,axis=0),
        'mae_mean': mae.mean(axis=0),
        'mae_95ci': np.percentile(mae,p_lim,axis=0),
        'mae_bin_mean': mae_bin.mean(axis=0),
        'mae_bin_95ci': np.percentile(mae_bin,p_lim,axis=0),
        'nmae_mean': nmae.mean(axis=0),
        'nmae_95ci': np.percentile(nmae,p_lim,axis=0),
        'nmae_bin_mean': nmae_bin.mean(axis=0),
        'nmae_bin_95ci': np.percentile(nmae_bin,p_lim,axis=0),
        'corr_mean': np.tanh(np.arctanh(corr).mean(axis=0)),
        'corr_95ci': np.percentile(corr,p_lim,axis=0),
        'corr_bin_mean': np.tanh(np.arctanh(corr_bin).mean(axis=0)),
        'corr_bin_95ci': np.percentile(corr_bin,p_lim,axis=0)
    }
    stat_dict_all = {
        'mse_mean': mse_all.mean(axis=0),
        'mse_95ci': np.percentile(mse_all,p_lim,axis=0),
        'mse_bin_mean': mse_bin_all.mean(axis=0),
        'mse_bin_95ci': np.percentile(mse_bin_all,p_lim,axis=0),
        'rmse_mean': rmse_all.mean(axis=0),
        'rmse_95ci': np.percentile(rmse_all,p_lim,axis=0),
        'rmse_bin_mean': rmse_bin_all.mean(axis=0),
        'rmse_bin_95ci': np.percentile(rmse_bin_all,p_lim,axis=0),
        'rpe_mean': rpe_all.mean(axis=0),
        'rpe_95ci': np.percentile(rpe_all,p_lim,axis=0),
        'rpe_bin_mean': rpe_bin_all.mean(axis=0),
        'rpe_bin_95ci': np.percentile(rpe_bin_all,p_lim,axis=0),
        'mae_mean': mae_all.mean(axis=0),
        'mae_95ci': np.percentile(mae_all,p_lim,axis=0),
        'mae_bin_mean': mae_bin_all.mean(axis=0),
        'mae_bin_95ci': np.percentile(mae_bin_all,p_lim,axis=0),
        'nmae_mean': nmae_all.mean(axis=0),
        'nmae_95ci': np.percentile(nmae_all,p_lim,axis=0),
        'nmae_bin_mean': nmae_bin_all.mean(axis=0),
        'nmae_bin_95ci': np.percentile(nmae_bin_all,p_lim,axis=0),
    }
    return stat_dict, stat_dict_all, bin_T_left_edge

def create_metric_stat_table(stat_dict):
    df = pd.DataFrame(data = {
        'mse_mean': [stat_dict['mse_mean']],
        'mse_ci_2.5': [stat_dict['mse_95ci'][0,]],
        'mse_ci_97.5': [stat_dict['mse_95ci'][1,]],
        'rmse_mean': [stat_dict['rmse_mean']],
        'rmse_ci_2.5': [stat_dict['rmse_95ci'][0,]],
        'rmse_ci_97.5': [stat_dict['rmse_95ci'][1,]],
        'rpe_mean': [stat_dict['rpe_mean']],
        'rpe_ci_2.5': [stat_dict['rpe_95ci'][0,]],
        'rpe_ci_97.5': [stat_dict['rpe_95ci'][1,]],
        'mae_mean': [stat_dict['mae_mean']],
        'mae_ci_2.5': [stat_dict['mae_95ci'][0,]],
        'mae_ci_97.5': [stat_dict['mae_95ci'][1,]],
        'nmae_mean': [stat_dict['nmae_mean']],
        'nmae_ci_2.5': [stat_dict['nmae_95ci'][0,]],
        'nmae_ci_97.5': [stat_dict['nmae_95ci'][1,]],
        'corr_mean': [stat_dict['corr_mean']],
        'corr_ci_2.5': [stat_dict['corr_95ci'][0,]],
        'corr_ci_97.5': [stat_dict['corr_95ci'][1,]],
    })

    return df

def create_metric_all_stat_table(stat_dict):
    df = pd.DataFrame(data = {
        'mse_mean': [stat_dict['mse_mean']],
        'mse_ci_2.5': [stat_dict['mse_95ci'][0,]],
        'mse_ci_97.5': [stat_dict['mse_95ci'][1,]],
        'rmse_mean': [stat_dict['rmse_mean']],
        'rmse_ci_2.5': [stat_dict['rmse_95ci'][0,]],
        'rmse_ci_97.5': [stat_dict['rmse_95ci'][1,]],
        'rpe_mean': [stat_dict['rpe_mean']],
        'rpe_ci_2.5': [stat_dict['rpe_95ci'][0,]],
        'rpe_ci_97.5': [stat_dict['rpe_95ci'][1,]],
        'mae_mean': [stat_dict['mae_mean']],
        'mae_ci_2.5': [stat_dict['mae_95ci'][0,]],
        'mae_ci_97.5': [stat_dict['mae_95ci'][1,]],
        'nmae_mean': [stat_dict['nmae_mean']],
        'nmae_ci_2.5': [stat_dict['nmae_95ci'][0,]],
        'nmae_ci_97.5': [stat_dict['nmae_95ci'][1,]],
    })

    return df

def create_metric_bin_stat_table(stat_dict):
    df = pd.DataFrame(data = {
        'mse_bin_mean': [stat_dict['mse_bin_mean']],
        'mse_ci_2.5': [stat_dict['mse_bin_95ci'][0,]],
        'mse_ci_97.5': [stat_dict['mse_bin_95ci'][1,]],
        'rmse_bin_mean': [stat_dict['rmse_bin_mean']],
        'rmse_ci_2.5': [stat_dict['rmse_bin_95ci'][0,]],
        'rmse_ci_97.5': [stat_dict['rmse_bin_95ci'][1,]],
        'rpe_bin_mean': [stat_dict['rpe_bin_mean']],
        'rpe_ci_2.5': [stat_dict['rpe_bin_95ci'][0,]],
        'rpe_ci_97.5': [stat_dict['rpe_bin_95ci'][1,]],
        'mae_bin_mean': [stat_dict['mae_bin_mean']],
        'mae_ci_2.5': [stat_dict['mae_bin_95ci'][0,]],
        'mae_ci_97.5': [stat_dict['mae_bin_95ci'][1,]],
        'nmae_bin_mean': [stat_dict['nmae_bin_mean']],
        'nmae_ci_2.5': [stat_dict['nmae_bin_95ci'][0,]],
        'nmae_ci_97.5': [stat_dict['nmae_bin_95ci'][1,]],
    })

    return df

In [None]:
# get a single file data array
# file list to dataset
data_path_root = 'C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1'
data_path_day = path.join(data_path_root,'18032*')
data_file_list = glob.glob(path.join(data_path_day,'0[0-9]*\\*ECOG*clfp_ds250_fl0u10.dat'))
# data_file_list = glob.glob(path.join(data_path_day,'0[0-9]*\\*ECOG*clfp_ds250.dat'))
device = 'cpu'
print('mounting to device: {}'.format(device))
print(f'files found:\t{len(data_file_list)}')
print(f'files: {data_file_list}')
datafile_list = [aopy.data.DataFile(df) for df in data_file_list]


In [None]:
# loop across all files to:
#   - train AR model
#   - compute error metrics
#   - create error table
# concatenate all error tables
# save table

mask_buffer_T = 60
model_order = 10
pred_window_T = 1
bin_T = 0.1
stat_df_list = []
stat_all_df_list = []
stat_bin_df_list = []
file_used = np.zeros(len(datafile_list),dtype=bool)
file_path_used = []

date_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y%m%d%H%M%S')
analysis_path = f'D:\\Users\\mickey\\Data\\analysis\\prediction_p{model_order}_{pred_window_T}s_{date_str}'
if not path.exists(analysis_path):
    os.makedirs(analysis_path)
os.makedirs(path.join(analysis_path,'models'))

for df_idx, df in enumerate(datafile_list):
    print(f'({df_idx+1}/{len(datafile_list)}) file {df.data_file_path}')
    data = df.read().T
    mask = df.data_mask
    srate = df.srate
    if mask.mean() < 0.8:
        # try:
            file_used[df_idx] = True
            # widen masked regions - no chances
            print(f'masking data...')
            mask_buffer_n = mask_buffer_T*srate
            fill_mask = sp.signal.convolve(mask,np.ones(mask_buffer_n,dtype=bool),mode='same')
            data = data[~fill_mask,:]
            train_data, valid_data, test_data, ch_idx = partition_data(data)
            # train the linear model
            print('fitting model...')
            model_fit = train_AR_model(train_data,model_order)
            # compute validation metrics
            print('computing metrics:')
            stat_dict, stat_dict_all, bin_T_left_edge = compute_prediction_metrics(valid_data,model_fit,pred_window_T,bin_T,srate=srate)
            # convert to tables
            _stat_df = create_metric_stat_table(stat_dict)
            _stat_all_df = create_metric_all_stat_table(stat_dict_all)
            _stat_bin_df = create_metric_bin_stat_table(stat_dict_all)
            stat_df_list.append(_stat_df)
            stat_all_df_list.append(_stat_all_df)
            stat_bin_df_list.append(_stat_bin_df)
            # save the AR model
            model_fit.save(path.join(analysis_path,'models',f'model_{str(df_idx)}'))
            file_path_used.append(df.data_file_path)
        # except:
        #     continue
stat_df = pd.concat(stat_df_list)


In [None]:
stat_df['file_path'] = file_path_used
stat_all_df = pd.concat(stat_all_df_list)
stat_all_df['file_path'] = file_path_used
stat_bin_df = pd.concat(stat_bin_df_list)
stat_bin_df['file_path'] = file_path_used
# stat_bin_df['bin_t'] = repmat(bin_T_left_edge,1,file_used.sum())[0]

In [None]:
# save performance metric table for this analysis
stat_df.to_csv(path.join(analysis_path,'prediction_metric_stats.csv'))
stat_all_df.to_csv(path.join(analysis_path,'prediction_metric_all_stats.csv'))
stat_bin_df.to_csv(path.join(analysis_path,'prediction_metric_bin_stats.csv'))
param_dict = {
    'bin_t': bin_T_left_edge,
    'pred_t': pred_window_T,
    'bin_t': bin_T,
    'model_p': model_order
}
np.savez(path.join(analysis_path,'param.npz'),param_dict)