In [2]:
import tensorflow as tf
import os
from machine_learning.neural_networks import tf_helpers as tfh
import pdb

from pynwb import NWBHDF5IO
import numpy as np
import os
import torch
import soundfile as sf
import scipy.stats
from process_nwb.resample import resample as resample_nwb
from scipy.stats.mstats import zscore
import samplerate

# Creating a list of dictionaries of:
# 
# `ecog_sequence`: ECoG data, clipped to token(-sequence) length
# `text_sequence`: the corresponding text token(-sequence)
# `audio_sequence`: the corresponding audio (MFCC) token sequence (gonna set to)
# `phoneme_sequence`: ditto for phonemes--with repeats
#
# Then saving them as tf_records

def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times, transcription, max_length, sampling_rate):
    
    # if the transcription is missing (e.g. for covert trials)
    if transcription is None:
        return np.full(max_length, 'pau', dtype='<U5')

    # get just the parts of transcript relevant to this trial
    trial_inds = (onset_times >= trial_t0) * (offset_times < trial_tF)
    transcript = np.array(transcription.description.split(' '))[trial_inds]
    onset_times = np.array(onset_times[trial_inds])
    offset_times = np.array(offset_times[trial_inds])

    # vectorized indexing
    sample_times = trial_t0 + np.arange(max_length)/sampling_rate
    indices = (
        (sample_times[None, :] >= onset_times[:, None]) *
        (sample_times[None, :] < offset_times[:, None])
    )

    # no more than one phoneme should be on at once...
    try:
        # print('exactly one phoneme:', np.all(np.sum(indices, 0) == 1))
        assert np.all(np.sum(indices, 0) < 2)
    except:
        pdb.set_trace()

    # ...but there can be locations with *zero* phonemes; assume 'pau' here
    transcript = np.insert(transcript, 0, 'pau')
    indices = np.sum(indices*(np.arange(1, len(transcript))[:, None]), 0)

    return transcript[indices]

def sentence_tokenize(token_list): # token_type = word_sequence
    tokenized_sentence = [
                (token.lower() + '_').encode('utf-8') for token in token_list
            ]
    return tokenized_sentence

def write_to_Protobuf(path, example_dicts):
    '''
    Collect the relevant ECoG data and then write to disk as a (google)
        protocol buffer.
    '''
    writer = tf.io.TFRecordWriter(
        path)
    for example_dict in example_dicts:
        feature_example = tfh.make_feature_example(example_dict)
        writer.write(feature_example.SerializeToString())
            
# sorting function for latent representation filenames, NOT USED FOR THIS
# def custom_sort_key(filename):
#     num_part = int(filename.split('nwb_')[1].split('.wav.pt')[0])
#     return num_part

def resample(
    data, source_to_target_ratio, ZSCORE, resample_method='sinc_best',
    N_channels_max=128
):

    ######################
    # If downsampling by an integer, just anti-alias and subsample??
    ######################

    # 128 is the max for the underlying library
    N_channels_max = min(N_channels_max, 128)
    N_channels = data.shape[1]
    data_mat = None

    for i0 in np.arange(0, N_channels, N_channels_max):
        iF = np.min((i0+N_channels_max, N_channels))
        resampler = samplerate.Resampler(resample_method, channels=iF-i0)
        data_chunk = resampler.process(
            data[:, i0:iF], 1/source_to_target_ratio, end_of_input=True
        )
        data_mat = (
            data_chunk if data_mat is None else
            np.concatenate((data_mat, data_chunk), axis=1)
        )
    if ZSCORE:
        data_mat = zscore(data_mat)

    return data_mat

def downsample(data, rate_source, rate_target, ZSCORE=False):
    return downsample_NWB(data, rate_source, rate_target, ZSCORE=ZSCORE)
    # return resample(data, rate_source/rate_target, ZSCORE=ZSCORE)

def downsample_NWB(data, rate_source, rate_target, ZSCORE=False):
    '''
    Downsample data from rate_source to rate_target using process_NWB methods

    Input arguments:
    --------
    data:
        an ndarray of the data to downsample (Nsamples_source, Nchannels)
    source_rate:
        the sampling rate of the input data
    rate_target:
        the sampling rate of the output data

    Returns:
    --------
    X:
        An ndarray (Nsamples_source, Nchannels) of the downsampled data
    '''

    # Note: zero padding is done in resample
 
    # downsampling
    print("Downsampling signals to %s Hz; please wait..." % rate_target)
    Nsamples_source, Nchannels = data.shape
    ##############
    # 1e6 scaling helps with numerical accuracy
    # scale = 1e6
    scale = 1
    # Is this true??
    ##############
    
    # malloc
    Nsamples_target = int(np.ceil(Nsamples_source*rate_target/rate_source))
    X = np.zeros((Nsamples_target, Nchannels))

    # One channel at a time, to improve memory usage for long signals
    for ch in np.arange(Nchannels):
        X[:, ch] = resample_nwb(data[:, ch]*scale, rate_target, rate_source)
    X = X/scale

    if ZSCORE:
        X = scipy.stats.mstats.zscore(X)

    return X

