<a href="https://colab.research.google.com/github/fatihonay/Deep-Learning-Journey/blob/main/Parkinson_Death_Zone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Import Libraries

In [1]:
pip install mne torch torchvision braindecode



# 2. Scan All Subject and Detect Common EEG Channels in Dataset


In [None]:
import os
import mne

dataset_root = '/content/drive/MyDrive/ds007020'

all_channel_sets = []
failed_subjects_channel_check = []

# Get a list of all subject directories that start with 'sub-'
subject_dirs = [d for d in os.listdir(dataset_root) if d.startswith('sub-') and os.path.isdir(os.path.join(dataset_root, d))]

print(f"Found {len(subject_dirs)} potential subject directories for channel check.")

for subject_id in sorted(subject_dirs): # Sort for consistent processing order
    subject_path = os.path.join(dataset_root, subject_id)

    # Look for session directories (e.g., 'ses-01')
    session_dirs = [d for d in os.listdir(subject_path) if d.startswith('ses-') and os.path.isdir(os.path.join(subject_path, d))]

    if not session_dirs:
        # print(f"Warning: No session directories found for {subject_id}. Skipping channel check.")
        failed_subjects_channel_check.append(subject_id)
        continue

    # Assuming 'ses-01' is the primary session or taking the first available session
    target_session_id = 'ses-01'
    if target_session_id in session_dirs:
        session_id = target_session_id
    elif session_dirs:
        session_id = session_dirs[0] # Use the first found session if 'ses-01' is not present
        # print(f"Warning: '{target_session_id}' not found for {subject_id}, using '{session_id}' instead for channel check.")
    else:
        # This case is already covered by the 'if not session_dirs' above, but kept for clarity
        # print(f"Warning: No valid session found for {subject_id}. Skipping channel check.")
        failed_subjects_channel_check.append(subject_id)
        continue

    session_path = os.path.join(subject_path, session_id)
    eeg_folder_path = os.path.join(session_path, 'eeg')

    if not os.path.isdir(eeg_folder_path):
        # print(f"Warning: 'eeg' directory not found for {subject_id}/{session_id}. Skipping channel check.")
        failed_subjects_channel_check.append(subject_id)
        continue

    # Find the .vhdr file within the 'eeg' directory
    vhdr_files = [f for f in os.listdir(eeg_folder_path) if f.endswith('.vhdr')]

    if not vhdr_files:
        # print(f"Warning: No .vhdr file found for {subject_id}/{session_id}. Skipping channel check.")
        failed_subjects_channel_check.append(subject_id)
        continue

    # Assuming there's only one .vhdr file per eeg directory, or taking the first one
    vhdr_file_name = vhdr_files[0]
    full_vhdr_path = os.path.join(eeg_folder_path, vhdr_file_name)

    try:
        # Load the raw EEG data (preload=False to minimize memory usage during this scan)
        raw = mne.io.read_raw_brainvision(full_vhdr_path, preload=False, verbose=False)
        all_channel_sets.append(set(raw.info['ch_names']))
        # print(f"Processed {subject_id}/{session_id}: Found {len(raw.info['ch_names'])} channels.")
    except Exception as e:
        print(f"Error processing {subject_id}/{session_id} ({full_vhdr_path}) for channels: {e}")
        failed_subjects_channel_check.append(subject_id)

print(f"\nSuccessfully collected channel information from {len(all_channel_sets)} subjects.")
print(f"Failed to collect channel information from {len(failed_subjects_channel_check)} subjects: {failed_subjects_channel_check}")

if all_channel_sets:
    # Find the intersection of all channel sets to get common channels
    common_channels = set.intersection(*all_channel_sets)
    print(f"\nCommon channels across all successfully processed subjects ({len(common_channels)} channels):")
    print(sorted(list(common_channels)))
else:
    print("No channel sets were collected. Cannot determine common channels.")


