In [None]:
import mne
import numpy as np
import os
import os.path as op
import matplotlib.pyplot as plt
import nibabel as nib
from mne.datasets import sample
from mne.minimum_norm import make_inverse_operator, apply_inverse_epochs
from mne.datasets import fetch_fsaverage
import scipy.io
from scipy.io import loadmat
from scipy.spatial import Delaunay

import gc

%matplotlib inline

In [None]:
channels_mapping = {
    "EEG-Fz": "Fz",
    "EEG-0": "FC3",
    "EEG-1": "FC1",
    "EEG-2": "FCz",
    "EEG-3": "FC2",
    "EEG-4": "FC4",
    "EEG-5": "C5",
    "EEG-C3": "C3", 
    "EEG-6": "C1",
    "EEG-Cz": "Cz",
    "EEG-7": "C2",
    "EEG-C4": "C4",
    "EEG-8": "C6",
    "EEG-9": "CP3",
    "EEG-10": "CP1",
    "EEG-11": "CPz",
    "EEG-12": "CP2",
    "EEG-13": "CP4",
    "EEG-14": "P1",
    "EEG-Pz": "Pz",
    "EEG-15": "P2",
    "EEG-16": "POz",
    "EOG-left": "EOG-left",
    "EOG-central": "EOG-central",
    "EOG-right": "EOG-right"
}

channels_type_mapping = {
    "Fz": "eeg",
    "FC3": "eeg",
    "FC1": "eeg",
    "FCz": "eeg",
    "FC2": "eeg",
    "FC4": "eeg",
    "C5": "eeg",
    "C3": "eeg", 
    "C1": "eeg",
    "Cz": "eeg",
    "C2": "eeg",
    "C4": "eeg",
    "C6": "eeg",
    "CP3": "eeg",
    "CP1": "eeg",
    "CPz": "eeg",
    "CP2": "eeg",
    "CP4": "eeg",
    "P1": "eeg",
    "Pz": "eeg",
    "P2": "eeg",
    "POz": "eeg",
    "EOG-left": "eog",
    "EOG-central": "eog",
    "EOG-right": "eog"
}

In [None]:
img = nib.load("/Users/ivanl/Downloads/MRIcron_windows/MRIcron/Resources/templates/brodmann.nii.gz")

brodmann_data = img.get_fdata()
brodmann_motor = brodmann_data.reshape(-1) == 4
print(brodmann_motor)

shape, affine = img.shape[:3], img.affine
coords = np.array(np.meshgrid(*(range(i) for i in shape), indexing='ij'))
coords = np.rollaxis(coords, 0, len(shape) + 1)
mm_coords = nib.affines.apply_affine(affine, coords)

def in_hull(p, hull):
    """
    Test if points in `p` are in `hull`

    `p` should be a `NxK` coordinates of `N` points in `K` dimensions
    `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the 
    coordinates of `M` points in `K`dimensions for which Delaunay triangulation
    will be computed
    """
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)

    return hull.find_simplex(p)>=0

my_left_points = None
my_right_points = None

In [None]:
""""
labels utility function
"""
def load_subject_labels(name="A01E.mat", dir="drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"):
  data = scipy.io.loadmat(dir + name)["classlabel"].reshape(-1)
  return data

def load_all_true_labels(dataset_path):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject_labels(name=file, dir=root) 
  return data

"""
plot graph utility function
"""
def plot_average_graph(subject_name="A01T.gdf", Class="left", filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  x = average[Class]
  if filter_channels is None:
    fig, axs = plt.subplots(x.shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x.shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].plot(range(x.shape[0]), x[:, channel])
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
  else :
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x.shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].plot(range(x.shape[0]), x[:, channel])
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          break
  plt.tight_layout()

