# Epilepsy Data Exploration

## Load Libraries

In [None]:
import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# MNE Libraries
import mne
from mne import Epochs, pick_types, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs

# Scipy libraries
from scipy import signal
from scipy.spatial.distance import euclidean
from scipy.stats import pearsonr

## Data Loading Function

In [None]:
def load_metadata_from_json(data_dir='data'):
    """Load patient and seizure metadata from JSON files"""
    data_dir = Path(data_dir)
    all_seizures = []
    all_patients = []
    
    json_files = sorted(list(data_dir.glob('**/*.json')))
    print(f"Found {len(json_files)} JSON files")
    
    for json_file in json_files:
        try:
            with open(json_file, 'r') as f:
                data = json.load(f)
            
            # Extract patient-level information
            patient_info = {
                'patient_id': data['patient_id'],
                'sampling_rate_hz': data['sampling_rate_hz'],
                'num_channels': len(data['channels']),
                'json_file_path': str(json_file)
            }
            
            if 'file_name' in data:
                patient_info['file_name'] = data['file_name']
                patient_info['registration_start_time'] = data.get('registration_start_time')
                patient_info['registration_end_time'] = data.get('registration_end_time')
            
            all_patients.append(patient_info)
            
            # Process each seizure
            for seizure in data['seizures']:
                seizure_record = {
                    'patient_id': data['patient_id'],
                    'sampling_rate_hz': data['sampling_rate_hz'],
                    'seizure_number': seizure['seizure_number']
                }
                
                for key, value in seizure.items():
                    seizure_record[key] = value
                
                if 'file_name' in patient_info and 'file_name' not in seizure:
                    seizure_record['file_name'] = patient_info['file_name']
                    if 'registration_start_time' in patient_info:
                        seizure_record['registration_start_time'] = patient_info['registration_start_time']
                        seizure_record['registration_end_time'] = patient_info['registration_end_time']
                
                all_seizures.append(seizure_record)
                
        except Exception as e:
            print(f"Error loading {json_file}: {e}")
            continue
    
    seizures_df = pd.DataFrame(all_seizures)
    patients_df = pd.DataFrame(all_patients)
    
    return seizures_df, patients_df

## Data Processing Utilities

In [None]:
def fill_missing_values(df):
    # Fill categorical columns with 'N/A'
    cat_cols = df.select_dtypes(include=['object']).columns
    df[cat_cols] = df[cat_cols].fillna("N/A")
    
    # Fill all numerical columns with 0 (int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64, complex64, complex128)
    num_cols = df.select_dtypes(include=[np.number]).columns  
    df[num_cols] = df[num_cols].replace([np.inf, -np.inf, '', None], np.nan).fillna(0)
    
    return df

In [None]:
# Simple function to convert complex numerical types to simpler ones
def simplify_dtypes(df):
    df_simplified = df.copy()
    
    # Dictionary mapping complex types to simpler ones
    dtype_mapping = {
        'float64': 'float32',
        'int64': 'int32',
        'int32': 'int16', 
        'float32': 'float32', 
    }
    
    for column in df_simplified.columns:
        current_dtype = str(df_simplified[column].dtype)
        
        # Convert float64 to float32
        if current_dtype == 'float64':
            df_simplified[column] = df_simplified[column].astype('float32')
            print(f"Converted {column}: {current_dtype} -> float32")
        
        # Convert int64 to int32 (check range first)
        elif current_dtype == 'int64':
            col_min = df_simplified[column].min()
            col_max = df_simplified[column].max()
            
            # Check if values fit in int32 range
            if col_min >= -2147483648 and col_max <= 2147483647:
                df_simplified[column] = df_simplified[column].astype('int32')
                print(f"Converted {column}: {current_dtype} -> int32")
            else:
                print(f"Kept {column} as {current_dtype} (values too large for int32)")
        
        # Convert int32 to int16 if values are small enough
        elif current_dtype == 'int32':
            col_min = df_simplified[column].min()
            col_max = df_simplified[column].max()
            
            # Check if values fit in int16 range
            if col_min >= -32768 and col_max <= 32767:
                df_simplified[column] = df_simplified[column].astype('int16')
                print(f"Converted {column}: {current_dtype} -> int16")
    
    return df_simplified

## PSD Feature Extraction