# Removing bad electrodes

def elec_layout(grid_size, grid_step):
    layout = np.arange(np.prod(
        grid_size)-1, -1, -1).reshape(grid_size).T

    # now correct for subsampling the grid
    return layout[::grid_step, ::grid_step]
    
def good_electrodes(grid_size, bad_electrodes):
    '''
    NB!!! bad_electrodes are 1-indexed, good_electrodes are zero-indexed!!

    Since this is a set, it contains no order information.  The canonical
    ordering is established with good_channels, since after all the data
    size is (... x Nchannels),  not (... x Nelectrodes).
    '''

    # bad_electrodes = [int(e.strip()) for e in bad_electrodes]
    return (
        set(range(np.prod(grid_size))) -
        set(np.array(bad_electrodes)-1)
    )
    
def bipolar_to_elec_map(layout):
    # print('WARNING!!!!  MAKING UP bipolar_to_elec_map!!!')
    elec_map = []
    # layout = self.elec_layout  # for short
    for i in range(layout.shape[0]):
        for j in range(layout.shape[1]):
            if j < layout.shape[1]-1:
                elec_map.append((layout[i, j], layout[i, j+1]))
            if i < layout.shape[0]-1:
                elec_map.append((layout[i, j], layout[i+1, j]))
    return np.array(elec_map)
    
def good_channels(elec_layout, bipolar_to_elec_map, good_electrodes):
    '''
    Pseudo-channels, constructed (on the fly) from the physical electrodes.
    For now at least, we won't USE_FIELD_POTENTIALS if we want to
    REFERENCE_BIPOLAR.

    NB!!: The *order* of these channels matters--it determines the order of
    the input data, and therefore is required by the functions that plot
    electrode_contributions in plotters.py! And the order of these channels
    will be determined by the *elec_layout*.
    '''

    # NB: this means that the electrodes are *not* in numerical order ('e1'
    #  does not correspond to the 0th entry in all_electrodes): as you can
    #  check, flattening the elec_layout does not yield an ordered list.
    all_electrodes = elec_layout.flatten().tolist()

    # if self.USE_FIELD_POTENTIALS:
    #     M = len(all_electrodes)
    #     return (
    #         [e for e in all_electrodes if e in self.good_electrodes] +
    #         [e+M for e in all_electrodes if e in self.good_electrodes]
    #     )
    # elif self.REFERENCE_BIPOLAR:
    return [
        ch for ch, elec_pair in enumerate(bipolar_to_elec_map)
        if all([e in good_electrodes for e in elec_pair])
    ]
    # else:
    #     return [e for e in all_electrodes if e in self.good_electrodes]

2023-12-16 11:12:10.585811: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-16 11:12:10.585840: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-16 11:12:10.586854: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-16 11:12:10.591675: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [4]:
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, filtfilt, hilbert
from scipy.fft import fft, ifft

all_example_dict = []
patient = 'EFC400'
blocks = [3, 23, 72]
        # [4, 41, 57, 61, 66, 69, 73, 77, 83, 87] # [3,4,6,8,10,12,14,15,19,23,28,30,38,40,42,46,57,61,72] # change this for what tf_record you're making
_bad_electrodes = [1, 2, 33, 50, 54, 64, 128, 129, 193, 194, 256]
    #[1,2,63,64,65,127,143,193,194,195,196,235,239,243,252,254,255,256]
# bad_electrodes = [i - 1 for i in bad_electrodes]
# good_electrodes = [x for x in list(np.arange(256)) if x not in bad_electrodes]

grid_size = np.array([16, 16])
grid_step = 1

_good_electrodes = good_electrodes(grid_size, _bad_electrodes)
_elec_layout = elec_layout(grid_size, grid_step)
_bipolar_to_elec_map = bipolar_to_elec_map(_elec_layout)

# print(_bad_electrodes)
# print(_good_electrodes)

_good_channels = good_channels(_elec_layout, _bipolar_to_elec_map, _good_electrodes)