# Save common channels to a text file for later use
common_channels_list = sorted(list(common_channels))
output_file = os.path.join(dataset_root, 'common_channels.txt')
with open(output_file, 'w') as f:
    for ch in common_channels_list:
        f.write(ch + '\n')
print(f"Common channels saved to: {output_file}")


# 3. Check All Subjects with the Common EEG Channels

In [3]:
failed_subjects_set = set(failed_subjects_channel_check)
valid_subject_dirs = [s for s in subject_dirs if s not in failed_subjects_set]

print(f"Original number of subject directories: {len(subject_dirs)}")
print(f"Number of subjects that failed channel check: {len(failed_subjects_channel_check)}")
print(f"Number of valid subject directories after filtering: {len(valid_subject_dirs)}")
print(f"First 5 valid subject directories: {valid_subject_dirs[:5]}")

Original number of subject directories: 94
Number of subjects that failed channel check: 0
Number of valid subject directories after filtering: 94
First 5 valid subject directories: ['sub-PD1541', 'sub-PD1341', 'sub-PD1281', 'sub-PD1051', 'sub-PD1591']


# 4. Load and Clean Data

ASR for artifact removal process, and then perform epoching on the resting dataset.

In [None]:
import os
import numpy as np
import pandas as pd
import mne
from mne import make_fixed_length_epochs
from asrpy import ASR  # Import ASR

# --- CONFIGURATION ---
SFREQ_TARGET = 100
L_FREQ, H_FREQ = 0.5, 45.0
WINDOW_SIZE_SEC = 5
OVERLAP = 5.0  # Changed to 5.0 seconds (50% overlap) for better data augmentation
ASR_CUTOFF = 20 # Standard cutoff for artifact removal (lower = more aggressive)

# --- 1. LOAD LABELS ---
participants_path = os.path.join(dataset_root, 'participants.tsv')
df_participants = pd.read_csv(participants_path, sep='\t')

label_map = {}
for idx, row in df_participants.iterrows():
    sub_id = row['participant_id']
    status = row['survival_status']
    if status == 'living':
        label_map[sub_id] = 0
    elif status == 'deceased':
        label_map[sub_id] = 1

print(f"Loaded labels for {len(label_map)} subjects.")

# --- 2. DATA PROCESSING LOOP ---
X_list = []
y_list = []
groups_list = []

# Ensure channels are consistent
target_channels = sorted(list(common_channels))
print(f"Starting processing using {len(target_channels)} channels...")

for subject_id in sorted(subject_dirs):
    if subject_id not in label_map:
        continue

    try:
        # Construct path
        session_dirs = [d for d in os.listdir(os.path.join(dataset_root, subject_id)) if d.startswith('ses-')]
        if not session_dirs: continue
        session_id = 'ses-01' if 'ses-01' in session_dirs else session_dirs[0]

        eeg_folder = os.path.join(dataset_root, subject_id, session_id, 'eeg')
        vhdr_files = [f for f in os.listdir(eeg_folder) if f.endswith('.vhdr')]
        if not vhdr_files: continue
        full_path = os.path.join(eeg_folder, vhdr_files[0])

        # A. LOAD
        raw = mne.io.read_raw_brainvision(full_path, preload=True, verbose='ERROR')

        # B. PICK CHANNELS
        raw.pick_channels(target_channels)

        # C. FILTER & RESAMPLE
        raw.filter(L_FREQ, H_FREQ, verbose='ERROR')
        raw.resample(SFREQ_TARGET, verbose='ERROR')

        # --- NEW: ASR CLEANING (Inserted Here) ---
        # ASR needs the continuous raw data to learn what "clean" looks like.
        try:
            asr = ASR(sfreq=SFREQ_TARGET, cutoff=ASR_CUTOFF)
            asr.fit(raw)            # Learn clean statistics from this subject
            raw = asr.transform(raw) # Repair artifacts
        except Exception as asr_e:
            print(f"  ASR failed for {subject_id}, using standard data: {asr_e}")
        # -----------------------------------------

        # D. CREATE WINDOWS (EPOCHS)
        epochs = make_fixed_length_epochs(
            raw,
            duration=WINDOW_SIZE_SEC,
            overlap=OVERLAP,
            preload=True,
            verbose='ERROR'
        )

        # E. EXTRACT DATA
        data = epochs.get_data(copy=True)

        # F. APPEND
        label = label_map[subject_id]
        n_windows = data.shape[0]

        if n_windows > 0:
            X_list.append(data)
            y_list.extend([label] * n_windows)
            groups_list.extend([subject_id] * n_windows)
            print(f"Processed {subject_id}: {n_windows} windows (Label: {label})")

    except Exception as e:
        print(f"Skipping {subject_id}: {e}")

