# creating (and saving a single instance of) an h5py dataset for the hierarchical_lfads code (or any other time series reconstruction model I cook up or find from github)

Michael Nolan

2020.10.26

In [None]:
import os
import site
from glob import glob
site.addsitedir("C:\\Users\\mickey\\aoLab\\Code\\hierarchical_lfads")
from utils import write_data
import pickle as pkl
import h5py
import numpy as np
import scipy as sp
import aopy
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def get_concat_dataset( src_t=1.0, trg_t=0.0, step_t=1.0, filt_str=''):
    # get complete dataset and sampling arrays
    data_path = "C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1"
    data_file_list = glob(os.path.join(data_path,f'18032[0-9]*\\0[0-9]*\\*ECOG_3.clfp_ds250{filt_str}.dat'))
    print(f'files found:\t{len(data_file_list)}')
    # create dataset interface - may have been unnecessary...
    src_t = 1.0
    trg_t = 0.0
    step_t = 1.0
    df_list = [aopy.data.DataFile(dfp) for dfp in data_file_list]
    dfds_list = [aopy.data.DatafileDataset(df,src_t,trg_t,step_t) for df in df_list]
    dfcds = aopy.data.DatafileConcatDataset(dfds_list)
    print(f'total samples: {len(dfcds)}')
    return dfcds

In [None]:
def create_ndarray_from_cds( cds, detrend=True, zscore=True ):
    #create dataset from all samples after z-scoring and detrending
    n_trial = len(cds)
    n_t, n_ch = cds.__getitem__(0)[0].size()
    data = np.empty((n_trial,n_t,n_ch))
    empty_ch_count = np.empty(n_trial)
    for trial_idx in tqdm(range(len(cds))):
        _sample = cds.__getitem__(trial_idx)[0]
        if detrend:
            _sample = sp.signal.detrend(_sample,axis=0,type='linear')
        if zscore:
            _sample = sp.stats.zscore(_sample,axis=0)
        empty_ch_count[trial_idx] = np.isnan(_sample.mean(axis=0)).sum()
        data[trial_idx,:,:] = _sample
    data = data[empty_ch_count == 0,:,:]
    return data, empty_ch_count

In [None]:
def create_data_dict( data, dt, cds, shuffle=True, rng_seed=42, train_valid_test_split=(0.7,0.2,0.1), filt_str='' ):
    # get split data indices
    n_samples = data.shape[0]
    n_train = round(n_samples*train_valid_test_split[0])
    n_valid = round(n_samples*train_valid_test_split[1])
    n_test = round(n_samples*train_valid_test_split[2])
    if shuffle:
        # shuffle your dataset trials
        rng = np.random.default_rng(rng_seed)
        data = rng.permutation(data,axis=0)
    # create data dict
    data_dict = {
        f'train_ecog{filt_str}': data[:n_train,:,:],
        f'valid_ecog{filt_str}': data[n_train:n_train+n_valid,:,:],
        f'test_ecog{filt_str}': data[n_train+n_valid:,:,:],
        'dt': dt, # does this need to be in here? could this be external?
    }
    param_dict = {
        'file_list': [ds.datafile.data_file_path for ds in cds.datasets],
        'src_t': cds.datasets[0].src_t,
        'step_t': cds.datasets[0].step_t,
        'rng_seed': rng_seed,
        'train_valid_test_split': train_valid_test_split,
    }
    return data_dict, param_dict

In [None]:
filt_str_list = ['','_fl0u10','_fl0u20','_fl0u30']
for filt_str in filt_str_list:
    cds = get_concat_dataset(filt_str = filt_str)
    dt = 1/cds.srate
    dataset, empty_count = create_ndarray_from_cds(cds)
    data_dict, param_dict = create_data_dict(dataset, dt, cds, filt_str=filt_str)
    data_dict.keys()
    h5_dataset_dir = "D:\\Users\\mickey\\Data\\datasets\\ecog\\goose_wireless"
    h5_dataset_path = os.path.join(h5_dataset_dir,f"gw_250{filt_str}")
    write_data(h5_dataset_path,data_dict,compression=None)
    param_file_path = os.path.join(h5_dataset_dir,f"gw_250{filt_str}_param.pkl")
    with open(param_file_path,'wb') as param_f:
        pkl.dump(param_dict,param_f)