# Data Mask Generation 
## (new machine, wayy more RAM)

Michael Nolan

2020.09.16

In [None]:
import aopy
import os.path as path
from glob import glob
import tqdm
import time

import scipy as sp
import numpy as np
import matplotlib.pyplot as plt

import pickle as pkl


In [None]:
# get file list
data_dir = "C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1"
file_list = glob(path.join(data_dir,"18032[0-9]\\0[0-9]*\\*ECOG_3.clfp_ds250.dat"))
print(f'Files found: {len(file_list)}')
datafile_list = [aopy.data.DataFile(f) for f in file_list]

In [None]:
power_window_t = 10
gmean_thresh_scale = 1.25
amp_thresh_scale = 10
plot_ch_idx = 0
# figsize = (16,5)
plot_scale = 1.2
for df in datafile_list:
    print(df.data_file_path)
    
    # get data, params
    srate = df.srate
    ch_idx = df.ch_idx
    data = df.read()
    n_ch, n_sample = data.shape
    
    # compute window power
    power_window_n = srate * power_window_t
    n_power_window = int(np.ceil(n_sample / power_window_n))
    ch_mean = data.mean(axis=-1)
    rmse_window = np.empty((n_power_window))
    for win_idx in tqdm.tqdm(range(n_power_window)):
        if win_idx == n_power_window -1:
            win_sample_idx = np.arange(win_idx*power_window_n,n_sample-1)
        else:
            win_sample_idx = win_idx*power_window_n + np.arange(power_window_n)
        _win_data_norm = data[:,win_sample_idx].T - ch_mean # local variance won't catch drift well
        rmse_window[win_idx] = np.sqrt((_win_data_norm**2).mean(axis=(0,1)))
    
    # compute, apply power thresholds
    gmean_rmse = sp.stats.gmean(rmse_window)
    bad_window = rmse_window > gmean_thresh_scale*gmean_rmse
    bad_data_mask = bad_window.repeat(power_window_n)[:n_sample]
    
    # run amplitude filter
    amp_thresh_high = data[:,~bad_data_mask].mean(axis=-1) + amp_thresh_scale*data[:,~bad_data_mask].std(axis=-1)
    amp_thresh_low = data[:,~bad_data_mask].mean(axis=-1) - amp_thresh_scale*data[:,~bad_data_mask].std(axis=-1)
    bad_window_amp = bad_window.copy()
    for win_idx in tqdm.tqdm(np.arange(n_power_window)[~bad_window]): # over all "good" windows
        if win_idx == n_power_window -1:
            win_sample_idx = np.arange(win_idx*power_window_n,n_sample-1)
        else:
            win_sample_idx = win_idx*power_window_n + np.arange(power_window_n)
        _win_data = data[:,win_sample_idx]
        oor_low = np.any(_win_data.T < amp_thresh_low, axis=(0,1))
        oor_high = np.any(_win_data.T > amp_thresh_high, axis=(0,1))
        bad_window_amp[win_idx] = oor_low | oor_high
    
    # get secondary thresholds
    bad_data_amp_mask = bad_window_amp.repeat(power_window_n)[:n_sample]

    # print threshold percentages
    print(f'power thresh: {100*bad_data_mask.mean():0.3f}%')
    print(f'power, amp. thresh: {100*bad_data_amp_mask.mean():0.3f}%')

    # plot results
    f, ax = plt.subplots()
    time = np.arange(n_sample)/srate
    ax.plot(time,data[plot_ch_idx,:].T,label='no filter')
    ax.plot(time[~bad_data_mask],data[plot_ch_idx,~bad_data_mask].T,label='power filter')
    ax.plot(time[~bad_data_amp_mask],data[plot_ch_idx,~bad_data_amp_mask].T,label='power, amp. filter')
    ax.axhline(data[plot_ch_idx,~bad_data_mask].mean(),color='k')
    ax.axhline(amp_thresh_high[plot_ch_idx],color='k',linestyle=':')
    ax.axhline(amp_thresh_low[plot_ch_idx],color='k',linestyle=':')
    ax.set_ylim(plot_scale*amp_thresh_low[plot_ch_idx],plot_scale*amp_thresh_high[plot_ch_idx])
    ax.legend(loc=0)
    ax.set_xlabel('time (s)')
    ax.set_ylabel('amp. ($\mu$V)')
    ax.set_title(f'ECoG data, ch. {plot_ch_idx}')

    # save figure
    file_dir = path.dirname(df.data_file_path)
    f.savefig(path.join(file_dir,f'power_and_amplitude_filter_ch{plot_ch_idx}_ds250.png'))

    # save new mask to data file under 'sat' value
    with open(df.mask_file_path,'rb') as f:
        mask_dict = pkl.load(f)
    mask_dict['sat'] = bad_data_amp_mask
    with open(df.mask_file_path,'wb') as f:
        pkl.dump(mask_dict,f)
    print(f'mask file {df.mask_file_path} updated.')
    