# MEG Data Preprocessing Pipeline - Subject 95

## Overview
This notebook implements a standardized preprocessing pipeline for OPM-MEG data analysis. The pipeline is designed for **subject sub-95** (healthy participant)
and processes one session at a time through a modular, reusable framework.

## Research Context
- **Subject**: sub-95 (healthy control participant)
- **Data Type**: OPM-MEG recordings (.fif format)


## 0 - Libraries:

In [None]:
# ---- LIBRARIES ----
import json
import os
import sys
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
from mne.preprocessing import compute_proj_hfc

sys.path.append('../source')


# from find_paths import get_onedrive_path
# from plot_functions import (plot_raw_vs_processed,
#                             plot_channels_comparison,)

try:
    import utils.load_utils as load_utils
    # change working directory to lid_opm if utils not found
except:
    os.chdir(os.path.dirname(os.getcwd()))
    import utils.load_utils as load_utils

## 1. Define settings and load data (incl epoched events)


### Load configuration


In [None]:
SUB_ID = 'sub-95'
SES_ID = 'Dec'


# ---- 1.2 LOAD CONFIGURATION  ----
config = load_utils.load_subject_config(SUB_ID)

# Show available tasks from the config file
available_tasks = config['tasks'][f"ses-{SES_ID}"]
print("="*60)
print("Available tasks in the configuration file:")
print("="*60)
for task in available_tasks:
    print(f"- {task}")
print("="*60)


### Load Data:

In [None]:
importlib.reload(load_utils)

In [None]:
# 1.3 Define the task for this run (select one from the list above)
TASK = 'testArdgonogo3'  #'behavfederico1' 

# find filepath
ses_path = os.path.join(
    load_utils.get_onedrive_path('source_data'),
    SUBJECT_ID,
    'OPM_MEG',
    f'ses-{SES_ID}'
)
files = os.listdir(ses_path)
sel_fname = [f for f in files if TASK in f and f.endswith('.fif')][0]

file_path = os.path.join(ses_path, sel_fname)
assert os.path.exists(file_path), 'WARNING. FILEPATH NOTE EXISTING'

In [None]:
# LOAD AND EXPLORE DATA ----


raw = mne.io.read_raw_fif(file_path, preload=True, verbose= True)
print("File loaded successfully.")

# Display the data header (raw.info)
print("\n" + "="*60)
print("DATA HEADER:")
print("="*60)
print(raw.info)

In [None]:
# From the raw.info, extract and display the sampling frequency (sfreq):

SFREQ = raw.info['sfreq']
print("="*60)
print(f"Sampling Frequency (sfreq): {SFREQ} Hz")
print("="*60)

### 1.5 Verify Sensor Geometry

The plot displays:
-   **Sensor Positions**: Each sensor's location in 3D space, shown as a black dot.
-   **Orientation Vectors**: The local coordinate system of each sensor, represented by three colored arrows:
    -   **Red**: The sensor's local X-axis.
    -   **Green**: The sensor's local Y-axis.
    -   **Blue**: The sensor's local Z-axis (normal vector, pointing away from the head).

In [None]:
print("\n" + "="*60)
print("DETAILED 3D SENSOR GEOMETRY VERIFICATION")
print("="*60)
print("Plotting sensor positions and orientation vectors from raw.info.")

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')

# Get channel names from the raw object
chnames = raw.ch_names

for ch_name in chnames:
    # Find the channel's dictionary in the info structure
    ch_info = raw.info['chs'][raw.ch_names.index(ch_name)]
    
    # Extract position and all THREE orientation vectors from the 'loc' array
    pos = ch_info['loc'][:3]    # Position (x, y, z)
    ori_x = ch_info['loc'][3:6]  # First orientation vector (X-axis)
    ori_y = ch_info['loc'][6:9]  # Second orientation vector (Y-axis)
    ori_z = ch_info['loc'][9:12] # Third orientation vector (Z-axis, normal)
    
    # Plot sensor location
    ax.scatter(*pos, c='black', s=20)

    # Plot orientation vectors (scaled for visibility)
    scale = 0.01
    ax.quiver(*pos, *ori_x, length=scale, color='red', label='Ori-X' if ch_name == chnames[0] else "")
    ax.quiver(*pos, *ori_y, length=scale, color='green', label='Ori-Y' if ch_name == chnames[0] else "")
    ax.quiver(*pos, *ori_z, length=scale, color='blue', label='Ori-Z' if ch_name == chnames[0] else "")

