<a href="https://colab.research.google.com/github/neuralsrg/EEG/blob/main/nn/Conv%26RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np 
import tensorflow as tf 
import matplotlib as plt
import pandas as pd

try:
  import mne
except Exception:
  import sys

  from google.colab import drive
  drive.mount('/content/drive')

  sys.path.append('/content/drive/MyDrive/lib') # should read mne from here
  import mne

from typing import Optional, List, Set, Tuple

In [70]:
ORIGIN_URL = '/content/drive/MyDrive/EEG_data/relabeled_unpacked/Bashirin_phonemes_.EDF'

class WindowGenerator():

  def __init__(self,
               origin_url : Optional[str] = ORIGIN_URL) -> None:
      self._mne_data = mne.io.read_raw_edf(ORIGIN_URL)
      self._raw_data = self._mne_data.get_data()
      self._labeled_data = mne.io.RawArray(self._raw_data,
                                          mne.create_info(self._mne_data.ch_names,
                                                          1006.04, ch_types='eeg'))
      self._label_channel = self._raw_data[-1]
      self._raw_data = self._raw_data[:-1]

      self._normalize = np.max(np.abs(self._raw_data), axis=-1)[..., np.newaxis]
      self._raw_data = self._raw_data / self._normalize
      print('\n\nAll the data was normalized. Refer to normalization coefficient as WindowGenerator.normalize')

      self._df = pd.DataFrame(self._raw_data.T, columns=self._mne_data.ch_names[:-1])


  def eeg_info(self) -> pd.DataFrame:
    """ Prints eeg data info """
    print(self._mne_data.info)
    return self._df.describe().transpose().head(68)


  def _generate_indices(self) -> None:
    """ Creates indices for data windows """
    def select(values, indices):
      if values[0] == values[1] + 10:
        return set(indices)
      else:
        return {}
    
    ind = np.arange(self._label_channel.shape[0])
    ind = ind[self._label_channel > 0]

    indices = set()

    for i in range(ind.shape[0] - 1):
      indices = indices.union(select(self._label_channel[ind[i : i + 2]], ind[i : i + 2]))

    self._non_zero_label_inds = np.array(sorted(list(indices)))


  def _generate_sequences(self,
                          skip_in_repeat : Optional[int] = 100,
                          verbose : Optional[bool] = True) -> None:
    """ Creates data sequences """
    indices_listen = self._non_zero_label_inds[::2]
    indices_repeat = self._non_zero_label_inds[1::2] + skip_in_repeat

    indices_listen = np.array(sorted(list(set(indices_listen).
        union(*[set(indices_listen + j) for j in range(1, self._length)]))))
    indices_repeat = np.array(sorted(list(set(indices_repeat).
        union(*[set(indices_repeat + j) for j in range(1, self._length)]))))
    indices_noise = np.concatenate((np.arange(1000, 1000 + self._length),
                                    indices_repeat[:-self._length] + 10000),
                                    dtype=int)
    self._labels = self._label_channel[self._non_zero_label_inds[1::2]]
    n_events = self._labels.shape[0]
    sig_listen = np.empty((n_events, 0, self._length))
    sig_repeat = np.empty((n_events, 0, self._length))
    sig_noise = np.empty((n_events, 0, self._length))

    # so straight because of insufficient RAM problem...
    num_electrodes = self._raw_data.shape[0]
    for i in range(num_electrodes):

      sig_listen = np.append(sig_listen, self._raw_data[i][indices_listen].
          reshape(self._length, n_events).T.reshape(n_events, 1, self._length), axis=1)
      
      sig_repeat = np.append(sig_repeat, self._raw_data[i][indices_repeat].
          reshape(self._length, n_events).T.reshape(n_events, 1, self._length), axis=1)
      
      sig_noise = np.append(sig_noise, self._raw_data[i][indices_noise].
          reshape(self._length, n_events).T.reshape(n_events, 1, self._length), axis=1)

    self._listen = sig_listen
    self._repeat = sig_repeat
    self._noise = sig_noise


  def _split_windows(self) -> None:
    """ Splits data sequences into smaller windows """
    def window_ds(array : np.array) -> np.array: # shape (batch, channels, features)
      array = array.T # shape (features, channels, batch)
      ds = tf.data.Dataset.from_tensor_slices(array)
      win = ds.window(self._window_size,  # (num_windows, window_size, channels, batch)
                      shift=self._shift,
                      stride=self._stride,
                      drop_remainder=True)
      
      flatten = lambda x: x.batch(self._window_size, drop_remainder=True)
      win = win.flat_map(flatten)

      array = np.array(list(win.as_numpy_iterator()))
      array = np.moveaxis(array, 0, -1) # (window_size, channels, batch, num_windows)
      shape = (array.shape[0], array.shape[1], array.shape[2] * array.shape[3])
      array = array.reshape(shape) # (window_size, channels, batch)

      return array.T # (batch, channels, window_size)

    self._listen = window_ds(self._listen)
    self._repeat = window_ds(self._repeat)
    self._noise = window_ds(self._noise)

  
  def _verb(self, listen_repeat_noise : Optional[Tuple[bool, bool, bool]]) -> None:
    msg = np.array([
        f'listen.shape : {self._listen.shape}',
        f'repeat.shape : {self._repeat.shape}',
        f'noise.shape : {self._noise.shape}'
    ])
    print('\n'.join(msg[listen_repeat_noise]))

  
  def create_dataset(self,
                     event_length : Optional[int] = 300,
                     skip_in_repeat : Optional[int] = 100,
                     listen_repeat_noise : Optional[Tuple[bool, bool, bool]] = [True, False, True],
                     train_val_test : Optional[Tuple[float, float, float]] = [.8, .2, 0],
                     batch_size : Optional[int] = 32,
                     split_windows : Optional[bool] = False,
                     window_size : Optional[int] = 128,
                     shift : Optional[int] = None,
                     stride : Optional[int] = 1,
                     verbose : Optional[bool] = True) -> None:
    """
    Creates datasets
    
    Args:
    event_length (optional) -- int. Length of the event. 300ms = 300 samples
    skip_in_repeat (optional) -- samples to skip after repeat label. 100ms = 100 samples
    listen_repeat_noise (optional) -- List[bool, bool, bool]. Specifies whitch data to include in the dataset
      The first bool refers to listen, the second to repeat etc.
    train_val_test (optional) -- List[float, float, float]. Specifies train/val/test ratios respectively
    batch_size (optional) -- batch size
    split_windows (optional) -- bool. Whether to breake neural data sequences down into smaller ones
    window_size (optional) -- int. Used if split_windows == True. The length of the smaller windows
    shift (optional) -- int. Used if split_windows == True. Hop length used when creating smaller windows.
      If None, then smaller windows do not overlap.
    stride (optional) -- int. Used if split_windows == True. Determines the stride between input elements within a window.
      Not recommended to change!
    verbose : Optional[bool] = True. Whether to print logging messages

    Returns:
    None.
    Generated datasets are stored in WindowGenerator.train, WindowGenerator.val and WindowGenerator.test
    """
    assert np.sum(train_val_test) == 1

    self._length = event_length
    self._window_size = window_size
    self._shift = shift # The shift argument determines the number of input elements to shift between the start of each window
    self._stride = stride # The stride argument determines the stride between input elements within a window

    self._generate_indices()
    self._generate_sequences(skip_in_repeat, verbose)

    if split_windows:
      self._split_windows()

    if verbose:
      self._verb(listen_repeat_noise)

    self._listen_ds = tf.data.Dataset.from_tensor_slices(self._listen)
    self._repeat_ds = tf.data.Dataset.from_tensor_slices(self._repeat)
    self._noise_ds = tf.data.Dataset.from_tensor_slices(self._noise)

    num_examples = np.count_nonzero(listen_repeat_noise) * self._listen.shape[0]

    noise_lambda = lambda x: (x, 0)
    listen_lambda = lambda x: (x, 1)
    repeat_lambda = lambda x: (x, 2)

    self._listen_ds = self._listen_ds.map(listen_lambda, num_parallel_calls=tf.data.AUTOTUNE)
    self._repeat_ds = self._repeat_ds.map(repeat_lambda, num_parallel_calls=tf.data.AUTOTUNE)
    self._noise_ds = self._noise_ds.map(noise_lambda, num_parallel_calls=tf.data.AUTOTUNE)

    datasets = np.array([self._listen_ds, self._repeat_ds, self._noise_ds])
    datasets = datasets[listen_repeat_noise]

    dataset = datasets[0]
    if len(datasets) > 1:
      for ds in datasets[1:]:
        dataset = dataset.concatenate(ds)

    print('Total number of elements in the dataset:', num_examples)

    dataset = dataset.shuffle(num_examples)

    n_sets = np.count_nonzero(train_val_test)
    sizes = [int(num_examples * train_val_test[i]) for i in range(n_sets - 1)]
    sizes.append(num_examples - np.sum(sizes))

    datasets = []
    for size in sizes:
      datasets.append(dataset.take(size))
      dataset = dataset.skip(size)

    sets = np.array(['train', 'val', 'test'])[np.arange(n_sets)]
    if verbose:
      msg = [f'{sets[i]} dataset contains {size} elements' for i, size in enumerate(sizes)]
      print('\n'.join(msg))

    print('\nRefer to datasets as:')
    msg = [f'\t WindowGenerator().{s}' for s in sets]
    print('\n'.join(msg))

    self._train = datasets[0].batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
    self._val = datasets[1].batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
    if len(datasets) > 2:
      self._test = datasets[2].batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)

    
  @property
  def train(self):
    return self._train


  @property
  def val(self) -> tf.data.Dataset:
    return self._val


  @property
  def test(self) -> tf.data.Dataset:
    return self._test

