# Dataset statistics + preprocessing
### Again and again and again... and again....
...for the last time!

Michael Nolan
2020.09.01

In [None]:
from aopy import datafilter, datareader

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tqdm import tqdm

import os
import ntpath
import sys
from glob import glob
import pickle as pkl

In [None]:
# get list of data files
data_root = "E:\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1\\"
data_file_list = glob(ntpath.join(data_root,'18*\\*\\*.clfp.dat'))
print(f'{len(data_file_list)} files found')

In [None]:
# function to compute spectrogram of multichannel ECoG data.
# doing this all at once takes a TON of RAM and I don't know why.
# It's much more memory efficient to compute the spectrograms separately.
def multichannel_spectrogram(data_file):
    data, exp, mask = datareader.load_ecog_clfp_data(data_file)
    srate = exp['srate']
    num_ch = exp['num_ch']
    sgram_window_len = 4
    sgram_over_len = 2
    bw = 2
    sgram_list = []
    for ch_idx in tqdm(range(num_ch)):
        f_sg, t_sg, _sgram = datafilter.mt_sgram(data[ch_idx,:],srate,sgram_window_len,sgram_over_len,bw)
        sgram_list.append(_sgram)
        del _sgram
    sgram_all = np.stack(sgram_list)
    del data, sgram_list
    return sgram_all, f_sg, t_sg

In [None]:
# save multichannel spectrograms as compressed .npz files
def save_multichannel_spectrogram(sgram,f_sg,t_sg,data_file):
    data_dir = ntpath.dirname(data_file)
    data_base_name = ntpath.basename(data_file)
    sgram_file_name = ntpath.join(data_dir,data_base_name+".sgram.npz")
    np.savez_compressed(sgram_file_name,sgram=sgram,f_sg=f_sg,t_sg=t_sg)

In [None]:
# plot data spectrograms. Each individual channel file is plotted.
def plot_mean_sgrams(sgram,f_sg,t_sg,data_file_path):
    data_save_dir = ntpath.dirname(data_file_path)
    fig_save_dir = ntpath.join(data_save_dir,'sgram_figs')
    if not os.path.exists(fig_save_dir):
        os.makedirs(fig_save_dir)
    data_base_name = ntpath.basename(data_file_path)
    n_ch, n_f, n_t = sgram.shape
    n_row = 1#4
    n_col = 1
    n_t_per_row = n_t // n_row
    for ch_idx in tqdm(range(n_ch)):
        f,ax = plt.subplots(n_row,n_col,sharex=True,dpi=120,figsize=(15,4))
        extent = (t_sg[0],t_sg[-1],f_sg[0],f_sg[-1])
        _im = ax.imshow(10*np.log10(sgram[ch_idx,:,:]),clim=(0,35),origin='lower',aspect='auto',extent=extent)
        ax.set_ylim(0,150)
        plt.colorbar(_im)
        ax.set_title(f'ch. {ch_idx+1}')
        fig_file_name = ntpath.join(fig_save_dir,f'sgram_ch{ch_idx+1}.png')
        f.savefig(fig_file_name)
        plt.close(fig=f)

In [None]:
# loop across all files to:
# - load data
# - compute sgrams
# - save sgrams
# - plot sgrams
for data_file in data_file_list:
    print(data_file)
    print("computing spectrograms...")
    sgram, f_sg, t_sg = multichannel_spectrogram(data_file)
    print("saving spectrogram file...")
    save_multichannel_spectrogram(sgram,f_sg,t_sg,data_file)
    print("saving spectrograms figures...")
    plot_mean_sgrams(sgram,f_sg,t_sg,data_file)
    del sgram

In [None]:
# loop across files to:
# - create spectrogram mask (time)
# - create channel mask
# - overwrite existing mask file to include channel mask
# - save spectrogram mask as a separate file

sgram_file_list = glob(ntpath.join(data_root,'18*\\*\\*.sgram.npz'))