def plot_multiple_graph(subject_name="A02T.gdf", classes=["left", "right", "foot", "tongue"], filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  color = {"left": "b", "right": "g", "foot": "c", "tongue": "m", "tongue": "y"}
  x = []
  for Class in classes:
    x.append(average[Class])

  if filter_channels is None:
    fig, axs = plt.subplots(x[0].shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x[0].shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
      for i in range(len(classes)):
        axs[channel].plot(range(x[i].shape[0]), x[i][:, channel], color=color[classes[i]])
  else:
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x[0].shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          for j in range(len(classes)):
            axs[i].plot(range(x[j].shape[0]), x[j][:, channel], color=color[classes[j]])
          break
  plt.tight_layout()

"""
load data function
"""
def load_subject(name="A01T.gdf", dir='drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/', debug=None):
  subject_data = {}
  # Load data
  raw = mne.io.read_raw_gdf(dir + name)
  # Rename channels
  raw.rename_channels(channels_mapping)
  # Set channels types
  raw.set_channel_types(channels_type_mapping)
  # Set montage
  # Read and set the EEG electrode locations
  ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
  raw.set_montage(ten_twenty_montage)
  # Set common average reference
  raw.set_eeg_reference('average', projection=True, verbose=False)
  # Drop eog channels
  raw.drop_channels(["EOG-left", "EOG-central", "EOG-right"])

  subject_data["raw"] = raw
  subject_data["info"] = raw.info
  if debug == "all":
    print("-------------------------------------------------------------------------------------------------------")
    for key, item in raw.info.items():
      print(key, item)
    print("-------------------------------------------------------------------------------------------------------")
  
  """
  '276': 'Idling EEG (eyes open)'
  '277': 'Idling EEG (eyes closed)'
  '768': 'Start of a trial'
  '769': 'Cue onset left (class 1)'
  '770': 'Cue onset right (class 2)'
  '771': 'Cue onset foot (class 3)'
  '772': 'Cue onset tongue (class 4)'
  '783': 'Cue unknown'
  '1023': 'Rejected trial'
  '1072': 'Eye movements'
  '32766': 'Start of a new run'
  """
  custom_mapping = {'276': 276, '277': 277, '768': 768, '769': 769, '770': 770, '771': 771, '772': 772, '783': 783, '1023': 1023, '1072': 1072, '32766': 32766}
  events_from_annot, event_dict = mne.events_from_annotations(raw, event_id=custom_mapping)

  if debug == " all":
    print("-------------------------------------------------------------------------------------------------------")
    print(event_dict)
    print(events_from_annot)
    print("-------------------------------------------------------------------------------------------------------")
    
    print("-------------------------------------------------------------------------------------------------------")
    for i in range(len(raw.annotations)):
      print(events_from_annot[i], raw.annotations[i])  
    print("-------------------------------------------------------------------------------------------------------")

  class_info = "Idling EEG (eyes open): " + str(len(events_from_annot[events_from_annot[:, 2]==276][:, 0])) + "\n" + \
               "Idling EEG (eyes closed): " + str(len(events_from_annot[events_from_annot[:, 2]==277][:, 0])) + "\n" + \
               "Start of a trial: " + str(len(events_from_annot[events_from_annot[:, 2]==768][:, 0])) + "\n" + \
               "Cue onset left (class 1): " + str(len(events_from_annot[events_from_annot[:, 2]==769][:, 0])) + "\n" + \
               "Cue onset right (class 2): " + str(len(events_from_annot[events_from_annot[:, 2]==770][:, 0])) + "\n" + \
               "Cue onset foot (class 3): " + str(len(events_from_annot[events_from_annot[:, 2]==771][:, 0])) + "\n" + \
               "Cue onset tongue (class 4): " + str(len(events_from_annot[events_from_annot[:, 2]==772][:, 0])) + "\n" + \
               "Cue unknown: " + str(len(events_from_annot[events_from_annot[:, 2]==783][:, 0])) + "\n" + \
               "Rejected trial: " + str(len(events_from_annot[events_from_annot[:, 2]==1023][:, 0])) + "\n" + \
               "Eye movements: " + str(len(events_from_annot[events_from_annot[:, 2]==1072][:, 0])) + "\n" + \
               "Start of a new run: " + str(len(events_from_annot[events_from_annot[:, 2]==32766][:, 0]))
  subject_data["class_info"] = class_info

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    print(class_info)
    print("-------------------------------------------------------------------------------------------------------")

  epoch_data = {"left": [], "right": [], "foot": [], "tongue": [], "unknown": []}
  rejected_trial = events_from_annot[events_from_annot[:, 2]==1023][:, 0]
  class_dict = {"left": 769, "right": 770, "foot": 771, "tongue": 772, "unknown": 783}
  raw_data = raw.get_data() #(22, 672528)
  start = 0                 # cue
  stop = 500                # cue+3.0s

  for event_class, event_id in class_dict.items():
    current_event = events_from_annot[events_from_annot[:, 2]==event_id][:, 0]
    if event_class == "unknown":
      subject_true_labels = true_labels[name[:4]+".mat"]
      class_dict_labels = {1: "left", 2: "right", 3: "foot", 4: "tongue"}
      for i in range(len(current_event)):
        # exclude artifact
        if (current_event[i] - 500 != rejected_trial).all():
          current_event_data = np.expand_dims(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]), axis=0)
          if (epoch_data.get(class_dict_labels[subject_true_labels[i]]) == None).all():
            epoch_data[class_dict_labels[subject_true_labels[i]]] = current_event_data
          else:
            epoch_data[class_dict_labels[subject_true_labels[i]]] = np.append(epoch_data[class_dict_labels[subject_true_labels[i]]], current_event_data, axis=0)
    else:
      for i in range(len(current_event)):
        # exclude artifact
        if((current_event[i] - 500 != rejected_trial).all()):
          epoch_data[event_class].append(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]))
      epoch_data[event_class] = np.array(epoch_data[event_class])

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    for key, data in epoch_data.items():
      print(key, len(data))
    print("-------------------------------------------------------------------------------------------------------")

  for event_class, event_data in epoch_data.items():
    epoch_data[event_class] = np.array(event_data)

  subject_data["epoch_data"] = epoch_data
    

  return subject_data

