In [None]:
"""
Written by Sitong Chen
This code is for source analysis of EEG data
"""

In [None]:
# %matplotlib qt
import mne
from mne.io.constants import FIFF
import mne
from mne_bids import BIDSPath, read_raw_bids
from mne.coreg import Coregistration
from mne.viz import Brain
import os.path as op
import numpy as np
from mne.minimum_norm import make_inverse_operator, apply_inverse_epochs,apply_inverse
from mne import setup_source_space, setup_volume_source_space
import scipy.io
import pandas as pd
import os
import matplotlib.pyplot as plt
from mne.channels import make_dig_montage
import pickle

In [None]:
# Define the path to the Brain Vision files in the BIDS dataset
onset_start = -0.2
onset_end = 0.8
time_start = 0
time_end = 10
# Define the path to the Brain Vision files in the BIDS dataset
bids_root = 'example_root'
subject_name = "example_name"
session = "example_session"
run = "example_run"
task = "example_task"
output_dir = f'example_path'

fsaverage_data_dir = 'example_fsaverage_path'  
mne.datasets.sample.data_path(verbose=True)
sample_data_dir = mne.datasets.sample.data_path()
subject_dir = os.path.join(sample_data_dir, 'subjects')
subject = 'fsaverage'


brainvision_path = os.path.join(
    bids_root, 
    subject_name, 
    f'ses-{session}', 
    'eeg', 
    f'{subject_name}_ses-{session}_task-{task}_run-{run}_eeg.vhdr'
)

def create_custom_montage(coord_file):
    """Create a custom montage from the coordinates file."""
    coords = pd.read_csv(coord_file, sep='\t')

    # Prepare lists for names and positions
    names = coords['name'].tolist()
    positions = np.array([coords['x'], coords['y'], coords['z']]).T

    # Create the montage
    montage = make_dig_montage(ch_pos={name: pos for name, pos in zip(names, positions)},
                               nasion=None,
                               lpa=None,
                               rpa=None)

    return montage

# Read the Brain Vision data into a Raw object
raw = mne.io.read_raw_brainvision(brainvision_path, preload=True)

coord_file = f'{bids_root}/{subject_name}/ses-{session}/eeg/{subject_name}_ses-{session}_space-CapTrak_electrodes.tsv'

# Create the custom montage using the coordinates
montage = create_custom_montage(coord_file)

# Set the custom montage
raw.set_montage(montage)
raw.get_montage
# Preprocess: Set EEG reference and apply projection
raw.set_eeg_reference('average', projection=True)
raw.apply_proj()

# **Extract events from annotations**
events, event_id = mne.events_from_annotations(raw)

raw.crop(tmin=time_start, tmax=time_end)

epochs = mne.Epochs(
    raw, 
    events, 
    event_id= None, 
    tmin=onset_start, 
    tmax=onset_end, 
    # baseline=(None, 0), 
    baseline=None,
    preload=True
)
info = epochs.info
epochs.plot_drop_log()

os.makedirs(output_dir, exist_ok=True)
info_fname = os.path.join(output_dir, 'info.pkl')

if not os.path.exists(info_fname):
    with open(info_fname, 'wb') as f:
        pickle.dump(info, f)
    print(f"Info object saved to {info_fname}")
else:
    print(f"Info file already exists: {info_fname}")


In [None]:
montage = raw.get_montage()
montage

In [None]:
epochs.drop_log 
evoked = epochs.average()

In [None]:

bem_fname = os.path.join(subject_dir, subject, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

if not os.path.exists(bem_fname):
    raise FileNotFoundError(f"BEM solution not found at {bem_fname}")
bem = mne.read_bem_solution(bem_fname)
src = mne.setup_source_space(subject, subjects_dir=subject_dir, add_dist=True)
# Save the BEM solution
bem_fname_saved = os.path.join(output_dir, 'bem')
mne.write_bem_solution(os.path.join(output_dir, 'bem-sol.fif'), bem,overwrite=True)
print(f"BEM solution saved to {os.path.join(output_dir, 'bem-sol.fif')}")


In [None]:
# Save the source space as a pickle file
src_file = os.path.join(output_dir, 'src.pkl')
with open(src_file, 'wb') as f:
    pickle.dump(src, f)

# Optionally, save the settings for the plot (e.g., view parameters)
alignment_settings = {
    'azimuth': 180,
    'elevation': 90,
    'distance': 0.30,
    'focalpoint': (-0.03, -0.01, 0.03)
}

# Save the alignment settings as a pickle file
settings_file = os.path.join(output_dir, 'alignment_settings.pkl')
with open(settings_file, 'wb') as f:
    pickle.dump(alignment_settings, f)

print(f"Data saved to {output_dir}")


In [None]:
from mne.datasets import fetch_fsaverage
fetch_fsaverage(verbose=True)

fiducials = 'estimated'
subject = 'fsaverage'
coreg = Coregistration(info, subject,subjects_dir=subject_dir,fiducials=fiducials)
coreg.fit_fiducials(verbose=True)
coreg.fit_icp(n_iterations=6, nasion_weight=2.0, verbose=True)
coreg.omit_head_shape_points(distance=5.0 / 1000)  # distance is in meters
coreg.fit_icp(n_iterations=20, nasion_weight=10.0, verbose=True)
dists = coreg.compute_dig_mri_distances() * 1e3  # in mm
print(
    f"Distance between HSP and MRI (mean/min/max):\n{np.mean(dists):.2f} mm "
    f"/ {np.min(dists):.2f} mm / {np.max(dists):.2f} mm"
)
trans = coreg.trans
trans_fname = os.path.join(output_dir, 'trans.fif')
mne.write_trans(trans_fname, coreg.trans,overwrite=True)
print(f"Transformation matrix saved to {trans_fname}")

In [None]:
fwd = mne.make_forward_solution(
info = info,
trans = trans,
src = src,
bem = bem,
meg = False,
eeg=True,
mindist=0.0,
n_jobs=8,
verbose=True,
)

In [None]:
leadfield = fwd["sol"]["data"]

In [None]:
noise_cov = mne.compute_covariance(epochs, tmin=start, tmax=end, method=["shrunk", "empirical"], rank=None, verbose=True, n_jobs=8)

In [None]:
from mne.minimum_norm import read_inverse_operator, apply_inverse_raw, apply_inverse
inv_method = "eLORETA"  # sLORETA, MNE, dSPM

lambda2 = 1 / 6

inverse_operator = make_inverse_operator(
    info, fwd, noise_cov, depth=0.8,
)
inverse_operator_fname = os.path.join(output_dir, 'inverse_operator')
mne.minimum_norm.write_inverse_operator(inverse_operator_fname, inverse_operator,overwrite=True)
print(f"Inverse operator saved to {inverse_operator_fname}-inv.fif")

In [None]:
stc = apply_inverse(evoked, inverse_operator, lambda2, method=inv_method, pick_ori=None)
# stc = stc.copy().align_src(src)
# Save the source estimate (stc)
stc_fname = os.path.join(output_dir, 'source_estimate')
stc.save(stc_fname,overwrite=True)
print(f"Source estimate saved to {stc_fname}-stc.fif")

In [None]:
stcs[0].plot(subjects_dir = subject_dir,subject = subject,surface = 'white',hemi = 'both', time_unit='s',time_viewer=True,src = src)

stc.plot(subjects_dir = subject_dir,subject = subject,surface = 'white',hemi = 'both', time_unit='s',time_viewer=True,src = src)