ax.set_xlabel('X (m)')
ax.set_ylabel('Y (m)')
ax.set_zlabel('Z (m)')
ax.set_title("Sensor Positions and Orientation Vectors")
ax.legend()

# Set aspect ratio to be equal
ax.set_box_aspect([np.ptp(ax.get_xlim()), np.ptp(ax.get_ylim()), np.ptp(ax.get_zlim())])

plt.show()

### 1.6 Define channel lists from config file:

### Note on Available MEG Components

For this specific dataset, the raw `.fif` files contain only the **Y and Z magnetic field components** from the OPM sensors. This is the expected format for this recording, and the subsequent analysis will proceed using these two available components.

In [None]:
raw.ch_names

In [None]:
# ---- 1.6 DEFINE CHANNEL LISTS FROM CONFIG ----

# Dynamically identify MEG channels from the raw data based on naming convention
meg_channels = {'y': [ch_name for ch_name in raw.ch_names
                      if '_by' in ch_name],
                'z': [ch_name for ch_name in raw.ch_names
                      if '_bz' in ch_name]}

all_meg_channels = meg_channels['y'] + meg_channels['z']

# Verify that channels were found
if not all_meg_channels:
    raise RuntimeError("No MEG channels with '_by' or '_bz' found in the data. Please check channel names.")

# Select the MEG data from the raw file
meg_data = raw.copy().pick(all_meg_channels)

print("\n" + "="*60)
print("MEG CHANNEL SELECTION")
print("="*60)
print(f"Selected {len(meg_data.ch_names)} MEG channels out of {len(raw.ch_names)} total channels.")
print("="*60)

## 2. Preprocessing: Resampling and Filtering

### 2.1 Load preprocessing settings and resample:

In [None]:
# ---- 2.1 LOAD PREPROCESSING SETTINGS AND RESAMPLE ----

# Load the general preprocessing settings
preproc_settings = load_utils.load_preproc_config()

TARGET_SFREQ = preproc_settings['TARGET_SFREQ']


# Resample the data if the original sampling frequency is higher than the target
print("\n" + "="*60)
print("RESAMPLING DATA")
print("="*60)
print(f"Original sampling rate: {meg_data.info['sfreq']} Hz.")
meg_data.resample(TARGET_SFREQ, npad='auto')
print(f"New sampling rate: {meg_data.info['sfreq']} Hz")

# Convert MEG data from Tesla (T) to picoTesla (pT)
print("Converting MEG units from T to pT...")
meg_data.apply_function(lambda x: x * 1e12)

#To visualize the effects of our preprocessing, I create a copy of the selected MEG data 
# before any filtering or resampling.
meg_data_unprocessed = meg_data.copy()


### 2.2 Apply bandpass and notch filters:

In [None]:
# ---- 2.2 APPLY BANDPASS AND NOTCH FILTERS ----

# Extract filter parameters from the settings file
BANDPASS_LOW = preproc_settings['BANDPASS_LOW']
BANDPASS_HIGH = preproc_settings['BANDPASS_HIGH']
NOTCH_FREQS = preproc_settings['NOTCH_FREQS']

print("\n" + "="*60)
print("APPLYING FILTERS")
print("="*60)

# Apply bandpass filter
print(f"Applying bandpass filter between {BANDPASS_LOW} Hz and {BANDPASS_HIGH} Hz...")
meg_data.filter(
    l_freq=BANDPASS_LOW,
    h_freq=BANDPASS_HIGH,
    method='fir',
    verbose=False
)

# Apply notch filter to remove power line noise
print(f"Applying notch filter at {NOTCH_FREQS} Hz...")
meg_data.notch_filter(
    freqs=NOTCH_FREQS,
    verbose=False
)

print("Filtering and scaling complete.")
print("="*60)

### 2.3 Homogeneous Field Correction (HFC)

Next, we apply Homogeneous Field Correction (HFC) to suppress external magnetic field interference.

In [None]:
preproc_settings

