# Import library

In [None]:
import os
import os.path as op
import numpy as np
import matplotlib.pyplot as plt
  
import mne
import nibabel as nib
from mne.datasets import sample
from mne.minimum_norm import make_inverse_operator, apply_inverse_epochs
from mne.datasets import fetch_fsaverage
from scipy.io import loadmat
from scipy.spatial import Delaunay

import gc

%matplotlib qt

In [None]:
img = nib.load("/Users/ivanl/Downloads/MRIcron_windows/MRIcron/Resources/templates/brodmann.nii.gz")

brodmann_data = img.get_fdata()
brodmann_motor = brodmann_data.reshape(-1) == 4
print(brodmann_motor)

shape, affine = img.shape[:3], img.affine
coords = np.array(np.meshgrid(*(range(i) for i in shape), indexing='ij'))
coords = np.rollaxis(coords, 0, len(shape) + 1)
mm_coords = nib.affines.apply_affine(affine, coords)

def in_hull(p, hull):
    """
    Test if points in `p` are in `hull`

    `p` should be a `NxK` coordinates of `N` points in `K` dimensions
    `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the 
    coordinates of `M` points in `K`dimensions for which Delaunay triangulation
    will be computed
    """
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)

    return hull.find_simplex(p)>=0

my_left_points = None
my_right_points = None

In [None]:
# cd to google drive
os.chdir("G:")

# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

# The files live in:
subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

In [None]:
source = mne.read_source_spaces(src)
left = source[0]
right = source[1]
left_pos = left["rr"][left["inuse"]==1]
right_pos = right["rr"][right["inuse"]==1]
                        
transformation = mne.read_trans(op.join(fs_dir, "bem", "fsaverage-trans.fif"))

In [None]:
save_path = op.join(os.getcwd(), "Shared drives", "Motor Imagery", "Source Estimate")

# Load preprocessed data

In [None]:
data_path = "Shared drives/Motor Imagery/resting state eeg & fmri/EEG_MPILMBB_LEMON/EEG_Preprocessed_BIDS_ID/EEG_Preprocessed/"
dirs = os.listdir(data_path)
files = ["sub-010006_EO.set", "sub-010006_EC.set"]

for file in files:
    if file[-3:] == "fdt":
        continue
        
    save_folder = op.join(save_path, file[4:10])
    if not op.exists(save_folder):
        os.makedirs(save_folder)
        
    input_fname = op.join(data_path, file)

    # Load preprocessed data
    raw = mne.io.read_raw_eeglab(input_fname, preload=True, verbose=False)
    
    # Set montage
    # Read and set the EEG electrode locations
    montage = mne.channels.make_standard_montage('standard_1005')
    raw.set_montage(montage)
    
    # Set common average reference
    raw.set_eeg_reference('average', projection=True, verbose=False)
    print(raw.info)
    
    # Construct epochs
    events, _ = mne.events_from_annotations(raw, verbose=False)
    raw.info["events"] = events
    if "EO" in file:
        event_id = {"eyes open": 1}
    elif "EC" in file:
        event_id = {"eyes close": 2}
    tmin, tmax = 0., 2.  # in s
    baseline = None
    epochs = mne.Epochs(
        raw, events=events,
        event_id=event_id, tmin=tmin,
        tmax=tmax, baseline=baseline, verbose=False)
    #epochs.plot()
    print(len(epochs.events))
    noise_cov = mne.compute_covariance(epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
    del raw # save memory
    
    # Check that the locations of EEG electrodes is correct with respect to MRI
    #mne.viz.plot_alignment(
    #    epochs.info, src=src, eeg=['original', 'projected'], trans=trans,
    #    show_axes=True, mri_fiducials=True, dig='fiducials')
    
    fwd = mne.make_forward_solution(epochs.info, trans=trans, src=src,
                                bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1)
    print(fwd)
    
    # Use fwd to compute the sensitivity map for illustration purposes
    #eeg_map = mne.sensitivity_map(fwd, ch_type='eeg', mode='fixed')
    #brain = eeg_map.plot(time_label='EEG sensitivity', subjects_dir=subjects_dir,
    #                     clim=dict(lims=[5, 50, 100]))

    inverse_operator = make_inverse_operator(
        epochs.info, fwd, noise_cov, loose=0.2, depth=0.8)
    del fwd # save memory
    
    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc = apply_inverse_epochs(epochs, inverse_operator, lambda2,
                                  method=method, pick_ori=None, verbose=True)
    del epochs # save memory
    
    # get motor region points (once)
    if my_left_points is None and my_right_points is None:
        my_source = stc[0]
        mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, subject)
        print(mni_lh.shape)
        mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, subject)
        print(mni_rh.shape)
        
        """
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(mm_coords.reshape(-1, 3)[brodmann_motor][:, 0], mm_coords.reshape(-1, 3)[brodmann_motor][:, 1], mm_coords.reshape(-1, 3)[brodmann_motor][:, 2], s=15, marker='|')
        ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_')
        ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')
        plt.show()
        """
        
        my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[brodmann_motor])
        my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[brodmann_motor])
        
        mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, subject)
        print(mni_left_motor.shape)
        mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, subject)
        print(mni_right_motor.shape)

        """
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
        ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
        ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o')
        ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^')
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')
        plt.show()
        """
    # slice data
    left_hemi_data = []
    right_hemi_data = []
    for source in stc:
        left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
        right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
    left_hemi_data = np.array(left_hemi_data)
    right_hemi_data = np.array(right_hemi_data)
    print(left_hemi_data.shape, right_hemi_data.shape)

    if "EO" in file:
        np.savez_compressed(op.join(save_folder, file[4:10]+"_EO.npz"), left=left_hemi_data, right=right_hemi_data)
    elif "EC" in file:
        np.savez_compressed(op.join(save_folder, file[4:10]+"_EC.npz"), left=left_hemi_data, right=right_hemi_data)
    del left_hemi_data, right_hemi_data, stc # save memory
    gc.collect()


In [None]:
"""
data_path = "Shared drives/Motor Imagery/resting state eeg & fmri/EEG_MPILMBB_LEMON/EEG_Preprocessed_BIDS_ID/EEG_Preprocessed/"
dirs = os.listdir(data_path)
print(len(dirs[740:]))
print(dirs[740:])
"""

In [None]:
#my_load_data = np.load(op.join(save_folder, file[4:10]+"_EO.npz"), allow_pickle=True)
#print(my_load_data["left"].shape)
#print(my_load_data["right"].shape)

In [None]:
# forward matrix
fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                         use_cps=True)
leadfield = fwd_fixed['sol']['data']
print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)

In [None]:
# numpy array
#reconstruct_evoked = np.dot(leadfield, stc.data)

# mne data structure
reconstruct_evoked = mne.apply_forward(fwd_fixed, stc, evoked.info)
print(reconstruct_evoked.data.shape)
for i in range(reconstruct_evoked.data.shape[0]):
    plt.plot(np.arange(0, 2.004, 0.004), reconstruct_evoked.data[i])
plt.show()

In [None]:
evoked.plot()
plt.show()