In [None]:
def extract_psd_features(raw, fmin=0.5, fmax=50):
    """
    Extract PSD features from raw EEG data
    
    Returns a dictionary with power features for different frequency bands
    """
    # Compute PSD using multitaper method
    psd = raw.compute_psd(method='multitaper', fmin=fmin, fmax=fmax, verbose=False)
    psds, freqs = psd.get_data(return_freqs=True)
    
    # Define frequency bands
    bands = {
        'delta': (0.5, 4),
        'theta': (4, 8),
        'alpha': (8, 12),
        'beta': (12, 30),
        'gamma': (30, 50)
    }
    
    features = {}
    
    # Calculate band powers for each channel
    for band_name, (low_freq, high_freq) in bands.items():
        # Find frequency indices
        freq_mask = (freqs >= low_freq) & (freqs < high_freq)
        
        # Calculate mean power in band for each channel
        band_power = np.mean(psds[:, freq_mask], axis=1)
        
        # Store statistics across channels
        features[f'{band_name}_power_mean'] = np.mean(band_power)
        features[f'{band_name}_power_std'] = np.std(band_power)
        features[f'{band_name}_power_median'] = np.median(band_power)
        features[f'{band_name}_power_max'] = np.max(band_power)
        features[f'{band_name}_power_min'] = np.min(band_power)
    
    # Calculate total power
    total_power = np.mean(psds, axis=1)
    features['total_power_mean'] = np.mean(total_power)
    features['total_power_std'] = np.std(total_power)
    
    # Calculate relative band powers
    for band_name in bands.keys():
        features[f'{band_name}_relative_power'] = features[f'{band_name}_power_mean'] / features['total_power_mean']
    
    # Calculate peak frequency
    mean_psd = np.mean(psds, axis=0)
    peak_idx = np.argmax(mean_psd)
    features['peak_frequency'] = freqs[peak_idx]
    features['peak_power'] = mean_psd[peak_idx]
    
    # Calculate spectral entropy
    psd_norm = psds / psds.sum(axis=1, keepdims=True)
    spectral_entropy = -np.sum(psd_norm * np.log(psd_norm + 1e-15), axis=1)
    features['spectral_entropy_mean'] = np.mean(spectral_entropy)
    features['spectral_entropy_std'] = np.std(spectral_entropy)
    
    # Calculate spectral edge frequency (95% of power)
    cumsum_psd = np.cumsum(mean_psd)
    cumsum_psd = cumsum_psd / cumsum_psd[-1]
    edge_idx = np.where(cumsum_psd >= 0.95)[0][0]
    features['spectral_edge_95'] = freqs[edge_idx]
    
    return features

def extract_psd_features_by_region(raw, fmin=0.5, fmax=50):
    """
    Extract PSD features by brain region
    """
    # Define channel groups
    channel_groups = {
        'frontal': ['Fp1', 'Fp2', 'F3', 'F4', 'F7', 'F8', 'Fz'],
        'central': ['C3', 'C4', 'Cz'],
        'parietal': ['P3', 'P4', 'Pz'],
        'occipital': ['O1', 'O2'],
        'temporal': ['T3', 'T4', 'T5', 'T6']
    }
    
    regional_features = {}
    available_channels = raw.ch_names
    
    for region, channel_list in channel_groups.items():
        # Find matching channels (handle different naming conventions)
        region_channels = []
        for ch in available_channels:
            ch_clean = ch.upper().replace('EEG', '').replace('-', '').strip()
            for target_ch in channel_list:
                if target_ch.upper() in ch_clean:
                    region_channels.append(ch)
                    break
        
        if not region_channels:
            continue
        
        # Pick channels for this region
        try:
            raw_region = raw.copy().pick(region_channels)
            
            # Extract features for this region
            region_psd_features = extract_psd_features(raw_region, fmin, fmax)
            
            # Add region prefix to feature names
            for feature_name, value in region_psd_features.items():
                regional_features[f'{region}_{feature_name}'] = value
                
        except Exception as e:
            print(f"Could not process {region} region: {e}")
            continue
    
    return regional_features

## Propagation Speed Calculation