In [None]:
# ---- 2.3 APPLY HOMOGENEOUS FIELD CORRECTION (HFC) ----

# Get HFC order from preprocessing settings
HFC_ORDER = preproc_settings['HFC_ORDER']

print("\n" + "="*60)
print("APPLYING HOMOGENEOUS FIELD CORRECTION (HFC)")
print("="*60)
print(f"Computing HFC projectors with order={HFC_ORDER}...")

# Compute and apply HFC projectors
proj_hfc = compute_proj_hfc(meg_data.info, order=HFC_ORDER)
meg_data.add_proj(proj_hfc)
meg_data.apply_proj()

print("HFC projectors have been applied to the data.")
print("="*60)

### 2.4 Visual Comparison: Raw vs. Processed


In [None]:
# ---- PLOT COMPONENTS COMPARISON ----

# Define time window
START_TIME = 0  # 
DURATION = 300  # in seconds
STOP_TIME = START_TIME + DURATION


In [None]:
%matplotlib qt
# %matplotlib inline

In [None]:
# ---- PLOT Y-COMPONENT COMPARISON ----
PLOT_AX = 'z'

# Extract data and time for the plot for Y-components
# Unprocessed data
start_idx_unprocessed = meg_data_unprocessed.time_as_index(START_TIME)[0]
stop_idx_unprocessed = meg_data_unprocessed.time_as_index(STOP_TIME)[0]
y_data_unprocessed = meg_data_unprocessed.get_data(
    picks=meg_channels[PLOT_AX],
    start=start_idx_unprocessed,
    stop=stop_idx_unprocessed,
)
time_unprocessed = meg_data_unprocessed.times[
    start_idx_unprocessed:stop_idx_unprocessed
]

# Processed data
start_idx_processed = meg_data.time_as_index(START_TIME)[0]
stop_idx_processed = meg_data.time_as_index(STOP_TIME)[0]
y_data_processed = meg_data.get_data(
    picks=meg_channels[PLOT_AX], start=start_idx_processed, stop=stop_idx_processed
)
time_processed = meg_data.times[start_idx_processed:stop_idx_processed]

# Generate colors for the channels
colors_y = plt.cm.rainbow(np.linspace(0, 1, len(meg_channels[PLOT_AX])))

plot_channels_comparison(
    time_0=time_unprocessed,
    time_1=time_processed,
    raw_channels=y_data_unprocessed,
    filtered_channels=y_data_processed,
    raw_labels=meg_channels[PLOT_AX],
    filtered_labels=meg_channels[PLOT_AX],
    colors=colors_y,
    rec_label=f"{SUB_ID} - {TASK}",
    y_label="Amplitude (pT)",
    axis_label=PLOT_AX,
    sync_ylim=False,
    show_legend=False
)

In [None]:
# %matplotlib qt
%matplotlib inline

In [None]:

def plot_topoBands_fromRaw(
    rawObj,
    bands_to_plot={'low-freq': [4, 12],
                   'beta': [15, 30],
                   'gamma': [65, 85]},
):

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))


    for i, (band, fband) in enumerate(bands_to_plot.items()):

        # Band-pass filter raw data
        raw_band = rawObj.copy().filter(
            fband[0], fband[1],
            fir_design="firwin",
        )

        # Compute PSD (Power Spectral Density)
        psds, freqs = mne.time_frequency.psd_array_welch(
            raw_band.get_data(),
            sfreq=int(rawObj.info['sfreq']),
            fmin=fband[0], 
            fmax=fband[1],
            n_fft=int(rawObj.info['sfreq'])
        )

        # Average across frequencies in the band
        psd_mean = psds.mean(axis=-1)  # shape (n_channels,)

        # Pick channel info (EEG/MEG sensors only)
        picks = mne.pick_types(rawObj.info, meg=True, eeg=False)

        # Plot topomap
        topofig, topoax = mne.viz.plot_topomap(
            psd_mean[picks],
            rawObj.info,
            cmap="viridis",
            show=False,
            axes=axes[i],
        )
        axes[i].set_title(f'{band} activity ({fband[0]}-{fband[1]} Hz)',
                        size=16,)

    plt.tight_layout()

    # plt.savefig(figpath, dpi=300, facecolor='w',)

    plt.show()