In [66]:
generator = WindowGenerator()

Extracting EDF parameters from /content/drive/MyDrive/EEG_data/relabeled_unpacked/Bashirin_phonemes_.EDF...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Creating RawArray with float64 data, n_channels=69, n_times=660636
    Range : 0 ... 660635 =      0.000 ...   656.669 secs
Ready.


All the data was normalized. Refer to normalization coefficient as WindowGenerator.normalize


In [67]:
generator.eeg_info()

<Info | 7 non-empty values
 bads: []
 ch_names: EEG O1-M1, EEG O2-M2, EEG P3-M1, EEG P4-M2, EEG C3-M1, EEG ...
 chs: 69 EEG
 custom_ref_applied: False
 highpass: 0.5 Hz
 lowpass: 503.0 Hz
 meas_date: 2020-02-15 16:49:00 UTC
 nchan: 69
 projs: []
 sfreq: 1006.0 Hz
>


Unnamed: 0,count,mean,std,min,25%,50%,75%,max
EEG O1-M1,660636.0,-7.249706e-05,0.139873,-1.000000,-0.082073,0.002160,0.088553,0.831533
EEG O2-M2,660636.0,4.125636e-06,0.148530,-0.948370,-0.089674,0.005435,0.095109,1.000000
EEG P3-M1,660636.0,5.010845e-07,0.083417,-1.000000,-0.050000,0.001724,0.051724,0.817241
EEG P4-M2,660636.0,9.249510e-05,0.155021,-0.947205,-0.093168,0.003106,0.099379,1.000000
EEG C3-M1,660636.0,9.846589e-06,0.054143,-1.000000,-0.030067,0.000000,0.030067,0.694878
...,...,...,...,...,...,...,...,...
v 7 3,660636.0,-1.685534e-05,0.109172,-1.000000,-0.066030,0.000796,0.067621,0.808274
v 7 4,660636.0,-4.897718e-06,0.087615,-1.000000,-0.051539,0.000000,0.052878,0.791165
v 7 5,660636.0,8.611919e-07,0.076819,-1.000000,-0.045154,0.000000,0.045756,0.711017
v 7 6,660636.0,4.365781e-06,0.073819,-1.000000,-0.043428,0.000000,0.044007,0.691372


In [68]:
generator.create_dataset(split_windows=True, window_size=64, shift=32, train_val_test=[.6, .2, .2], listen_repeat_noise=[True, True, True])

listen.shape : (928, 68, 64)
repeat.shape : (928, 68, 64)
noise.shape : (928, 68, 64)
Total number of elements in the dataset: 2784
train dataset contains 1670 elements
val dataset contains 556 elements
test dataset contains 558 elements

Refer to datasets as:
	 WindowGenerator().train
	 WindowGenerator().val
	 WindowGenerator().test


In [72]:
n = 0
for x, y in generator.test:
  n += x.shape[0]


n

558