In [None]:
def detect_channel_onsets(data, sfreq, threshold=2.0, window_size=1.0, 
                         seizure_start=None, seizure_end=None):
    """Detect seizure onset times for each channel"""
    n_channels, n_samples = data.shape
    window_samples = int(window_size * sfreq)
    onset_times = {}
    
    for ch_idx in range(n_channels):
        channel_data = data[ch_idx, :]
        
        # Calculate envelope using Hilbert transform
        analytic_signal = signal.hilbert(channel_data)
        envelope = np.abs(analytic_signal)
        
        # Smooth envelope
        if window_samples > 3:
            envelope_smooth = signal.savgol_filter(envelope, window_samples, 3)
        else:
            envelope_smooth = envelope
        
        # Calculate baseline and threshold
        if seizure_start is not None and seizure_start > 10:
            baseline_end = int((seizure_start - 1) * sfreq)
            baseline = envelope_smooth[:baseline_end]
        else:
            baseline = envelope_smooth[:int(10 * sfreq)]
        
        if len(baseline) == 0:
            continue
            
        baseline_mean = np.mean(baseline)
        baseline_std = np.std(baseline)
        onset_threshold = baseline_mean + threshold * baseline_std
        
        # Find onset
        if seizure_start is not None:
            search_start = max(0, int((seizure_start - 5) * sfreq))
            search_end = min(n_samples, int((seizure_start + 10) * sfreq))
            search_region = envelope_smooth[search_start:search_end]
            
            crossings = np.where(search_region > onset_threshold)[0]
            if len(crossings) > 0:
                onset_sample = search_start + crossings[0]
                onset_times[ch_idx] = onset_sample / sfreq
        else:
            crossings = np.where(envelope_smooth > onset_threshold)[0]
            if len(crossings) > 0:
                for crossing in crossings:
                    if crossing + window_samples < n_samples:
                        if np.mean(envelope_smooth[crossing:crossing+window_samples]) > onset_threshold:
                            onset_times[ch_idx] = crossing / sfreq
                            break
    
    return onset_times

def estimate_electrode_positions(ch_names):
    """Estimate electrode positions based on 10-20 system"""
    standard_positions = {
        'FP1': (-0.3, 0.9), 'FP2': (0.3, 0.9),
        'F3': (-0.5, 0.6), 'F4': (0.5, 0.6),
        'F7': (-0.8, 0.5), 'F8': (0.8, 0.5),
        'C3': (-0.5, 0), 'C4': (0.5, 0),
        'T3': (-0.9, 0), 'T4': (0.9, 0),
        'T5': (-0.8, -0.5), 'T6': (0.8, -0.5),
        'P3': (-0.5, -0.6), 'P4': (0.5, -0.6),
        'O1': (-0.3, -0.9), 'O2': (0.3, -0.9),
        'FZ': (0, 0.7), 'CZ': (0, 0), 'PZ': (0, -0.7)
    }
    
    positions = {}
    for idx, ch_name in enumerate(ch_names):
        ch_clean = ch_name.upper().replace('-', '').replace('EEG', '').strip()
        
        for std_name, pos in standard_positions.items():
            if std_name in ch_clean:
                positions[idx] = pos
                break
        
        if idx not in positions:
            angle = 2 * np.pi * idx / len(ch_names)
            positions[idx] = (np.cos(angle), np.sin(angle))
    
    return positions

def calculate_propagation_speeds(onset_times, ch_names):
    """Calculate propagation speeds between channels"""
    if len(onset_times) < 2:
        return {}
    
    speeds = []
    delays = []
    
    sorted_onsets = sorted(onset_times.items(), key=lambda x: x[1])
    positions = estimate_electrode_positions(ch_names)
    
    for i in range(len(sorted_onsets) - 1):
        ch1_idx, time1 = sorted_onsets[i]
        ch2_idx, time2 = sorted_onsets[i + 1]
        
        delay = time2 - time1
        delays.append(delay)
        
        if ch1_idx in positions and ch2_idx in positions:
            pos1 = positions[ch1_idx]
            pos2 = positions[ch2_idx]
            distance = euclidean(pos1, pos2) * 100  # Convert to mm
            
            if delay > 0:
                speed = distance / delay
                speeds.append(speed)
    
    results = {}
    if speeds:
        results['mean_propagation_speed'] = np.mean(speeds)
        results['median_propagation_speed'] = np.median(speeds)
        results['std_propagation_speed'] = np.std(speeds)
        results['max_propagation_speed'] = np.max(speeds)
        results['min_propagation_speed'] = np.min(speeds)
        results['num_propagation_events'] = len(speeds)
    
    if delays:
        results['mean_onset_delay'] = np.mean(delays)
        results['max_onset_delay'] = np.max(delays)
    
    return results

