In [None]:
!pip install mne
!pip install mne-connectivity
!pip install pyedflib
!pip install networkx
!pip install scikit-learn
!pip install tqdm
!pip install torch-geometric

In [None]:
import mne
import numpy as np
import torch
import os
import pyedflib
import glob
import matplotlib.pyplot as plt
from google.colab import files
from google.colab import drive
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler
from mne.preprocessing import ICA
from tqdm.notebook import tqdm
import networkx as nx
from scipy import stats
import shutil
import pandas as pd
import re
import pickle

drive.mount('/content/drive')

base_dir = "/content/drive/My Drive/4106 Project/eeg-motor-movementimagery-dataset-1.0.0/files"


'''
output_dir = "/content/processed"                               #change i guess!!!!!!
os.makedirs(base_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

#process uploaded files
subject_pattern = re.compile(r'S(\d+)')
run_pattern = re.compile(r'R(\d+)')
uploaded_files = []

for filename, content in uploaded.items():
    if filename.endswith('.edf'):
        # Extract subject ID if present in filename
        subject_match = subject_pattern.search(filename)
        subject_id = f"S{subject_match.group(1)}" if subject_match else "unknown_subject"

        #create subject directory if it dne
        subject_dir = os.path.join(base_dir, subject_id)
        os.makedirs(subject_dir, exist_ok=True)

        #save file to the appropriate location
        file_path = os.path.join(subject_dir, filename)
        with open(file_path, 'wb') as f:
            f.write(content)

        #extract run number if present
        run_match = run_pattern.search(filename)
        run_info = f"run {run_match.group(1)}" if run_match else "unknown run"

        print(f"Saved {filename} ({subject_id}, {run_info}) to {file_path}")
        uploaded_files.append(file_path)
    else:
        print(f"Skipping {filename} - not an EEG .edf file")

'''

In [None]:
'''
def bandpass_filter(data, sfreq, low=0.5, high=40.0, order=4):
    nyq = 0.5 * sfreq
    low /= nyq
    high /= nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)
'''

def load_events(event_file):
    events = []
    with open(event_file, 'r', encoding='latin-1') as f: # Changed encoding to 'latin-1'
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                try:
                    timestamp = float(parts[0])
                    label = parts[1]
                    events.append((timestamp, label))
                except Exception as e:
                    print(f"Skipping line due to error: {e}")
    return events

#utility fxnsss

def bandpass_filter(data, sfreq, low=0.8, high=30.0, order=4):
    """Apply bandpass filter to EEG data."""
    nyq = 0.5 * sfreq
    low /= nyq
    high /= nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def find_edf_files(base_path, pattern="**/*.edf"):
    """Find all EDF files recursively in the base directory."""
    files = glob.glob(os.path.join(base_path, pattern), recursive=True)
    return files

def clean_channel_names(raw):
    """Clean channel names by removing trailing periods."""
    ch_names = raw.info['ch_names']
    clean_names = {name: name.rstrip('.') for name in ch_names}
    raw.rename_channels(clean_names)
    return raw

def extract_events(raw):
    """Extract events from raw data's annotations."""
    try:
        events, event_id = mne.events_from_annotations(raw)
        return events, event_id
    except Exception as e:
        print(f"Error extracting events: {e}")
        return None, None