# --- 3. CONVERT TO NUMPY ---
if len(X_list) > 0:
    X = np.concatenate(X_list, axis=0)
    y = np.array(y_list)
    groups = np.array(groups_list)

    # --- 4. SCALING ---
    mean = np.mean(X, axis=(0, 2), keepdims=True)
    std = np.std(X, axis=(0, 2), keepdims=True)
    std[std == 0] = 1.0
    X_scaled = (X - mean) / (std + 1e-6)

    print("\n--- DATA LOADING COMPLETE ---")
    print(f"Final Data Shape (X): {X.shape}")
    print(f"Final Labels Shape (y): {y.shape}")
    print(f"Class Balance: {np.sum(y==0)} Living vs {np.sum(y==1)} Deceased")
else:
    print("Error: No data was processed.")

# 5. Save pre-processed Data for Later Use

In [None]:
# --- 5. SAVE DATA ---
import pickle  # Add this import at the top if not already present

output_dir = os.path.join(dataset_root, 'processed_data')
os.makedirs(output_dir, exist_ok=True)

# Save with pickle (recommended for large numpy arrays)
data_dict = {
    'X': X,
    'X_scaled': X_scaled,
    'y': y,
    'groups': groups,
    'mean': mean,
    'std': std,
    'target_channels': target_channels,
    'config': {
        'SFREQ_TARGET': SFREQ_TARGET,
        'L_FREQ': L_FREQ,
        'H_FREQ': H_FREQ,
        'WINDOW_SIZE_SEC': WINDOW_SIZE_SEC,
        'OVERLAP': OVERLAP
    }
}

pickle_file = os.path.join(output_dir, 'eeg_data.pkl')
with open(pickle_file, 'wb') as f:
    pickle.dump(data_dict, f)

print(f"Data saved to: {pickle_file}")
print(f"Filesize: {os.path.getsize(pickle_file) / (1024**2):.1f} MB")


# 6. Import Channel List and Cleaned EEG Data


In [12]:
# Import Channels List

import os

input_file = os.path.join(dataset_root, 'common_channels.txt')

with open(input_file, 'r') as f:
    common_channels = [line.strip() for line in f if line.strip()]

print(common_channels)
print(type(common_channels))  # list

# Import Data

import os
import pickle
import numpy as np

# --- LOAD PREVIOUSLY PROCESSED DATA ---
dataset_root = '/content/drive/MyDrive/ds007020'
processed_dir = os.path.join(dataset_root, 'processed_data')
pickle_file = os.path.join(processed_dir, 'eeg_data.pkl')

print("Loading processed EEG data...")
with open(pickle_file, 'rb') as f:
    data_dict = pickle.load(f)

# Extract all your variables
X = data_dict['X']
X_scaled = data_dict['X_scaled']
y = data_dict['y']
groups = data_dict['groups']
mean = data_dict['mean']
std = data_dict['std']
target_channels = data_dict['target_channels']
config = data_dict['config']

# Print loaded info
print(f"Loaded X shape: {X.shape}")
print(f"Loaded X_scaled shape: {X_scaled.shape}")
print(f"Labels: {np.sum(y==0)} Living vs {np.sum(y==1)} Deceased")
print(f"Channels: {len(target_channels)}")
print(f"SFREQ: {config['SFREQ_TARGET']} Hz")

Data scaled to mean 0, std 1.