In [None]:
# spike and sharpwave detection
def detect_spike_sharpwave_events(data, sfreq, threshold=2.0, min_duration_ms=20, max_duration_ms=250):
    """
    Detect spike and sharp wave events in EEG data based on duration and morphology.
    Returns a dictionary with counts of spikes and sharp waves per channel.
    
    Parameters:
    - data: 2D numpy array of shape (n_channels, n_samples)
    - sfreq: Sampling frequency in Hz
    - threshold: Number of standard deviations above baseline to consider a peak
    - min_duration_ms: Minimum duration of a detectable event in milliseconds
    - max_duration_ms: Maximum duration of a detectable event in milliseconds
    """
    n_channels, n_samples = data.shape
    results = {}

    for ch_idx in range(n_channels):
        channel_data = data[ch_idx, :]

        # Calculate envelope using Hilbert transform
        analytic_signal = signal.hilbert(channel_data)
        envelope = np.abs(analytic_signal)

        # Smooth envelope
        window_samples = int(0.05 * sfreq)  # 50 ms smoothing window
        if window_samples > 3:
            envelope_smooth = signal.savgol_filter(envelope, window_samples, 3)
        else:
            envelope_smooth = envelope

        # Calculate baseline and threshold
        baseline = envelope_smooth[:int(10 * sfreq)]
        baseline_mean = np.mean(baseline)
        baseline_std = np.std(baseline)
        onset_threshold = baseline_mean + threshold * baseline_std

        # Detect peaks above threshold
        peaks, _ = signal.find_peaks(envelope_smooth, height=onset_threshold, distance=int(0.02 * sfreq))

        spike_count = 0
        sharpwave_count = 0
        for peak in peaks:
            # Estimate duration by looking at width at baseline level
            left_base = peak
            while left_base > 0 and envelope_smooth[left_base] > baseline_mean:
                left_base -= 1

            right_base = peak
            while right_base < n_samples and envelope_smooth[right_base] > baseline_mean:
                right_base += 1

            duration_samples = right_base - left_base
            duration_ms = (duration_samples / sfreq) * 1000

            if duration_ms < 70:
                spike_count += 1
            elif 70 <= duration_ms <= 200:
                sharpwave_count += 1

        results[ch_idx] = {
            'num_spikes': spike_count,
            'num_sharpwaves': sharpwave_count
        }

    return results

## Comprehensive Data Extraction

In [None]:
def parse_seizure_time(time_str):
    """Parse seizure time string to seconds"""
    if pd.isna(time_str) or time_str == '':
        return None
    
    try:
        if ':' in str(time_str):
            parts = str(time_str).split(':')
            if len(parts) == 3:
                hours, minutes, seconds = map(float, parts)
                return hours * 3600 + minutes * 60 + seconds
            elif len(parts) == 2:
                minutes, seconds = map(float, parts)
                return minutes * 60 + seconds
        else:
            return float(time_str)
    except:
        return None

def process_single_edf(edf_path, seizure_info):
    """
    Process a single EDF file and extract all features
    
    Returns a dictionary with all extracted features
    """
    features = {}
    
    try:
        # Load EDF file
        raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose=False)
        
        # Basic file info
        features['file_path'] = str(edf_path)
        features['num_channels'] = len(raw.ch_names)
        features['sampling_rate'] = raw.info['sfreq']
        features['duration_seconds'] = raw.n_times / raw.info['sfreq']
        
        # Apply bandpass filter
        raw.filter(0.5, 50, fir_design='firwin', verbose=False)
        
        # Extract global PSD features
        psd_features = extract_psd_features(raw)
        features.update(psd_features)
        
        # Extract regional PSD features
        regional_features = extract_psd_features_by_region(raw)
        features.update(regional_features)
        
        # If seizure info is provided, calculate propagation speeds
        if seizure_info.any():
            seizure_start = parse_seizure_time(seizure_info.get('seizure_start_time'))
            seizure_end = parse_seizure_time(seizure_info.get('seizure_end_time'))
            
            if seizure_start is not None:
                # Extract seizure segment
                data = raw.get_data()
                
                # Apply additional filtering for seizure detection
                data_filtered = mne.filter.filter_data(
                    data, raw.info['sfreq'], l_freq=3, h_freq=30, verbose=False
                )
                
                # Detect onsets
                onset_times = detect_channel_onsets(
                    data_filtered, 
                    raw.info['sfreq'],
                    threshold=2.0,
                    window_size=1.0,
                    seizure_start=seizure_start,
                    seizure_end=seizure_end
                )
                
                # Calculate propagation speeds
                propagation_features = calculate_propagation_speeds(
                    onset_times, 
                    raw.ch_names
                )
                features.update(propagation_features)
                
                # Add seizure timing info
                features['seizure_start_seconds'] = seizure_start
                features['seizure_end_seconds'] = seizure_end
                if seizure_end and seizure_start:
                    features['seizure_duration'] = seizure_end - seizure_start
        
        features['processing_success'] = True
        
    except Exception as e:
        print(f"Error processing {edf_path}: {e}")
        features['processing_success'] = False
        features['error_message'] = str(e)
    
    return features