def apply_ica(raw, n_components=15, random_state=42):
    """Apply ICA without requiring EOG channels"""
    #create copy of raw data
    raw_copy = raw.copy()

    #initialize ICA
    ica = ICA(n_components=n_components, random_state=random_state, method='infomax')

    #apply ICA!
    ica.fit(raw_copy)

    try:
        #try EOG-based approach first
        #nope
        eog_indices, eog_scores = ica.find_bads_eog(raw_copy)
        ica.exclude = eog_indices
        print(f"Found and excluded {len(eog_indices)} EOG-related components")
    except Exception as e:
        #fall back to statistical approach
        ica_sources = ica.get_sources(raw_copy).get_data()

        #calculate metrics for each component
        kurt = stats.kurtosis(ica_sources, axis=1)
        var = np.var(ica_sources, axis=1)

        #outliers
        kurt_threshold = np.mean(kurt) + 2 * np.std(kurt)
        var_threshold = np.mean(var) + 2 * np.std(var)

        bad_idx = np.where((kurt > kurt_threshold) | (var > var_threshold))[0]

        #3 components limit
        if len(bad_idx) > 3:
            component_scores = kurt[bad_idx] + var[bad_idx]/np.max(var)
            worst_idx = bad_idx[np.argsort(component_scores)[-3:]]
            ica.exclude = worst_idx
        else:
            ica.exclude = bad_idx

        print(f"No EOG channels found. Excluded {len(ica.exclude)} components based on statistics")

    #apply ICA!
    raw_cleaned = raw_copy.copy()
    ica.apply(raw_cleaned)

    return raw_cleaned

def calculate_connectivity(epochs, method='pli', fmin=8, fmax=30):
    """
    Calculate connectivity matrix for EEG channels.

    Parameters:
    -----------
    epochs : mne.Epochs
        Epoched EEG data
    method : str
        Connectivity method ('pli', 'plv', etc.)
    fmin, fmax : float
        Frequency band to use for connectivity

    Returns:
    --------
    conn : ndarray
        Connectivity matrix (n_channels x n_channels)
    """
    #get data + info
    sfreq = epochs.info['sfreq']
    ch_names = epochs.ch_names


    try:
        #new API
        from mne_connectivity import spectral_connectivity_epochs

        #connectivity calc - new API
        conn = spectral_connectivity_epochs(
            epochs,
            method=method,
            mode='multitaper',
            sfreq=sfreq,
            fmin=fmin,
            fmax=fmax,
            faverage=True,
            mt_adaptive=True,
            n_jobs=1
        )

        #connectivity matrix
        conn_matrix = conn.get_data(output='dense')[:, :, 0]  #1st (& only) frequency band

    except (ImportError, ModuleNotFoundError):
        try:
            #older mne.connectivity
            from mne.connectivity import spectral_connectivity

            #connectivity calc- old API
            data = epochs.get_data()
            conn, freqs, times, n_epochs, n_tapers = spectral_connectivity(
                data,
                method=method,
                mode='multitaper',
                sfreq=sfreq,
                fmin=fmin,
                fmax=fmax,
                faverage=True,
                mt_adaptive=True,
                n_jobs=1
            )

            #reshape to n_channels x n_channels
            n_channels = len(ch_names)
            conn_matrix = conn.reshape(n_channels, n_channels)

        except (ImportError, ModuleNotFoundError):
            raise ImportError("Neither mne_connectivity nor mne.connectivity is available. "
                             "Please install with: pip install mne-connectivity")

    #ensure diagonal is zero
    np.fill_diagonal(conn_matrix, 0)

    return conn_matrix

def create_graph_from_connectivity(connectivity, threshold=0.3):
    """
    Create a graph structure from connectivity matrix.

    Parameters:
    -----------
    connectivity : ndarray
        Connectivity matrix (n_channels x n_channels)
    threshold : float
        Threshold for edge creation

    Returns:
    --------
    edge_index : torch.Tensor
        Edge indices for PyTorch Geometric
    edge_attr : torch.Tensor
        Edge attributes (connectivity values)
    """
    #threshold
    adj_matrix = connectivity.copy()
    adj_matrix[adj_matrix < threshold] = 0

    #create edge index & attributes
    edges = []
    edge_weights = []

    for i in range(adj_matrix.shape[0]):
        for j in range(adj_matrix.shape[1]):
            if i != j and adj_matrix[i, j] > 0:
                edges.append([i, j])
                edge_weights.append(adj_matrix[i, j])

    #check if there are any edges
    if not edges:
        #create a minimal edge index if no edges were created
        edges = [[0, 1], [1, 0]]
        edge_weights = [0.1, 0.1]

    #convert to PyTorch tensors
    edge_index = torch.tensor(edges).t().contiguous()
    edge_attr = torch.tensor(edge_weights).float()

    return edge_index, edge_attr

