In [1]:
import numpy as np
import h5py
import os
from tqdm import tqdm

from cell_inference.config import paths, params
from cell_inference.utils.feature_extractors.SummaryStats2D import process_lfp

invivo_name = 'all_cell_LFP2D_Analysis_SensorimotorSpikeWaveforms_NP_SUTempFilter_NPExample'
# invivo_name = 'all_cell_LFP2D_Analysis_SensorimotorSpikes_NPUnits_large.h5'

INVIVO_PATH = 'cell_inference/resources/invivo'
DATA_PATH = os.path.join(INVIVO_PATH, invivo_name + '.h5')
LFP_PATH = os.path.join(INVIVO_PATH, 'lfp_' + invivo_name)  # LFP and labels
STATS_PATH = os.path.join(INVIVO_PATH, 'summ_stats_' + invivo_name)  # summary statistics

save_lfp = True
save_stats = True

In [2]:
with h5py.File(DATA_PATH,'r') as hf:
    elec_pos = hf['coord'][()]
    lfp = hf['data'][()].transpose((0,2,1)) # (samples x channels x time) -> (samples x time x channels)
    if 'layer' in hf:
        layer_type = {i[0]: name for name, i in hf['layer'].attrs.items()}
        layer = np.array([layer_type[i] for i in hf['layer']])
    else:
        layer = np.full(lfp.shape[0], 'N/A')

In [3]:
pad_spike_window = False
bad_cases = tuple(range(-1,3)) if pad_spike_window else tuple(range(3))

bad_indices = {bad: [] for bad in bad_cases}
lfp_list = []
ycenter = []
summ_stats = []

for i in tqdm(range(lfp.shape[0])):
    bad, g_lfp, _, _, y_c, _, ss = process_lfp(lfp[i], coord=elec_pos, dt=None, calc_summ_stats=save_stats, err_msg=True)
    bad_indices[bad].append(i)
    if bad<=0:
        lfp_list.append(g_lfp)
        ycenter.append(y_c)
        if save_stats:
            summ_stats.append(ss)

t = params.DT * np.arange(params.WINDOW_SIZE)
windowed_lfp = np.stack(lfp_list, axis=0)  # (samples x time window x channels)
ycenter = np.array(ycenter)
summ_stats = np.array(summ_stats)

bad_idx = np.array([i for bad, indices in bad_indices.items() if bad>0 for i in indices])
good_indices = np.sort([i for bad, indices in bad_indices.items() if bad<=0 for i in indices])
print('%d good samples out of %d samples.' % (good_indices.size, lfp.shape[0]))
for bad, indices in bad_indices.items():
    print('Bad case %d bad: %d samples.' % (bad, len(indices)))

100%|████████████████████████████████████████| 198/198 [00:03<00:00, 52.93it/s]

198 good samples out of 198 samples.
Bad case 0 bad: 198 samples.
Bad case 1 bad: 0 samples.
Bad case 2 bad: 0 samples.





In [4]:
if save_lfp:
    np.savez(LFP_PATH, t=t, x=windowed_lfp, yc=ycenter, layer=layer,
             bad_indices=bad_indices, good_indices=good_indices)
if save_stats:
    np.savez(STATS_PATH, x=summ_stats, yc=ycenter, layer=layer[good_indices])