In [None]:
def build_comprehensive_dataset(seizures_df, patients_df, 
                              data_root_paths=['data/seina_scalp', 'data/chb-mit']):
    """
    Build comprehensive dataset with all features
    """
    all_records = []
    
    print(f"Processing {len(seizures_df)} seizure records...")
    print("="*60)
    
    for idx, seizure_row in tqdm(seizures_df.iterrows(), total=len(seizures_df)):
        record = {}
        
        # Add all seizure metadata
        for col in seizure_row.index:
            record[f'seizure_{col}'] = seizure_row[col]
        
        # Find corresponding patient info
        patient_id = seizure_row['patient_id']
        patient_info = patients_df[patients_df['patient_id'] == patient_id]
        
        if not patient_info.empty:
            for col in patient_info.columns:
                if col != 'patient_id':  # Avoid duplication
                    record[f'patient_{col}'] = patient_info.iloc[0][col]
        
        # Find EDF file
        file_name = seizure_row['file_name']
        edf_path = None
        
        for root_path in data_root_paths:
            possible_paths = [
                os.path.join(root_path, patient_id, file_name),
                os.path.join(root_path, patient_id.lower(), file_name),
                os.path.join(root_path, patient_id.upper(), file_name),
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    edf_path = path
                    break
            
            if edf_path:
                break
        
        if edf_path:
            # Process EDF and extract features
            edf_features = process_single_edf(edf_path, seizure_row)
            record.update(edf_features)
        else:
            record['processing_success'] = False
            record['error_message'] = 'EDF file not found'
        
        all_records.append(record)
    
    # Create comprehensive dataframe
    comprehensive_df = pd.DataFrame(all_records)
    
    # Print summary
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total records: {len(comprehensive_df)}")
    print(f"Successfully processed: {comprehensive_df['processing_success'].sum()}")
    print(f"Failed: {(~comprehensive_df['processing_success']).sum()}")
    
    return comprehensive_df

In [None]:
def save_to_parquet(df, output_path='comprehensive_eeg_features.parquet'):
    """
    Save dataframe to parquet format
    """
    # Convert any object columns that should be numeric
    numeric_columns = df.select_dtypes(include=['object']).columns
    for col in numeric_columns:
        try:
            df[col] = pd.to_numeric(df[col], errors='ignore')
        except:
            pass
    
    # Save to parquet
    df.to_parquet(output_path, index=False, compression='snappy')
    print(f"Dataset saved to: {output_path}")
    print(f"File size: {os.path.getsize(output_path) / 1024 / 1024:.2f} MB")
    
    # Create and save summary statistics
    summary_stats = df.describe()
    summary_stats.to_csv(output_path.replace('.parquet', '_summary.csv'))
    print(f"Summary statistics saved to: {output_path.replace('.parquet', '_summary.csv')}")

# Main Execution

In [None]:
# Load metadata
print("Loading metadata from JSON files...")
seizures_df, patients_df = load_metadata_from_json('data')
print(f"Loaded {len(patients_df)} patients and {len(seizures_df)} seizures")

In [None]:
# Build comprehensive dataset
comprehensive_df = build_comprehensive_dataset(
    seizures_df, 
    patients_df,
    data_root_paths=['data/seina_scalp', 'data/chb-mit']
)

## Fill in NaN and Simplify Dtypes

In [None]:
comprehensive_df = fill_missing_values(comprehensive_df)

In [None]:
comprehensive_df = simplify_dtypes(comprehensive_df)

## Export to Parquet

In [None]:
# Save to parquet
save_to_parquet(comprehensive_df, 'comprehensive_eeg_features.parquet')

# Display sample of features
print("\nSample of extracted features:")
print(comprehensive_df.columns.tolist()[:20])

print("\nDataset shape:", comprehensive_df.shape)