def map_eeg_to_brodmann(channel_names):
    """
    Map EEG channel names to Brodmann areas

    Parameters:
    -----------
    channel_names : list of str
        Channel names to map

    Returns:
    --------
    ba_areas : list of str
        Brodmann area labels for each channel
    """
    # Dictionary mapping EEG channels to Brodmann areas
    # Format: BA number + L/R/M (Left/Right/Midline)
    eeg_to_ba = {
        # Frontal
        'fp1': '10L',    # Left anterior prefrontal
        'fpz': '10M',    # Midline anterior prefrontal
        'fp2': '10R',    # Right anterior prefrontal
        'af7': '10L',    # Left anterior prefrontal
        'af3': '09L',    # Left dorsolateral prefrontal
        'afz': '09M',    # Midline dorsolateral prefrontal
        'af4': '09R',    # Right dorsolateral prefrontal
        'af8': '10R',    # Right anterior prefrontal
        'f7': '45L',     # Left inferior frontal
        'f5': '45L',     # Left inferior frontal
        'f3': '08L',     # Left middle frontal
        'f1': '08L',     # Left middle frontal
        'fz': '08M',     # Midline frontal
        'f2': '08R',     # Right middle frontal
        'f4': '08R',     # Right middle frontal
        'f6': '45R',     # Right inferior frontal
        'f8': '45R',     # Right inferior frontal
        # Frontal-Central
        'ft7': '44L',    # Left inferior frontal
        'fc5': '06L',    # Left premotor
        'fc3': '06L',    # Left premotor
        'fc1': '06L',    # Left premotor
        'fcz': '06M',    # Supplementary motor area
        'fc2': '06R',    # Right premotor
        'fc4': '06R',    # Right premotor
        'fc6': '06R',    # Right premotor
        'ft8': '44R',    # Right inferior frontal
        # Central
        't7': '22L',     # Left superior temporal
        'c5': '04L',     # Left primary motor
        'c3': '04L',     # Left primary motor
        'c1': '04L',     # Left primary motor
        'cz': '04M',     # Midline motor
        'c2': '04R',     # Right primary motor
        'c4': '04R',     # Right primary motor
        'c6': '04R',     # Right primary motor
        't8': '22R',     # Right superior temporal
        # Central-Parietal
        'tp7': '21L',    # Left middle temporal
        'cp5': '40L',    # Left supramarginal gyrus
        'cp3': '40L',    # Left supramarginal gyrus
        'cp1': '05L',    # Left somatosensory association
        'cpz': '05M',    # Midline somatosensory association
        'cp2': '05R',    # Right somatosensory association
        'cp4': '40R',    # Right supramarginal gyrus
        'cp6': '40R',    # Right supramarginal gyrus
        'tp8': '21R',    # Right middle temporal
        # Parietal
        'p7': '39L',     # Left angular gyrus
        'p5': '39L',     # Left angular gyrus
        'p3': '07L',     # Left superior parietal
        'p1': '07L',     # Left superior parietal
        'pz': '07M',     # Midline parietal
        'p2': '07R',     # Right superior parietal
        'p4': '07R',     # Right superior parietal
        'p6': '39R',     # Right angular gyrus
        'p8': '39R',     # Right angular gyrus
        # Parietal-Occipital
        'po7': '19L',    # Left associative visual
        'po3': '19L',    # Left associative visual
        'poz': '19M',    # Midline associative visual
        'po4': '19R',    # Right associative visual
        'po8': '19R',    # Right associative visual
        # Occipital
        'o1': '17L',     # Left primary visual
        'oz': '17M',     # Midline primary visual
        'o2': '17R',     # Right primary visual
    }

    #lowercase & trailing periods
    ba_areas = []
    for ch in channel_names:
        ch_clean = ch.lower().rstrip('.')
        if ch_clean in eeg_to_ba:
            ba_areas.append(eeg_to_ba[ch_clean])
        else:
            #if not found, keep original w ?
            ba_areas.append(ch + '?')
    return ba_areas

