In [1]:
import torch
import numpy as np
from scipy.signal import welch
import matplotlib.pyplot as plt
import os
import sys
import site
site.addsitedir(os.path.curdir + '\..')

In [2]:
#-------------------------------------------------------------------
#-------------------------------------------------------------------

def prep_model(model_name, data_dict, data_suffix, batch_size, device, hyperparams):
    if model_name == 'lfads_ecog':
        # train_dl, valid_dl, input_dims, plotter = prep_data(data_dict=data_dict, data_suffix=data_suffix, batch_size=batch_size, device=device)
        input_dims = data_dict['test_ecog_fl0u10'].shape[-1]
        model, objective = prep_lfads_ecog(input_dims = input_dims,
                                      hyperparams=hyperparams,
                                      device= device,
                                      dtype=data_dict['test_ecog_fl0u10'].dtype,
                                      dt= data_dict['dt']
                                      )
    return model

#-------------------------------------------------------------------
#-------------------------------------------------------------------

def prep_lfads_ecog(input_dims, hyperparams, device, dtype, dt):
    from objective import LFADS_Loss, LogLikelihoodGaussian
    from lfads import LFADS_Ecog_SingleSession_Net

    model = LFADS_Ecog_SingleSession_Net(input_size           = input_dims,
                                    factor_size          = hyperparams['model']['factor_size'],
                                    g_encoder_size       = hyperparams['model']['g_encoder_size'],
                                    c_encoder_size       = hyperparams['model']['c_encoder_size'],
                                    g_latent_size        = hyperparams['model']['g_latent_size'],
                                    u_latent_size        = hyperparams['model']['u_latent_size'],
                                    controller_size      = hyperparams['model']['controller_size'],
                                    generator_size       = hyperparams['model']['generator_size'],
                                    prior                = hyperparams['model']['prior'],
                                    clip_val             = hyperparams['model']['clip_val'],
                                    dropout              = hyperparams['model']['dropout'],
                                    do_normalize_factors = hyperparams['model']['normalize_factors'],
                                    max_norm             = hyperparams['model']['max_norm'],
                                    device               = device).to(device)
    
    loglikelihood = LogLikelihoodGaussian()

    objective = LFADS_Loss(loglikelihood            = loglikelihood,
                           loss_weight_dict         = {'kl': hyperparams['objective']['kl'], 
                                                       'l2': hyperparams['objective']['l2']},
                           l2_con_scale             = hyperparams['objective']['l2_con_scale'],
                           l2_gen_scale             = hyperparams['objective']['l2_gen_scale']).to(device)

    return model, objective

#-------------------------------------------------------------------
#-------------------------------------------------------------------
    
def prep_data(data_dict, data_suffix, batch_size, device):
    train_data  = torch.Tensor(data_dict['train_%s'%data_suffix])
    valid_data  = torch.Tensor(data_dict['valid_%s'%data_suffix])
    
    num_trials, num_steps, input_size = train_data.shape
    
    train_ds    = EcogTensorDataset(train_data,device=device)
    valid_ds    = EcogTensorDataset(valid_data,device=device)
    
    train_dl    = torch.utils.data.DataLoader(train_ds, batch_size = batch_size, shuffle=True)
    valid_dl    = torch.utils.data.DataLoader(valid_ds, batch_size = batch_size)
    
    TIME = torch._np.arange(0, num_steps*data_dict['dt'], data_dict['dt'])
    
    train_truth = {}
    if 'train_rates' in data_dict.keys():
        train_truth['rates'] = data_dict['train_rates']
    if 'train_latent' in data_dict.keys():
        train_truth['latent'] = data_dict['train_latent']
    if 'valid_spikes' in data_dict.keys():
        train_truth['spikes'] = data_dict['train_spikes']
        
    valid_truth = {}
    if 'valid_rates' in data_dict.keys():
        valid_truth['rates'] = data_dict['valid_rates']
    if 'valid_latent' in data_dict.keys():
        valid_truth['latent'] = data_dict['valid_latent']
    if 'valid_spikes' in data_dict.keys():
        valid_truth['spikes'] = data_dict['valid_spikes']

    plotter = None
    # plotter = {'train' : Plotter(time=TIME, truth=train_truth),
    #            'valid' : Plotter(time=TIME, truth=valid_truth)}
    
    return train_dl, valid_dl, input_size, plotter

#-------------------------------------------------------------------
#-------------------------------------------------------------------