def load_all_subject(dataset_path):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject(name=file, dir=root) 
  return data

In [None]:
# cd to google drive
os.chdir("G:")

# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

# The files live in:
subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

source = mne.read_source_spaces(src)
left = source[0]
right = source[1]
left_pos = left["rr"][left["inuse"]==1]
right_pos = right["rr"][right["inuse"]==1]
                        
transformation = mne.read_trans(op.join(fs_dir, "bem", "fsaverage-trans.fif"))

save_path = op.join(os.getcwd(), "Shared drives", "Motor Imagery", "Source Estimate")

In [None]:
true_labels_path = "Shared drives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"
true_labels = load_all_true_labels(true_labels_path)

dataset_path = 'Shared drives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/'
data = load_all_subject(dataset_path)

In [None]:
# some information to help to understand functions and data structure
"""
for key, item in data.items():
  print(key)

ch_names = data["A01T.gdf"]["info"]["ch_names"]
print(ch_names)

print(data["A01T.gdf"]["class_info"])

for key, value in data.items():
  print(key)
  for event_class, event_data in value["epoch_data"].items():
      print(event_class, len(event_data))
  print()

subject_name = "A01T.gdf"
Class = "left"
filter_channels = ["C3", "Cz", "C4"]
plot_average_graph(subject_name, Class, filter_channels)

subject_name = "A02T.gdf"
classes = ["left", "right"]
filter_channels = ["C3", "Cz", "C4"]
plot_multiple_graph(subject_name, classes, filter_channels)
"""

In [None]:
def create_epochs(data):
  subjects_data = {}

  for subject in data.keys():
    epochs_data = {}
    for event in data[subject]["epoch_data"].keys():
      if data[subject]["epoch_data"][event].any():
        epochs_data[event] = mne.EpochsArray(data[subject]["epoch_data"][event], data[subject]["info"])
    subjects_data[subject] = epochs_data

  return subjects_data

epochs = create_epochs(data)

In [None]:
my_epochs = epochs["A01T.gdf"]["right"]
my_evoked = my_epochs.average().pick("eeg")

noise_cov = mne.compute_covariance(my_epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
fwd = mne.make_forward_solution(my_epochs.info, trans=trans, src=src,
                            bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1)
# forward matrix
fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                         use_cps=True)

inverse_operator = make_inverse_operator(
    my_epochs.info, fwd, noise_cov, loose=0.2, depth=0.8)

method = "sLORETA"
snr = 3.
lambda2 = 1. / snr ** 2
stc = mne.minimum_norm.apply_inverse(my_evoked, inverse_operator, lambda2,
                              method=method, pick_ori="normal", verbose=True)

reconstruct_evoked = mne.apply_forward(fwd_fixed, stc, my_evoked.info)
my_evoked.plot_topomap()
reconstruct_evoked.plot_topomap()