In [2]:
from aopy import datareader, datafilter
from ecog_is2s import EcogDataloader, Training
from ecog_is2s.model import Encoder, Decoder, Seq2Seq
from ecog_is2s.model import Util

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import SequentialSampler, BatchSampler, SubsetRandomSampler
from torch.utils.data import TensorDataset, random_split

import spacy
import numpy as np
import pandas as pd
# import sklearn
import scipy as sp

import random
import math
import time

# import progressbar as pb
import datetime
import os
import sys
import pickle as pkl

# import argparse # add back in once this runs

import matplotlib.pyplot as plt


ImportError: cannot import name 'Encoder' from 'ecog_is2s' (/Users/mickey/anaconda3/envs/ecog_is2s/lib/python3.7/site-packages/ecog_is2s-0.1-py3.7.egg/ecog_is2s/__init__.py)

In [None]:
# load, preprocess example data (same as training, so sue me)
# seed RNG for pytorch/np
SEED = 5050
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# set device - CUDA if you've got it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('mounting to device: {}'.format(device))

# load data
platform_name = sys.platform
if platform_name == 'darwin':
    # local machine
    data_file_full_path = '/Volumes/Samsung_T5/aoLab/Data/WirelessData/Goose_Multiscale_M1/180325/001/rec001.LM1_ECOG_3.clfp.dat'
    mask_file_path = "/Volumes/Samsung_T5/aoLab/Data/WirelessData/Goose_Multiscale_M1/180325/001/rec001.LM1_ECOG_3.clfp.mask.pkl"
    model_save_dir_path = '/Volumes/Samsung_T5/aoLab/Data/models/pyt/seq2seq/'
elif platform_name == 'linux2':
    # HYAK, baby!
    data_file_full_path = '/gscratch/stf/manolan/Data/WirelessData/Goose_Multiscale_M1/180325/001/rec001.LM1_ECOG_3.clfp.dat'
    mask_file_path = "/gscratch/stf/manolan/Data/WirelessData/Goose_Multiscale_M1/180325/001/rec001.LM1_ECOG_3.clfp.mask.pkl"
elif platform_name == 'linux':
    # google cloud, don't fail me now
    data_file_full_path = '/home/mickey/rec001.LM1_ECOG_3.clfp.dat'
    mask_file_path = '/home/mickey/rec001.LM1_ECOG_3.clfp.mask.pkl'
    model_save_dir_path = '/home/mickey/models/pyt/seq2seq/'

# make sure the output directory actually exists
if not os.path.exists(model_save_dir_path):
    os.makedirs(model_save_dir_path)

data_in, data_param, data_mask = datareader.load_ecog_clfp_data(data_file_name=data_file_full_path)
srate_in= data_param['srate']
num_ch = data_param['num_ch']
# we already found the appropriate data masks, so just load them in
with open(mask_file_path, 'rb') as f:
    mask_data = pkl.load(f)
hf_mask = mask_data["hf"]
sat_mask = mask_data["sat"]

# mask data array, remove obvious outliers
data_in[:,np.logical_or(hf_mask,sat_mask)] = 0.

# downsample data
srate_down = 250

# create dataset object from file
srate = srate_in
# data_in = np.double(data_in[:,:120*srate])
# enc_len = args.encoder_depth
# dec_len = args.decoder_depth
# seq_len = enc_len+dec_len # use ten time points to predict the next time point

total_len_T = 1*60 # I just don't have that much time!
total_len_n = total_len_T*srate_in
data_idx = data_in.shape[1]//2 + np.arange(total_len_n)
print('Downsampling data from {0} to {1}'.format(srate_in,srate_down))
data_in = np.float32(sp.signal.decimate(data_in[:,data_idx],srate_in//srate_down,axis=-1))
print('Data Size:\t{}'.format(data_in.shape))

# filter dead channels
ch_rms = np.std(data_in,axis=-1)
ch_m = np.mean(ch_rms)
ch_low_lim = ch_m - 2*np.std(ch_rms)
ch_up_lim = ch_m + 2*np.std(ch_rms)
ch_idx = np.logical_and(ch_rms > ch_low_lim, ch_rms < ch_up_lim)
ch_list = np.arange(num_ch)[ch_idx]
num_ch_down = len(ch_list)
print('Num. ch. used:\t{}'.format(num_ch_down))
print('Ch. dropped:\t{}'.format(np.arange(num_ch)[np.logical_not(ch_idx)]))

data_in = data_in[ch_idx,:]

In [None]:
# plot data statistics over windows
n_data = data_in.shape[-1]
data_z = sp.stats.zscore(data_in,axis=-1)
data_cov = np.cov(data_in)
data_z_cov = np.cov(data_z)
# full data covariance
f,ax = plt.subplots(2,1,figsize=(10,10))
plt.colorbar(ax[0].imshow(data_cov),label='Covariance')
ax[0].set_title('ECoG Covariance')
# f.show()
# f,ax = plt.subplots(1,1,figsize=(10,10))
plt.colorbar(ax[1].imshow(data_z_cov,label='Normalized Data'),label='Covariance')
ax[1].set_title('Normalized ECoG Covariance')
f.show()
# plot differences between standard deviations 
# print(np.mean(data_in,axis=-1),np.mean(data_z,axis=-1))
# print(np.std(data_in,axis=-1),np.std(data_z,axis=-1))
# n_row = 7
# n_col = 8
# f,ax = plt.subplots(7,8,figsize=(16,14))
# # print(ax)
# for r_i in range(n_row):
#     for c_i in range(n_col):
#         ch_i = r_i*n_col + c_i
#         ax[r_i,c_i].plot(data_in[ch_i,:])
# f.show()

In [None]:
# create model from trained state dict
device = torch.device('cpu')
model_file_path = '/Volumes/Samsung_T5/aoLab/Data/models/pyt/seq2seq/enc1000_dec1000_nl2_2020042115161587482188/model_checkpoint.pt'
model0_file_path = '/Volumes/Samsung_T5/aoLab/Data/models/pyt/seq2seq/enc1000_dec1000_nl2_2020042115161587482188/example_sequence_figs/data_tuple_epoch10.pt'
model_def = torch.load(model_file_path,map_location=device)
# model0_def = torch.load(model0_file_path,map_location=device)
print(model_def['model_state_dict'].keys())

In [None]:
f,ax = plt.subplots(1,2,figsize=(10,8))
ax[0].imshow(model_def['model_state_dict']['encoder.rnn.weight_ih_l0'])