from torch.utils.data.dataset import Dataset
class EcogTensorDataset(Dataset):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors, device='cpu'):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.device = device

    def __getitem__(self, index):
        return tuple(tensor[index].to(self.device) for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

#-------------------------------------------------------------------
#-------------------------------------------------------------------

import h5py
def read_data(data_fname,keys):
    
    """ Read saved data in HDF5 format.

    Args:
        data_fname: The filename of the file from which to read the data.
    Returns:
        A dictionary whose keys will vary depending on dataset (but should
        always contain the keys 'train_data' and 'valid_data') and whose
        values are numpy arrays.
    """
    try:
        with h5py.File(data_fname, 'r') as hf:
            data_dict = {k: np.array(v) for k, v in hf.items() if k in keys}
            return data_dict
    except IOError:
        print("Cannot open %s for reading." % data_fname)
        raise

In [3]:
hyperparameter_path = "C:\\Users\\mickey\\aoLab\\code\\hierarchical_lfads\\hyperparameters\\ecog\\lfads_ecog_3.yaml"
data_path = "D:\\Users\\mickey\\Data\\datasets\\ecog\\goose_wireless\\gw_250_fl0u20"
data_suffix = "ecog_fl0u10"
model_name = "lfads_ecog"
batch_size = 1000
device = 'cpu'

In [4]:
from utils import  load_parameters
hyperparams = load_parameters(hyperparameter_path)
data_dict   = read_data(data_path,keys = ['test_ecog_fl0u10','dt'])
model= prep_model(model_name = model_name,
                                        data_dict = data_dict,
                                        data_suffix = data_suffix,
                                        batch_size = batch_size,
                                        device = device,
                                        hyperparams = hyperparams)
srate = int(1/data_dict['dt'])

In [5]:
model_path = "D:\\Users\\mickey\\Data\\models\\pyt\\lfads\\gw_250_fl0u20\\lfads_ecog\\cenc0_cont0_fact64_genc128_gene128_glat128_seqlen50_ulat0_orion-\\checkpoints\\best.pth"
# model_path = "D:\\Users\\mickey\\Data\\models\\pyt\\lfads\\gw_250_fl0u10\\lfads_ecog\\cenc0_cont0_fact64_genc128_gene128_glat128_ulat0_orion-\\checkpoints\\best.pth"
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['net'])

<All keys matched successfully>

In [6]:
n_test_samples = data_dict['test_ecog_fl0u10'].shape[0]
rates = model.forward(torch.tensor(data_dict['test_ecog_fl0u10'],dtype=torch.float))
# rates = rates[0]['factors'].detach().numpy()

In [None]:
f_psd, factor_psd = welch(rates,fs=250,axis=0)
plt.plot(f_psd, 10*np.log10(factor_psd.mean(axis=-1)));

In [None]:
factor_psd_mean = factor_psd.mean(axis=(1,2))
factor_psd_95ci = np.percentile(factor_psd,[2.5, 97.5], axis=(1,2))

In [None]:
fig, ax = plt.subplots(1,1,dpi=80)
ax.fill_between(f_psd, 10*np.log10(factor_psd_95ci[0,:]), 10*np.log10(factor_psd_95ci[1,:]), alpha=0.3, label='95% CI')
ax.plot(f_psd, 10*np.log10(factor_psd_mean), label='mean')
ax.legend(loc=0)
ax.set_xlabel('freq. (Hz)')
ax.set_ylabel('PSD (dB)')
ax.set_title('Factor PSD, mean & CI')

In [None]:
f_psd, data_psd = welch(data_dict['test_ecog_fl0u10'],fs=250,axis=1)
plt.plot(f_psd, 10*np.log10(data_psd.mean(axis=-1)).T);

In [None]:
data_psd_mean = np.median(data_psd,axis=0)
data_psd_95ci = np.percentile(data_psd, [2.5, 97.5], axis=0)

In [None]:
fig, ax = plt.subplots(1,1,dpi=80)
ax.fill_between(f_psd, 10*np.log10(data_psd_95ci[0,:,0]), 10*np.log10(data_psd_95ci[1,:,0]), alpha=0.3, label='95% CI')
ax.plot(f_psd, 10*np.log10(data_psd_mean[:,0]), label='mean')
ax.legend(loc=0)
ax.set_xlabel('freq. (Hz)')
ax.set_ylabel('PSD (dB)')
ax.set_title('Data PSD, mean & CI')

In [None]:
fig, ax = plt.subplots(1,2,dpi=80,figsize=(8,4))
ax[0].fill_between(f_psd, 10*np.log10(data_psd_95ci[0,:,0]), 10*np.log10(data_psd_95ci[1,:,0]), alpha=0.3, label='data 95% CI')
ax[0].plot(f_psd, 10*np.log10(data_psd_mean[:,0]), label='data mean')
ax[0].fill_between(f_psd, 10*np.log10(factor_psd_95ci[0,:]), 10*np.log10(factor_psd_95ci[1,:]), alpha=0.3, label='factor 95% CI')
ax[0].plot(f_psd, 10*np.log10(factor_psd_mean), label='factor mean')
ax[0].legend(loc=0)
ax[0].set_xlabel('freq. (Hz)')
ax[0].set_ylabel('PSD (dB)')
ax[0].set_title('Data, Factor, mean & CI')
ax[1].fill_between(f_psd, 10*np.log10(data_psd_95ci[0,:,0]), 10*np.log10(data_psd_95ci[1,:,0]), alpha=0.3, label='data 95% CI')
ax[1].plot(f_psd, 10*np.log10(data_psd_mean[:,0]), label='data mean')
ax[1].fill_between(f_psd, 10*np.log10(factor_psd_95ci[0,:]), 10*np.log10(factor_psd_95ci[1,:]), alpha=0.3, label='factor 95% CI')
ax[1].plot(f_psd, 10*np.log10(factor_psd_mean), label='factor mean')
ax[1].legend(loc=0)
ax[1].set_xlabel('freq. (Hz)')
ax[1].set_ylabel('PSD (dB)')
ax[1].set_title('Data, Factor, mean & CI')
ax[1].set_xlim(0,20)