for sgram_file in sgram_file_list:
    
    print(sgram_file)

    # load sgram data
    print('loading data...')
    sgram_file_data = np.load(sgram_file)
    sgram = sgram_file_data['arr_0']
    f_sg = sgram_file_data['arr_1']
    t_sg = sgram_file_data['arr_2']

    num_ch, num_f, num_t = sgram.shape
    
    # load mask
    mask_file = sgram_file[:-14] + '.mask.pkl'
    with open(mask_file,'rb') as f:
        mask_data_in = pkl.load(f)
    
    # compute bad power windows
    print('computing sgram window mask...')
    hf_power_db = 10*np.log10(sgram[:,f_sg > 100,:].mean(axis=1))
    hf_power_db_med = np.median(hf_power_db,axis=-1)
    hf_power_db_std = np.std(hf_power_db,axis=-1)
    win_scale = 2.0
    power_upper_thresh = hf_power_db_med + win_scale*hf_power_db_std
    power_lower_thresh = hf_power_db_med - win_scale*hf_power_db_std
    sgram_time_mask = np.logical_or(
        np.logical_or.reduce(hf_power_db.T>power_upper_thresh,axis=1),
        np.logical_or.reduce(hf_power_db.T<power_lower_thresh,axis=1)
        )
    print(f'{100*sgram_time_mask.mean()}% of windows masked.')
    
    # compute bad channels
    print('computing sgram channel mask...')
    mean_ch_power_db = 10*np.log10(sgram[:,:,~sgram_time_mask].mean(axis=(1,2)))
    mean_ch_power_db_median = np.median(mean_ch_power_db)
    mean_ch_power_db_std = np.std(mean_ch_power_db)
    bad_ch_mask = mean_ch_power_db < mean_ch_power_db_median - mean_ch_power_db_std
    
    # add channel mask to standard mask file
    print('adding channel mask to mask file...')
    mask_data_in['ch'] = bad_ch_mask
    with open(mask_file,'wb') as f:
        pkl.dump(mask_data_in,f)
    
    # save spectrogram time mask
    print('saving sgram window mask...')
    sgram_mask_file = sgram_file[:-14] + '.sgram.mask.pkl'
    with open(sgram_mask_file,'wb') as f:
        pkl.dump(sgram_mask_file,f)
    
    # plot masks, save to directory
    print('plotting mask visualization...')
    f,ax = plt.subplots(2,2,dpi=75,figsize=(8,6))
    
    ax[0,0].plot(t_sg,hf_power_db[0,:],label='power')
    ax[0,0].axhline(power_upper_thresh[0],label='low thr.')
    ax[0,0].axhline(power_lower_thresh[0],label='high thr.')
    ax[0,0].legend(loc=0)
    ax[0,0].set_xlabel('time (s)')
    ax[0,0].set_ylabel('power (dB)')
    ax[0,0].set_title('Ch. 0 power')
    
    ax[0,1].plot(np.arange(num_ch)+1,mean_ch_power_db,label='power')
    ax[0,1].axhline(mean_ch_power_db_median - mean_ch_power_db_std,label='low thr.')
    ax[0,1].legend(loc=0)
    ax[0,1].set_xlabel('ch.')
    ax[0,1].set_ylabel('mean power (dB)')
    ax[0,1].set_title('Mean Ch. Power')
    
    ax[1,0].plot(t_sg,sgram_time_mask,label="mask")
    ax[1,0].legend(loc=0)
    ax[1,0].set_xlabel('time (s)')
    ax[1,0].set_ylabel('mask (bool)')
    ax[1,0].set_title('Time Mask, all ch.')
    
    ax[1,1].get_xaxis().set_visible('false')
    ax[1,1].get_yaxis().set_visible('false')

    print('saving figure...')
    fig_save_name = ntpath.join(ntpath.dirname(sgram_file),'sgram_mask.png')
    f.savefig(fig_save_name)

    print('\n')
print('fin.')