def normalize_data(data, method='minmax'):
    """Normalize data using different methods."""
    if method == 'minmax':
        return (data + 100) / 200  #similar to github
    elif method == 'z':
        return (data - np.mean(data)) / np.std(data)
    elif method == 's':
        #normalize along the time dimension for each channel
        mean = np.mean(data, axis=-1, keepdims=True)
        std = np.std(data, axis=-1, keepdims=True)
        return (data - mean) / std
    else:
        return data

def visualize_connectivity(connectivity, ch_names, title="Brain Connectivity"):
    """Visualize connectivity matrix as a heatmap and a graph."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot connectivity matrix
    im = ax1.imshow(connectivity, cmap='viridis')
    ax1.set_title("Connectivity Matrix")

    #only show subset of labels if too many channels
    n = len(ch_names)
    if n > 20:
        step = n // 10
        indices = np.arange(0, n, step)
        ax1.set_xticks(indices)
        ax1.set_yticks(indices)
        ax1.set_xticklabels([ch_names[i] for i in indices], rotation=90)
        ax1.set_yticklabels([ch_names[i] for i in indices])
    else:
        ax1.set_xticks(np.arange(len(ch_names)))
        ax1.set_yticks(np.arange(len(ch_names)))
        ax1.set_xticklabels(ch_names, rotation=90)
        ax1.set_yticklabels(ch_names)

    plt.colorbar(im, ax=ax1)

    #plot graph
    G = nx.from_numpy_array(connectivity)
    pos = nx.circular_layout(G)
    nx.draw_networkx(G, pos=pos, ax=ax2, with_labels=True,
                    node_color='lightblue', node_size=500,
                    font_size=10, font_weight='bold')
    ax2.set_title("Connectivity Graph")

    plt.tight_layout()
    plt.suptitle(title, fontsize=16)
    plt.subplots_adjust(top=0.85)
    return fig

In [None]:
edf_path = "/content/drive/My Drive/4106 Project/eeg-motor-movementimagery-dataset-1.0.0/files/S001/S001R01.edf"
raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)

# Print channel names
print("Channels:", raw.ch_names)

# Print annotations (if present)
print("Annotations:", raw.annotations)

# Convert annotations to events (if needed)
events, event_id = mne.events_from_annotations(raw)
print("Event dictionary:", event_id)
print("Events array:\n", events)

In [None]:
# 4. Define valid motor imagery recordings and processing parameters
#ICA
#################modify to

# Define valid motor imagery recordings + params
valid_runs = ['R03', 'R04', 'R07', 'R08', 'R11', 'R12']
tmin, tmax = 0.0, 2.0  # seconds for each epoch
connectivity_methods = ['pli']  # Phase Lag Index- can add 'plv', 'wpli' etc.
normalization_method = 'minmax' # Same as in official implementation


batch_size = 10                  #try

#base_dir = os.path.join(output_dir, "/content/eeg_data")

base_dir = "/content/drive/My Drive/4106 Project/eeg-motor-movementimagery-dataset-1.0.0/files"
output_dir = "/content/drive/My Drive/4106 Project/eeg_data"


# Get all subject folders
subject_folders = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
print(f"Found {len(subject_folders)} subject folders")

# Process subjects in batches
all_subjects_data = []

os.makedirs(output_dir, exist_ok=True)

for batch_idx in range(0, len(subject_folders), batch_size):
    batch_folders = subject_folders[batch_idx:batch_idx+batch_size]
    print(f"Processing batch {batch_idx//batch_size + 1}/{(len(subject_folders)-1)//batch_size + 1}")

    for subject_folder in tqdm(batch_folders):
        subject_path = os.path.join(base_dir, subject_folder)
        subject_data = {
            'subject_id': subject_folder,
            'epochs_data': [],
            'labels': [],
            'connectivity': [],
            'edge_indices': [],
            'edge_attrs': []
        }

        # Find all EDF files for this subject
        edf_files = [f for f in os.listdir(subject_path) if f.endswith(".edf")]

        for edf_file in edf_files:
            # Check if it's a valid run
            if not any(run in edf_file for run in valid_runs):
                continue  # Skip rest or irrelevant runs

            edf_path = os.path.join(subject_path, edf_file)
            try:
                # Load and preprocess data
                raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
                raw = clean_channel_names(raw)

                # Extract events
                events, event_id = extract_events(raw)
                if events is None or len(events) == 0:
                    print(f"No events found in {edf_file}")
                    continue

                # Pick EEG channels only
                raw.pick_types(eeg=True)

                # Apply bandpass filter
                raw.filter(0.5, 40.0, fir_design='firwin')

                # Apply ICA for artifact removal
                raw_cleaned = apply_ica(raw)

                # Create epochs
                epochs = mne.Epochs(raw_cleaned, events, event_id=event_id,
                                   tmin=tmin, tmax=tmax, baseline=None,
                                   preload=True, verbose=False)

                # Calculate connectivity for each method
                for method in connectivity_methods:
                    connectivity = calculate_connectivity(epochs, method=method)

                    # Create graph structure
                    edge_index, edge_attr = create_graph_from_connectivity(connectivity)

                    # Normalize epochs data
                    epochs_data = epochs.get_data()
                    epochs_data_norm = normalize_data(epochs_data, method=normalization_method)

                    # Store data
                    subject_data['epochs_data'].append(epochs_data_norm)
                    subject_data['labels'].append(epochs.events[:, -1])
                    subject_data['connectivity'].append(connectivity)
                    subject_data['edge_indices'].append(edge_index)
                    subject_data['edge_attrs'].append(edge_attr)

                print(f"Processed {edf_file}: {epochs.get_data().shape[0]} samples")

            except Exception as e:
                print(f"Failed to process {edf_file}: {e}")

        # Add subject data if not empty
        if len(subject_data['epochs_data']) > 0:
            all_subjects_data.append(subject_data)

    # Save batch data to avoid memory issues
    batch_filename = os.path.join(output_dir, f"batch_{batch_idx//batch_size + 1}.npz")
    np.savez_compressed(batch_filename, batch_data=all_subjects_data)
    print(f"Saved batch {batch_idx//batch_size + 1} with {len(all_subjects_data)} subjects")

print(f"Processed {len(all_subjects_data)} subjects in total")

In [None]:
# 5. Split data by subjects to prevent data leakage
np.random.seed(42)  # For reproducibility
np.random.shuffle(all_subjects_data)

# Calculate split indices
train_ratio, val_ratio = 0.7, 0.15  # 70% train, 15% val, 15% test
n_subjects = len(all_subjects_data)
train_idx = int(n_subjects * train_ratio)
val_idx = train_idx + int(n_subjects * val_ratio)

# Split data
train_data = all_subjects_data[:train_idx]
val_data = all_subjects_data[train_idx:val_idx]
test_data = all_subjects_data[val_idx:]

print(f"Split data: {len(train_data)} train, {len(val_data)} validation, {len(test_data)} test subjects")

In [None]:
# 6. Prepare PyTorch Geometric data
from torch_geometric.data import Data, Dataset

def prepare_geometric_data(subject_data, label_map=None):
    """
    Convert processed subject data to PyTorch Geometric format.

    Args:
        subject_data: Dictionary containing processed data for a subject
        label_map: Dictionary mapping event IDs to class labels

    Returns:
        dataset_list: List of PyTorch Geometric Data objects
    """
    dataset_list = []

    # Use first connectivity matrix and edge index if multiple methods
    connectivity_idx = 0

    # For each recording session
    for i in range(len(subject_data['epochs_data'])):
        epochs = subject_data['epochs_data'][i]
        labels = subject_data['labels'][i]
        edge_index = subject_data['edge_indices'][i]
        edge_attr = subject_data['edge_attrs'][i]

        # For each epoch
        for j in range(epochs.shape[0]):
            # Prepare node features
            x = torch.from_numpy(epochs[j]).float()  # (n_channels, n_times)

            # Prepare label
            label = int(labels[j])
            y = torch.tensor([label]).long()

            # Create s and d attributes
            s = torch.tensor([label % 2]).long()  # Binary task (left vs right)
            d = torch.tensor([label % 3]).long()  # 3-class task

            # Create PyTorch Geometric Data object
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, s=s, d=d)

            dataset_list.append(data)

    return dataset_list

# Prepare train, validation, and test datasets
train_dataset = []
for subject in train_data:
    train_dataset.extend(prepare_geometric_data(subject))

val_dataset = []
for subject in val_data:
    val_dataset.extend(prepare_geometric_data(subject))

test_dataset = []
for subject in test_data:
    test_dataset.extend(prepare_geometric_data(subject))

print(f"Prepared datasets: {len(train_dataset)} train, {len(val_dataset)} validation, {len(test_dataset)} test samples")

In [None]:
# 7.save datasets in format compatible with official implementation
torch.save(train_dataset, os.path.join(output_dir, "train_dataset_id"))
torch.save(val_dataset, os.path.join(output_dir, "val_dataset_id"))
torch.save(test_dataset, os.path.join(output_dir, "test_dataset_id"))

#save edge index for global use
if len(train_dataset) > 0:
    edge_index = train_dataset[0].edge_index.numpy()
    np.savetxt(os.path.join(output_dir, "edge_index.txt"), edge_index)
    print("Saved edge index to file")

    #save metadata if channels are available
    try:
        #get channel names from the first subject
        n_nodes = train_dataset[0].x.shape[0]
        ch_names = ['Ch' + str(i) for i in range(n_nodes)]
        brodmann_labels = map_eeg_to_brodmann(ch_names)

        #save metadata
        metadata = {
            'channel_names': ch_names,
            'brodmann_labels': brodmann_labels
        }

        with open(os.path.join(output_dir, "metadata.pkl"), 'wb') as f:
            pickle.dump(metadata, f)

        print("Saved metadata")
    except Exception as e:
        print(f"Failed to save metadata: {e}")

In [None]:
# 8.zip file
try:
    zip_path = os.path.join(output_dir, "eeg_gnn_data.zip")
    shutil.make_archive(os.path.splitext(zip_path)[0], 'zip', output_dir)
    print(f"Created zip archive at {zip_path}")

    #download in Colab ?
    try:
        from google.colab import files
        files.download(zip_path)
    except:
        pass
except:
    print("Failed to create zip archive")

In [None]:
#9 visualizations:
# Visualize a sample connectivity matrix and signals
if len(train_dataset) > 0:
    sample_data = train_dataset[0]
    edge_index = sample_data.edge_index.numpy()
    n_nodes = sample_data.x.shape[0]

    # --- Build symmetric adjacency matrix ---
    adj_matrix = np.zeros((n_nodes, n_nodes))
    for i in range(edge_index.shape[1]):
        src, dst = edge_index[0, i], edge_index[1, i]
        adj_matrix[src, dst] = 1
        adj_matrix[dst, src] = 1  # ensure symmetry for visualization

    # --- Channel names using 10-10 convention (basic guess) ---
    # Replace this with actual channel labels if you have them
    ch_names = [f'Ch{i}' for i in range(n_nodes)]

    # --- Plot adjacency matrix ---
    plt.figure(figsize=(10, 10))
    plt.imshow(adj_matrix, cmap='viridis')
    plt.colorbar()
    plt.xticks(ticks=np.arange(n_nodes), labels=ch_names, rotation=90, fontsize=6)
    plt.yticks(ticks=np.arange(n_nodes), labels=ch_names, fontsize=6)
    plt.title("Connectivity Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "sample_connectivity_matrix.png"))
    plt.show()

    # --- Plot EEG signals ---
    plt.figure(figsize=(10, 6))
    for i in range(min(5, sample_data.x.shape[0])):
        signal = sample_data.x[i].numpy()
        signal = signal - np.mean(signal)  # remove DC offset
        plt.plot(signal, label=f'Channel {i}')
    plt.title("Sample EEG Signals (First 5 Channels)")
    plt.xlabel("Time")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "sample_signals.png"))
    plt.show()
