In [None]:
import os.path as path
from os import makedirs, chmod
import glob
import functools

import time
import datetime
import tqdm

import aopy
import ecog_is2s

import numpy as np
import torch
import torch.optim as optim
from torchvision.transforms import Compose

import matplotlib.pyplot as plt

In [None]:
# file list to dataset
data_path_root = 'C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1'
data_path_day = path.join(data_path_root,'18032[5-7]')
data_file_list = glob.glob(path.join(data_path_day,'0[0-9]*\\*ECOG*clfp_ds250.dat'))
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print('mounting to device: {}'.format(device))
print(f'files found:\t{len(data_file_list)}')
print(f'files: {data_file_list}')
datafile_list = [aopy.data.DataFile(df) for df in data_file_list]

In [None]:
src_t = 1.0
trg_t = 0.5
step_t = src_t+trg_t
diff_transform = ecog_is2s.Util.add_signal_diff() # no need for the srate parameter, dx est. is z-scored as well
zscore_transform = ecog_is2s.Util.local_zscore()
transform = lambda sample : diff_transform(zscore_transform(sample))
dfds_list = [aopy.data.DatafileDataset(df,src_t,trg_t,step_t,device=device) for df in datafile_list]
datafile_concatdataset = aopy.data.DatafileConcatDataset(dfds_list,transform=transform)

In [None]:
partition = (4,1,1)
batch_size = 500
train_loader, valid_loader, test_loader = datafile_concatdataset.get_data_loaders(partition=partition,batch_size=batch_size,rand_part=True)

In [None]:
n_ch = datafile_concatdataset.n_ch
n_unit = 2**9
n_layers = 1
dropout = 0.3
use_diff = True
bidirectional = False
model = ecog_is2s.Seq2Seq.Seq2Seq_GRU(input_dim=n_ch,hid_dim=n_unit,n_layers=n_layers,enc_len=0,dec_len=0,device=device,dropout=dropout,use_diff=use_diff,bidirectional=bidirectional).to(device)

In [None]:
# load parameters from file
model_path = "D:\\Users\\mickey\\Data\\models\\pyt\\seq2seq\\"
model_name = "enc1.0_dec0.5_srate250_20201010163509"
checkpoint_dict = torch.load(path.join(model_path,model_name,'checkpoint.pt'))
model.load_state_dict(checkpoint_dict['model_state_dict'])
model.device

In [None]:
def plot_trial_prediction(src,trg,out,srate=1,ch_idx=0,dpi=100,ax=None):
    n_t_src, n_ch = src.shape
    n_t_trg, _ = trg.shape
    time_src = np.arange(n_t_src)/srate
    time_trg = np.arange(n_t_trg)/srate + n_t_src/srate
    err = trg - out
    mse = (err**2).mean(axis=0)
    if not ax:
        f,ax = plt.subplots(1,1)
    ax.plot(time_src,src[:,ch_idx],label='src')
    ax.plot(time_trg,trg[:,ch_idx],label='trg')
    ax.plot(time_trg,out[:,ch_idx],label='out')
    ax.legend(loc=0)
    ax.set_xlabel('time')
    ax.set_ylabel('(a.u.)')
    ax.set_title(f'ch. {ch_idx}, mse = {mse[ch_idx]:0.4f}')
    return ax


In [None]:
sample_idx = 100
ch_idx = 20
src, trg = datafile_concatdataset.__getitem__(sample_idx)
out, enc, dec = model(src[None,:,:],trg[None,:,:])
out = out.detach().numpy()
enc = enc.detach().numpy()
dec = dec.detach().numpy()
f,ax = plt.subplots(2,1,dpi=100,constrained_layout=True)
plot_trial_prediction(src,trg,out[0,:,:],srate=datafile_concatdataset.srate,ch_idx=ch_idx,ax=ax[0])
plot_trial_prediction(enc[0,:,:],dec[0,:,:],np.zeros((dec.shape[1],dec.shape[2])),srate=datafile_concatdataset.srate,ch_idx=ch_idx,ax=ax[1])

In [None]:
def eval_sample(dataset,sample_idx,model):
    src, trg = dataset.__getitem__(sample_idx)
    out, enc, dec = model(src[None,:,:],trg[None,:,:])
    out = out.detach().numpy()
    enc = enc.detach().numpy()
    dec = dec.detach().numpy()
    # get error
    serr_time = (trg - out[0,:,:])**2
    mse_ch = np.sqrt(serr_time.mean(axis=0))
    mse = np.sqrt(serr_time.mean(axis=(0,1)))
    strg_time = (trg - trg.mean(axis=0))**2
    rpe_ch = mse_ch/np.sqrt(strg_time.mean(axis=0))
    rpe = mse/np.sqrt(strg_time.mean(axis=(0,1)))
    return mse_ch, mse, rpe_ch, rpe

In [None]:
mse = np.empty((len(datafile_concatdataset)))
mse_ch = np.empty((len(datafile_concatdataset),datafile_concatdataset.n_ch))
rpe = np.empty((len(datafile_concatdataset)))
rpe_ch = np.empty((len(datafile_concatdataset),datafile_concatdataset.n_ch))
for sample_idx in tqdm.tqdm(range(len(datafile_concatdataset))):
    mse_ch[sample_idx,:], mse[sample_idx], rpe_ch[sample_idx,:], rpe[sample_idx] = eval_sample(datafile_concatdataset,sample_idx,model)


In [None]:
result_dict = {
    'model_file': path.join(model_path,model_name,'checkpoint.pt'),
    'test_file_list': datafile_list,
    'src_t': src_t,
    'trg_t': trg_t,
    'mse': mse,
    'mse_ch': mse_ch,
    'rpe': rpe,
    'rpe_ch': rpe_ch
}
torch.save(result_dict,path.join(model_path,model_name,'test_result_data.pt'))

In [None]:
# plot the performance metrics
def plot_results(result_dict):
    f,ax = plt.subplots(2,1,dpi=100)
    ax[0].hist(result_dict['mse'],100,label='MSE')
    ax[1].hist(result_dict['rpe'],100,label='RPE')
plot_results(result_dict);