# print(len(_good_channels))
for block in blocks:

    tfrecord_path = f'/home/bayuan/Documents/fall23/ecog2vec/wav2vec_tfrecords/ecog2txt/word_sequence/tf_records_orig_400/{patient}_B{block}.tfrecord'

    nwb_filepath = folder_path = f"/NWB/{patient}/{patient}_B{block}.nwb"
    io = NWBHDF5IO(nwb_filepath, load_namespaces=True, mode='r')
    nwbfile = io.read()
    
    electrode_table = nwbfile.acquisition['ElectricalSeries'].\
                                        electrodes.table[:]

    indices = np.where(np.logical_or(electrode_table['group_name'] == 
                                        'L256GridElectrode electrodes', 
                                        electrode_table['group_name'] == 
                                        'R256GridElectrode electrodes'))[0]

    raw_data = nwbfile.acquisition['ElectricalSeries'].\
                                    data[:,indices]
                                    
    raw_data = raw_data[:,sorted(list(_good_electrodes))]
                    
    # high_gamma = downsample(high_gamma, 400, 200, ZSCORE=True)
    
    nwb_sr = nwbfile.acquisition['ElectricalSeries'].\
                                rate
    
    w_l = 70 / (nwb_sr / 2) # Normalize the frequency
    w_h = 200 / (nwb_sr / 2)
    b, a = butter(5, [w_l,w_h], 'band')
    
    for ch in range(raw_data.shape[1]):
        raw_data[:,ch] = filtfilt(b, 
                                a, 
                                raw_data[:,ch])
        
        #analytic amp
        raw_data[:,ch] = np.abs(hilbert(raw_data[:,ch]))
        
    high_gamma = raw_data
    
    phoneme_transcriptions = nwbfile.processing['behavior'].data_interfaces['BehavioralEpochs'].interval_series #['phoneme transcription'].timestamps[:]

    token_type = 'word_sequence'

    max_seconds_dict = {
        'phoneme': 0.2,
        'word': 1.0,
        'word_sequence': 6.25,
        'word_piece_sequence': 6.25,
        'phoneme_sequence': 6.25,
        'trial': 6.25
    }

    if 'phoneme transcription' in phoneme_transcriptions:
        print(f'Phoneme transcription for block {block} exists.')
        phoneme_transcript = phoneme_transcriptions['phoneme transcription']
        phoneme_onset_times = phoneme_transcript.timestamps[
            phoneme_transcript.data[()] == 1]
        phoneme_offset_times = phoneme_transcript.timestamps[
            phoneme_transcript.data[()] == -1]
    else:
        phoneme_transcript = None
        phoneme_onset_times = None
        phoneme_offset_times = None

    example_dicts = []

    makin_sr = 101.7 # 200
    
    
    high_gamma = downsample(high_gamma, nwb_sr, makin_sr, ZSCORE=True)
    
    # starts = list(nwbfile.trials[:]['start_time']) # * nwb_sr)
    # stops = list(nwbfile.trials[:]['stop_time']) # * nwb_sr)
    
    # print(starts[0], stops[0])

    for index, trial in enumerate(nwbfile.trials):
        t0 = float(trial.iloc[0].start_time)
        tF = float(trial.iloc[0].stop_time)
    
        i0 = np.rint(makin_sr*t0).astype(int)
        iF = np.rint(makin_sr*tF).astype(int)
        
        # ECOG (C) SEQUENCE
        c = high_gamma[i0:iF,:]
        # print(c.shape)
        # plt.plot(c[:,0])
        # break
    
        print(c.shape)
        nsamples = c.shape[0]
        
        # TEXT SEQUENCE
        speech_string = trial['transcription'].values[0]
        text_sequence = sentence_tokenize(speech_string.split(' ')) # , 'text_sequence')
        
        # AUDIO SEQUENCE    
        audio_sequence = []
        
        # PHONEME SEQUENCE
        
        M = iF - i0
        
        max_seconds = max_seconds_dict.get(token_type) # , 0.2) # i don't think this 0.2 default is necessary for the scope of this
        max_samples = int(np.floor(makin_sr*max_seconds))
        max_length = min(M, max_samples)
        
        phoneme_array = transcription_to_array(
                        t0, tF, phoneme_onset_times, phoneme_offset_times,
                        phoneme_transcript, max_length, makin_sr
                    )
        
        phoneme_sequence = [ph.encode('utf-8') for ph in phoneme_array]
        
        if len(phoneme_sequence) != nsamples:
            if len(phoneme_sequence) > nsamples:
                phoneme_sequence = [phoneme_sequence[i] for i in range(nsamples)]
            else:
                for i in range(nsamples - len(phoneme_sequence)):
                    phoneme_sequence.append(phoneme_sequence[len(phoneme_sequence) - 1])
        
        print('\n------------------------')
        print(f'For sentence {index}: ')
        print(c[0:5,0:5])
        print(f'Latent representation shape: {c.shape} (should be [samples, nchannel])')
        print(text_sequence)
        print(f'Audio sequence: {audio_sequence}')
        print(f'Length of phoneme sequence: {len(phoneme_sequence)}')
        print(phoneme_sequence)
        print('------------------------\n')
        
        example_dicts.append({'ecog_sequence': c, 'text_sequence': text_sequence, 'audio_sequence': [], 'phoneme_sequence': phoneme_sequence,})
        
        
        # break
    
    # break
        
    all_example_dict.extend(example_dicts)
    print(len(example_dicts))
    print(len(all_example_dict))
    write_to_Protobuf(tfrecord_path, example_dicts)