In [None]:

for AX in ['z', 'y']:

    TEMP_DAT = raw.copy().pick(meg_channels[AX])

    print(f'TOPOGRAM for {AX}:')
    plot_topoBands_fromRaw(rawObj=TEMP_DAT)



## 3. Create Epochs from Events

For the `withemgacc` task, the event trigger signals were not recorded on a standard trigger channel within the `.fif` file. Instead, they were saved to a separate `.trg` file, which is only available for this specific recording session.

To ensure a self-contained and reproducible analysis, the timestamps from this `.trg` file have been manually extracted and stored within the `config_sub95.json` file under the `event_timestamps` key. The following cells load these pre-defined timestamps from the configuration to create the event markers for epoching.

In [None]:
# ---- 3.1 LOAD EVENT TIMESTAMPS FROM CONFIG ----

print("\n" + "="*60)
print(f"LOADING EVENT TIMESTAMPS FOR TASK: '{TASK}'")
print("="*60)

try:
    # Get event timestamps from the loaded configuration file for the current task
    event_times = np.array(config['event_timestamps'][TASK])
    
    if event_times.size == 0:
        raise ValueError(f"No event timestamps found for task '{TASK}' in the config file.")
    
    print(f"Successfully loaded {len(event_times)} event timestamps from config.")
    print("Event times (in seconds):")
    print(event_times)

except KeyError:
    print(f"Error: 'event_timestamps' or task '{TASK}' not found in the config file.")
    event_times = np.array([]) # Ensure event_times exists but is empty
except Exception as e:
    print(f"An error occurred: {e}")
    event_times = np.array([])

print("="*60)

In [None]:
# ---- 3.2 CREATE MNE-COMPATIBLE EVENTS ARRAY ----

print("\n" + "="*60)
print("CREATING MNE EVENTS ARRAY FROM TIMESTAMPS")
print("="*60)

if event_times.size > 0:
    # MNE requires event markers as sample indices, not seconds.
    # Convert the timestamps in seconds to sample indices by multiplying by the sampling frequency.
    # Use the sampling frequency from the unprocessed data to ensure perfect alignment.
    sfreq_unprocessed = meg_data_unprocessed.info['sfreq']
    event_samples = (event_times * sfreq_unprocessed).astype(int)
    print(f"Converted {len(event_times)} event timestamps (in seconds) to sample indices using sfreq={sfreq_unprocessed} Hz.")

    # The MNE events array has 3 columns: [sample_index, previous_event_id, event_id].
    # We use a single event_id (1) for all our triggers.
    events_array = np.array([event_samples, np.zeros_like(event_samples), np.ones_like(event_samples)]).T

    print(f"\nCreated events array with shape: {events_array.shape}")
    print("This array contains ALL events.")
    print("Showing the full array for verification:")
    print(events_array)
else:
    print("Skipping MNE event array creation because no event times were loaded.")
    events_array = np.array([])

print("="*60)

In [None]:
# ---- 3.3 CREATE EPOCHS FROM EVENTS ----

print("\n" + "="*60)
print("CREATING EPOCHS FROM EVENTS ARRAY")
print("="*60)

if events_array.size > 0:
    # Define epoching parameters from the preprocessing settings
    TMIN = preproc_settings.get('EPOCH_TMIN', -1.0)  # Start time before event
    TMAX = preproc_settings.get('EPOCH_TMAX', 3.0)   # End time after event
    BASELINE = tuple(preproc_settings.get('EPOCH_BASELINE', [-1.0, 0])) # Baseline period

    # Create epochs from the preprocessed data using the full events_array
    epochs = mne.Epochs(
        meg_data,
        events=events_array,
        tmin=TMIN,
        tmax=TMAX,
        baseline=BASELINE,
        preload=True,
        verbose=False
    )

    print(f"Successfully created {len(epochs)} epochs.")
    print(f"Each epoch runs from {TMIN}s to {TMAX}s relative to the event.")
    print(f"Baseline correction was applied using the interval {BASELINE}s.")
    print("="*60)

    # Inspect the epochs object
    print("\nEpochs object info:")
    print(epochs)

else:
    print("Skipping epoch creation because no events were found.")
    epochs = None

print("="*60)