print(len(all_example_dict))


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Phoneme transcription for block 3 exists.
Downsampling signals to 101.7 Hz; please wait...
(138, 245)

------------------------
For sentence 0: 
[[ 0.07585072  0.31488337  0.23808057  0.30820591  0.98758648]
 [-1.64120539 -1.84417401 -1.73916661 -1.36783566  1.9998449 ]
 [-0.41667863  0.0694105   0.21159387 -0.48622319  0.36795554]
 [-1.94118787 -1.91691413 -1.6200376  -0.41769632  0.96471845]
 [ 0.68023269  1.05807324  1.71942254  2.48887566 -0.6908969 ]]
Latent representation shape: (138, 245) (should be [samples, nchannel])
[b'this_', b'was_', b'easy_', b'for_', b'us_']
Audio sequence: []
Length of phoneme sequence: 138
[b'dh', b'dh', b'dh', b'dh', b'ih', b'ih', b'ih', b'ih', b'ih', b'ih', b'ih', b's', b's', b's', b's', b's', b's', b's', b's', b's', b's', b's', b's', b'w', b'aa', b'aa', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'z', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'z', b'z', b'z', b'z',

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Phoneme transcription for block 23 exists.
Downsampling signals to 101.7 Hz; please wait...
(292, 245)

------------------------
For sentence 0: 
[[ 0.40931971  0.44673779 -0.28209702 -1.29939047 -0.83728231]
 [ 3.01114877  3.12470635  2.5194325   1.31126521  0.22670828]
 [ 0.28733601  1.35165244  1.15476848  0.96706364  1.09840194]
 [-0.01991355 -0.03897973 -0.16313629 -0.1249096   0.80687267]
 [-0.49438388 -0.15210176 -0.17361471 -1.79249963  0.72068759]]
Latent representation shape: (292, 245) (should be [samples, nchannel])
[b'young_', b'people_', b'participate_', b'in_', b'athletic_', b'activities_']
Audio sequence: []
Length of phoneme sequence: 292
[b'y', b'y', b'y', b'y', b'y', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ah', b'ng', b'ng', b'ng', b'ng', b'ng', b'ng', b'ng', b'ng', b'p', b'p', b'p', b'p', b'p', b'p', b'p', b'p', b'p', b'p', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'iy', b'p', b'p', b'p', b'p', b'ax', b'ax', b'l', b'l',

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Phoneme transcription for block 72 exists.
Downsampling signals to 101.7 Hz; please wait...
(259, 245)

------------------------
For sentence 0: 
[[-0.5455242   1.27007407  1.0938933   1.24677523  2.71933477]
 [ 0.54938078 -0.76484219 -1.08386255 -1.105022    1.46344822]
 [ 0.44974788 -0.10119321  0.15946003  0.05690019  0.01840509]
 [-0.42780259 -0.0127936  -0.16869403 -0.56820779 -0.3148242 ]
 [ 0.72466244  0.77090613  0.98307853  0.93807207  0.46212541]]
Latent representation shape: (259, 245) (should be [samples, nchannel])
[b'beg_', b'that_', b'guard_', b'for_', b'one_', b'gallon_', b'of_', b'petrol_']
Audio sequence: []
Length of phoneme sequence: 259
[b'b', b'b', b'b', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'eh', b'g', b'g', b'g', b'g', b'g', b'g', b'g', b'g', b'g', b'g', b'g', b'dh', b'dh', b'dh', b'dh', b'dh', b'dh', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b'ae', b't', b't', b'g', b'g', b