In [None]:
# Dataset: PhysioNet EEG Motor Movement/Imagery Dataset
# Goal: Load, filter, segment EEG data for 109 subjects
# =============================================================================

!pip install mne==1.6.1 -q
!pip install pyedflib -q
!pip install wget -q

import os
import numpy as np
import mne
import wget
import warnings
from tqdm import tqdm
import pickle
from sklearn.model_selection import train_test_split
from google.colab import drive

warnings.filterwarnings('ignore')
mne.set_log_level('WARNING')

print("All libraries imported successfully!")
print(f"MNE version: {mne.__version__}")

# ## 2. Configuration Parameters


class Config:
    # Dataset parameters
    N_SUBJECTS = 109                    # Total subjects in dataset
    SAMPLING_RATE = 160                 # Hz (PhysioNet EEG sampling rate)

    # Preprocessing parameters
    LOWCUT = 0.5                        # High-pass filter cutoff (Hz)
    HIGHCUT = 45.0                      # Low-pass filter cutoff (Hz)

    # Segmentation parameters
    SEGMENT_DURATION = 2.0              # seconds
    SEGMENT_SAMPLES = int(SEGMENT_DURATION * SAMPLING_RATE)  # 320 samples
    OVERLAP = 0.5                       # 50% overlap between segments

    # Motor cortex channels (10-10 system)
    # These channels cover the sensorimotor cortex region
    MOTOR_CHANNELS = [
        'Fc3', 'Fc1', 'Fcz', 'Fc2', 'Fc4',  # Frontal-Central
        'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6',  # Central (Primary Motor)
        'Cp3', 'Cp1', 'Cpz', 'Cp2', 'Cp4'   # Central-Parietal
    ]
    N_CHANNELS = len(MOTOR_CHANNELS)    # 17 channels

    # Motor Imagery runs only (as specified)
    # Runs 4, 8, 12: imagine opening/closing left or right fist
    # Runs 6, 10, 14: imagine opening/closing both fists or both feet
    MOTOR_IMAGERY_RUNS = [4, 6, 8, 10, 12, 14]

    # Data split
    TEST_SIZE = 0.2
    RANDOM_STATE = 42

    # Paths
    DATA_DIR = '/content/physionet_eeg'
    PROCESSED_DIR = '/content/processed_data'

    # PhysioNet base URL
    PHYSIONET_URL = 'https://physionet.org/files/eegmmidb/1.0.0'

config = Config()

print("Configuration:")
print(f"  - Number of subjects: {config.N_SUBJECTS}")
print(f"  - Sampling rate: {config.SAMPLING_RATE} Hz")
print(f"  - Segment duration: {config.SEGMENT_DURATION}s ({config.SEGMENT_SAMPLES} samples)")
print(f"  - Number of channels: {config.N_CHANNELS}")
print(f"  - Motor channels: {config.MOTOR_CHANNELS}")
print(f"  - Motor imagery runs: {config.MOTOR_IMAGERY_RUNS}")

# ## 3. Download Dataset from PhysioNet

def create_directories():
    """Create necessary directories"""
    os.makedirs(config.DATA_DIR, exist_ok=True)
    os.makedirs(config.PROCESSED_DIR, exist_ok=True)
    print(f"Created directories: {config.DATA_DIR}, {config.PROCESSED_DIR}")

def download_subject_data(subject_id):
    """
    Download EEG data for a single subject from PhysioNet

    Parameters:
    -----------
    subject_id : int
        Subject number (1-109)

    Returns:
    --------
    bool : True if successful, False otherwise
    """
    subject_dir = os.path.join(config.DATA_DIR, f'S{subject_id:03d}')
    os.makedirs(subject_dir, exist_ok=True)

    files_downloaded = 0

    for run in config.MOTOR_IMAGERY_RUNS:
        filename = f'S{subject_id:03d}R{run:02d}.edf'
        filepath = os.path.join(subject_dir, filename)

        # Skip if file already exists
        if os.path.exists(filepath):
            files_downloaded += 1
            continue

        url = f'{config.PHYSIONET_URL}/S{subject_id:03d}/{filename}'

        try:
            wget.download(url, filepath, bar=None)
            files_downloaded += 1
        except Exception as e:
            print(f"\n  Warning: Could not download {filename}: {e}")

    return files_downloaded == len(config.MOTOR_IMAGERY_RUNS)

def download_all_data():
    """Download data for all subjects"""
    print("=" * 60)
    print("DOWNLOADING PHYSIONET EEG DATA")
    print("=" * 60)
    print(f"This will download motor imagery runs for {config.N_SUBJECTS} subjects")
    print(f"Runs to download per subject: {config.MOTOR_IMAGERY_RUNS}")
    print("-" * 60)

    create_directories()

    successful = 0
    failed_subjects = []

    for subject_id in tqdm(range(1, config.N_SUBJECTS + 1), desc="Downloading"):
        if download_subject_data(subject_id):
            successful += 1
        else:
            failed_subjects.append(subject_id)

    print("-" * 60)
    print(f"Download complete: {successful}/{config.N_SUBJECTS} subjects")

    if failed_subjects:
        print(f"Failed subjects: {failed_subjects}")

    return successful, failed_subjects

successful, failed = download_all_data()

# ## 4. EEG Loading and Preprocessing Functions

def load_edf_file(filepath):
    """
    Load an EDF file using MNE

    Parameters:
    -----------
    filepath : str
        Path to the EDF file

    Returns:
    --------
    raw : mne.io.Raw
        Raw EEG data object
    """
    try:
        raw = mne.io.read_raw_edf(filepath, preload=True, verbose=False)
        return raw
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return None

def select_motor_channels(raw):
    """
    Select motor cortex channels from raw data

    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG data

    Returns:
    --------
    raw : mne.io.Raw
        Raw data with only motor channels
    """
    # Get available channels (handle different naming conventions)
    available_channels = raw.ch_names

    # PhysioNet uses format like 'Fc3.', 'C3.', etc. (with dots)
    # We need to match our channel names
    selected_channels = []

    for ch in config.MOTOR_CHANNELS:
        # Try different naming conventions
        possible_names = [ch, ch + '.', ch.upper(), ch.upper() + '.',
                         ch.lower(), ch.lower() + '.']

        for name in possible_names:
            if name in available_channels:
                selected_channels.append(name)
                break

    if len(selected_channels) == 0:
        print(f"Warning: No motor channels found. Available: {available_channels[:10]}...")
        return None

    # Pick only selected channels
    raw.pick_channels(selected_channels)

    return raw

def apply_bandpass_filter(raw, lowcut=None, highcut=None):
    """
    Apply band-pass filter to remove noise

    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG data
    lowcut : float
        Low cutoff frequency (Hz)
    highcut : float
        High cutoff frequency (Hz)

    Returns:
    --------
    raw : mne.io.Raw
        Filtered data
    """
    if lowcut is None:
        lowcut = config.LOWCUT
    if highcut is None:
        highcut = config.HIGHCUT

    raw.filter(lowcut, highcut, fir_design='firwin', verbose=False)

    return raw

def segment_data(raw, segment_duration=None, overlap=None):
    """
    Segment continuous EEG data into fixed-length epochs

    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG data
    segment_duration : float
        Duration of each segment in seconds
    overlap : float
        Overlap ratio between segments (0-1)

    Returns:
    --------
    segments : np.ndarray
        Shape: (n_segments, n_channels, n_samples)
    """
    if segment_duration is None:
        segment_duration = config.SEGMENT_DURATION
    if overlap is None:
        overlap = config.OVERLAP

    # Get data array
    data = raw.get_data()  # Shape: (n_channels, n_total_samples)
    sfreq = raw.info['sfreq']

    # Calculate segment parameters
    segment_samples = int(segment_duration * sfreq)
    step_samples = int(segment_samples * (1 - overlap))

    # Extract segments
    segments = []
    n_samples = data.shape[1]

    start = 0
    while start + segment_samples <= n_samples:
        segment = data[:, start:start + segment_samples]
        segments.append(segment)
        start += step_samples

    if len(segments) == 0:
        return None

    return np.array(segments)

def normalize_segments(segments):
    """
    Z-score normalize each segment independently

    Parameters:
    -----------
    segments : np.ndarray
        Shape: (n_segments, n_channels, n_samples)

    Returns:
    --------
    normalized : np.ndarray
        Normalized segments
    """
    # Normalize each segment independently (across all channels and time)
    normalized = np.zeros_like(segments)

    for i in range(len(segments)):
        segment = segments[i]
        mean = np.mean(segment)
        std = np.std(segment)
        if std > 0:
            normalized[i] = (segment - mean) / std
        else:
            normalized[i] = segment - mean

    return normalized

# ## 5. Process All Subjects

def process_subject(subject_id, verbose=False):
    """
    Process all motor imagery runs for a single subject

    Parameters:
    -----------
    subject_id : int
        Subject number (1-109)
    verbose : bool
        Print detailed information

    Returns:
    --------
    segments : np.ndarray or None
        All segments for this subject
        Shape: (n_segments, n_channels, n_samples)
    """
    subject_dir = os.path.join(config.DATA_DIR, f'S{subject_id:03d}')
    all_segments = []

    for run in config.MOTOR_IMAGERY_RUNS:
        filename = f'S{subject_id:03d}R{run:02d}.edf'
        filepath = os.path.join(subject_dir, filename)

        if not os.path.exists(filepath):
            if verbose:
                print(f"  File not found: {filename}")
            continue

        # Load EDF file
        raw = load_edf_file(filepath)
        if raw is None:
            continue

        # Select motor channels
        raw = select_motor_channels(raw)
        if raw is None:
            continue

        # Apply band-pass filter
        raw = apply_bandpass_filter(raw)

        # Segment data
        segments = segment_data(raw)
        if segments is None:
            continue

        all_segments.append(segments)

        if verbose:
            print(f"  Run {run}: {segments.shape[0]} segments")

    if len(all_segments) == 0:
        return None

    # Concatenate all segments
    all_segments = np.concatenate(all_segments, axis=0)

    # Normalize
    all_segments = normalize_segments(all_segments)

    return all_segments

def process_all_subjects():
    """
    Process EEG data for all subjects

    Returns:
    --------
    X : np.ndarray
        All EEG segments, shape: (n_total_segments, n_channels, n_samples)
    y : np.ndarray
        Subject labels (0 to 108), shape: (n_total_segments,)
    subject_segment_counts : dict
        Number of segments per subject
    """
    print("=" * 60)
    print("PROCESSING EEG DATA")
    print("=" * 60)

    all_X = []
    all_y = []
    subject_segment_counts = {}
    skipped_subjects = []

    for subject_id in tqdm(range(1, config.N_SUBJECTS + 1), desc="Processing"):
        segments = process_subject(subject_id)

        if segments is None or len(segments) == 0:
            skipped_subjects.append(subject_id)
            continue

        # Store segments and labels
        n_segments = len(segments)
        all_X.append(segments)
        all_y.extend([subject_id - 1] * n_segments)  # Labels: 0 to 108

        subject_segment_counts[subject_id] = n_segments

    # Concatenate all data
    X = np.concatenate(all_X, axis=0)
    y = np.array(all_y)

    print("-" * 60)
    print(f"Processing complete!")
    print(f"  - Total segments: {len(X)}")
    print(f"  - Data shape: {X.shape}")
    print(f"  - Subjects processed: {len(subject_segment_counts)}")

    if skipped_subjects:
        print(f"  - Skipped subjects: {skipped_subjects}")

    # Statistics
    segments_per_subject = list(subject_segment_counts.values())
    print(f"  - Segments per subject: min={min(segments_per_subject)}, "
          f"max={max(segments_per_subject)}, mean={np.mean(segments_per_subject):.1f}")

    return X, y, subject_segment_counts

# Process all subjects
X, y, segment_counts = process_all_subjects()

# ## 6. Data Verification and Statistics

def verify_data(X, y):
    """Verify processed data integrity"""
    print("=" * 60)
    print("DATA VERIFICATION")
    print("=" * 60)

    print(f"\nData Shapes:")
    print(f"  X (EEG data): {X.shape}")
    print(f"    - Total segments: {X.shape[0]}")
    print(f"    - Channels: {X.shape[1]}")
    print(f"    - Samples per segment: {X.shape[2]}")
    print(f"  y (labels): {y.shape}")

    print(f"\nLabel Statistics:")
    unique_labels = np.unique(y)
    print(f"  - Unique subjects: {len(unique_labels)}")
    print(f"  - Label range: {unique_labels.min()} to {unique_labels.max()}")

    print(f"\nData Statistics:")
    print(f"  - Mean: {np.mean(X):.6f}")
    print(f"  - Std: {np.std(X):.6f}")
    print(f"  - Min: {np.min(X):.6f}")
    print(f"  - Max: {np.max(X):.6f}")

    # Check for NaN or Inf
    nan_count = np.sum(np.isnan(X))
    inf_count = np.sum(np.isinf(X))
    print(f"\nData Quality:")
    print(f"  - NaN values: {nan_count}")
    print(f"  - Inf values: {inf_count}")

    # Class distribution
    print(f"\nClass Distribution (samples per subject):")
    label_counts = np.bincount(y)
    print(f"  - Min: {label_counts.min()}")
    print(f"  - Max: {label_counts.max()}")
    print(f"  - Mean: {label_counts.mean():.1f}")
    print(f"  - Std: {label_counts.std():.1f}")

    return True

verify_data(X, y)

# ## 7. Train/Test Split

def create_train_test_split(X, y, test_size=None, random_state=None):
    """
    Create stratified train/test split

    Parameters:
    -----------
    X : np.ndarray
        EEG segments
    y : np.ndarray
        Subject labels
    test_size : float
        Proportion of data for testing
    random_state : int
        Random seed for reproducibility

    Returns:
    --------
    X_train, X_test, y_train, y_test : np.ndarray
        Split data
    """
    if test_size is None:
        test_size = config.TEST_SIZE
    if random_state is None:
        random_state = config.RANDOM_STATE

    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=random_state,
        stratify=y  # Ensure balanced classes in both sets
    )

    print("=" * 60)
    print("TRAIN/TEST SPLIT")
    print("=" * 60)
    print(f"Training set: {X_train.shape[0]} samples ({100*(1-test_size):.0f}%)")
    print(f"Test set: {X_test.shape[0]} samples ({100*test_size:.0f}%)")
    print(f"X_train shape: {X_train.shape}")
    print(f"X_test shape: {X_test.shape}")

    return X_train, X_test, y_train, y_test

# Create split
X_train, X_test, y_train, y_test = create_train_test_split(X, y)

# ## 8. Save Processed Data

def save_processed_data(X_train, X_test, y_train, y_test, segment_counts):
    """Save all processed data to disk"""
    print("=" * 60)
    print("SAVING PROCESSED DATA")
    print("=" * 60)

    os.makedirs(config.PROCESSED_DIR, exist_ok=True)

    # Save as numpy arrays
    np.save(os.path.join(config.PROCESSED_DIR, 'X_train.npy'), X_train)
    np.save(os.path.join(config.PROCESSED_DIR, 'X_test.npy'), X_test)
    np.save(os.path.join(config.PROCESSED_DIR, 'y_train.npy'), y_train)
    np.save(os.path.join(config.PROCESSED_DIR, 'y_test.npy'), y_test)

    # Save configuration and metadata
    metadata = {
        'n_subjects': config.N_SUBJECTS,
        'n_channels': X_train.shape[1],
        'n_samples': X_train.shape[2],
        'sampling_rate': config.SAMPLING_RATE,
        'segment_duration': config.SEGMENT_DURATION,
        'motor_channels': config.MOTOR_CHANNELS,
        'lowcut': config.LOWCUT,
        'highcut': config.HIGHCUT,
        'segment_counts': segment_counts,
        'train_size': len(X_train),
        'test_size': len(X_test)
    }

    with open(os.path.join(config.PROCESSED_DIR, 'metadata.pkl'), 'wb') as f:
        pickle.dump(metadata, f)

    # Print saved file sizes
    print(f"\nSaved files in {config.PROCESSED_DIR}:")
    for filename in os.listdir(config.PROCESSED_DIR):
        filepath = os.path.join(config.PROCESSED_DIR, filename)
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"  - {filename}: {size_mb:.2f} MB")

    print("\nData saved successfully!")

    return metadata

metadata = save_processed_data(X_train, X_test, y_train, y_test, segment_counts)

# ## 9. (Optional) Mount Google Drive for Persistent Storage

# Uncomment and run this cell to save data to Google Drive
# This allows you to access the processed data in future sessions

"""
from google.colab import drive
drive.mount('/content/drive')

# Copy processed data to Google Drive
import shutil
drive_path = '/content/drive/MyDrive/EEG_Person_ID'
os.makedirs(drive_path, exist_ok=True)

for filename in os.listdir(config.PROCESSED_DIR):
    src = os.path.join(config.PROCESSED_DIR, filename)
    dst = os.path.join(drive_path, filename)
    shutil.copy(src, dst)
    print(f"Copied {filename} to Google Drive")

print(f"\nAll data saved to: {drive_path}")
"""

# ## 10. Summary and Next Steps

print("=" * 60)
print("PREPROCESSING COMPLETE - SUMMARY")
print("=" * 60)
print(f"""
Dataset: PhysioNet EEG Motor Movement/Imagery

Preprocessing Steps:
1. ✓ Downloaded motor imagery runs for {config.N_SUBJECTS} subjects
2. ✓ Selected {config.N_CHANNELS} motor cortex channels
3. ✓ Applied band-pass filter ({config.LOWCUT}-{config.HIGHCUT} Hz)
4. ✓ Segmented into {config.SEGMENT_DURATION}s epochs ({config.SEGMENT_SAMPLES} samples)
5. ✓ Normalized segments (z-score)
6. ✓ Created train/test split (80/20)
7. ✓ Saved processed data

Final Data:
- Training samples: {len(X_train)}
- Test samples: {len(X_test)}
- Shape: (samples, {X_train.shape[1]} channels, {X_train.shape[2]} time points)
- Number of classes: {len(np.unique(y_train))} subjects

Files saved in: {config.PROCESSED_DIR}
- X_train.npy, X_test.npy (EEG data)
- y_train.npy, y_test.npy (labels)
- metadata.pkl (configuration)

Next Step: Run Notebook 2 (CNN + GRU Model) to train the classifier
""")

All libraries imported successfully!
MNE version: 1.6.1
Configuration:
  - Number of subjects: 109
  - Sampling rate: 160 Hz
  - Segment duration: 2.0s (320 samples)
  - Number of channels: 17
  - Motor channels: ['Fc3', 'Fc1', 'Fcz', 'Fc2', 'Fc4', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'Cp3', 'Cp1', 'Cpz', 'Cp2', 'Cp4']
  - Motor imagery runs: [4, 6, 8, 10, 12, 14]
DOWNLOADING PHYSIONET EEG DATA
This will download motor imagery runs for 109 subjects
Runs to download per subject: [4, 6, 8, 10, 12, 14]
------------------------------------------------------------
Created directories: /content/physionet_eeg, /content/processed_data


Downloading: 100%|██████████| 109/109 [1:12:15<00:00, 39.77s/it]


------------------------------------------------------------
Download complete: 109/109 subjects
PROCESSING EEG DATA


Processing: 100%|██████████| 109/109 [00:52<00:00,  2.07it/s]


ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 320 and the array at index 87 has size 256

In [None]:
# Architecture: 1D CNN → GRU → Dense Classification
# Task: Classify which subject (1-109) a given EEG segment belongs to
# =============================================================================

# ## 1. Setup and Imports

!pip install torch torchvision -q
!pip install scikit-learn -q
!pip install seaborn -q

import os
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import time

warnings.filterwarnings('ignore')

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ## 2. Configuration

class ModelConfig:
    # Data paths
    PROCESSED_DIR = '/content/processed_data'
    MODEL_DIR = '/content/models'

    # Model architecture
    N_CHANNELS = 17          # Number of EEG channels (from preprocessing)
    N_SAMPLES = 320          # Samples per segment (2s * 160Hz)
    N_CLASSES = 109          # Number of subjects

    # CNN parameters
    CNN_FILTERS = [32, 64, 128]  # Filters for each conv layer
    CNN_KERNEL_SIZE = 7          # Kernel size for 1D convolution
    CNN_POOL_SIZE = 2            # Max pooling size
    DROPOUT_CNN = 0.3            # Dropout after CNN

    # GRU parameters
    GRU_HIDDEN_SIZE = 128        # Hidden size of GRU
    GRU_NUM_LAYERS = 2           # Number of GRU layers
    GRU_BIDIRECTIONAL = True     # Use bidirectional GRU
    DROPOUT_GRU = 0.3            # Dropout in GRU

    # Dense layer parameters
    DENSE_HIDDEN = 256           # Hidden units before output
    DROPOUT_DENSE = 0.5          # Dropout before output

    # Training parameters
    BATCH_SIZE = 64
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 1e-4           # L2 regularization
    N_EPOCHS = 50
    EARLY_STOPPING_PATIENCE = 10
    LR_SCHEDULER_PATIENCE = 5

    # Random seed
    RANDOM_STATE = 42

config = ModelConfig()

# Set random seeds for reproducibility
torch.manual_seed(config.RANDOM_STATE)
np.random.seed(config.RANDOM_STATE)
if torch.cuda.is_available():
    torch.cuda.manual_seed(config.RANDOM_STATE)

print("Model Configuration:")
print(f"  - Input shape: ({config.N_CHANNELS}, {config.N_SAMPLES})")
print(f"  - CNN filters: {config.CNN_FILTERS}")
print(f"  - GRU hidden: {config.GRU_HIDDEN_SIZE}, layers: {config.GRU_NUM_LAYERS}")
print(f"  - Bidirectional: {config.GRU_BIDIRECTIONAL}")
print(f"  - Output classes: {config.N_CLASSES}")

# ## 3. Load Preprocessed Data

def load_data():
    """Load preprocessed data from disk"""
    print("=" * 60)
    print("LOADING PREPROCESSED DATA")
    print("=" * 60)

    X_train = np.load(os.path.join(config.PROCESSED_DIR, 'X_train.npy'))
    X_test = np.load(os.path.join(config.PROCESSED_DIR, 'X_test.npy'))
    y_train = np.load(os.path.join(config.PROCESSED_DIR, 'y_train.npy'))
    y_test = np.load(os.path.join(config.PROCESSED_DIR, 'y_test.npy'))

    # Load metadata
    with open(os.path.join(config.PROCESSED_DIR, 'metadata.pkl'), 'rb') as f:
        metadata = pickle.load(f)

    print(f"Training data: {X_train.shape}")
    print(f"Test data: {X_test.shape}")
    print(f"Training labels: {y_train.shape}")
    print(f"Test labels: {y_test.shape}")
    print(f"Number of classes: {len(np.unique(y_train))}")

    # Update config with actual data dimensions
    config.N_CHANNELS = X_train.shape[1]
    config.N_SAMPLES = X_train.shape[2]
    config.N_CLASSES = len(np.unique(y_train))

    return X_train, X_test, y_train, y_test, metadata

X_train, X_test, y_train, y_test, metadata = load_data()

# ## 4. Create PyTorch Datasets and DataLoaders

class EEGDataset(Dataset):
    """Custom Dataset for EEG data"""

    def __init__(self, X, y):
        """
        Parameters:
        -----------
        X : np.ndarray
            EEG data of shape (n_samples, n_channels, n_timepoints)
        y : np.ndarray
            Labels of shape (n_samples,)
        """
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def create_dataloaders(X_train, X_test, y_train, y_test, batch_size=None):
    """Create PyTorch DataLoaders"""
    if batch_size is None:
        batch_size = config.BATCH_SIZE

    # Create datasets
    train_dataset = EEGDataset(X_train, y_train)
    test_dataset = EEGDataset(X_test, y_test)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )

    print(f"\nDataLoaders created:")
    print(f"  - Training batches: {len(train_loader)}")
    print(f"  - Test batches: {len(test_loader)}")
    print(f"  - Batch size: {batch_size}")

    return train_loader, test_loader

train_loader, test_loader = create_dataloaders(X_train, X_test, y_train, y_test)

# ## 5. Define CNN + GRU Model Architecture

class CNN_GRU_Model(nn.Module):
    """
    CNN + GRU Hybrid Model for EEG Person Identification

    Architecture:
    1. 1D CNN layers: Extract spatial-frequency features from raw EEG
    2. GRU layers: Capture temporal dynamics
    3. Dense layers: Classification
    """

    def __init__(self, n_channels, n_samples, n_classes, config):
        super(CNN_GRU_Model, self).__init__()

        self.n_channels = n_channels
        self.n_samples = n_samples
        self.n_classes = n_classes

        # =====================================================================
        # CNN LAYERS (1D Convolution along time axis)
        # Input: (batch, channels, time_samples)
        # =====================================================================

        self.cnn_layers = nn.Sequential(
            # Conv Block 1
            nn.Conv1d(n_channels, config.CNN_FILTERS[0],
                     kernel_size=config.CNN_KERNEL_SIZE, padding='same'),
            nn.BatchNorm1d(config.CNN_FILTERS[0]),
            nn.ReLU(),
            nn.MaxPool1d(config.CNN_POOL_SIZE),

            # Conv Block 2
            nn.Conv1d(config.CNN_FILTERS[0], config.CNN_FILTERS[1],
                     kernel_size=config.CNN_KERNEL_SIZE, padding='same'),
            nn.BatchNorm1d(config.CNN_FILTERS[1]),
            nn.ReLU(),
            nn.MaxPool1d(config.CNN_POOL_SIZE),

            # Conv Block 3
            nn.Conv1d(config.CNN_FILTERS[1], config.CNN_FILTERS[2],
                     kernel_size=config.CNN_KERNEL_SIZE, padding='same'),
            nn.BatchNorm1d(config.CNN_FILTERS[2]),
            nn.ReLU(),
            nn.MaxPool1d(config.CNN_POOL_SIZE),

            nn.Dropout(config.DROPOUT_CNN)
        )

        # Calculate CNN output size
        # After 3 pooling layers of size 2: n_samples / 8
        cnn_output_time = n_samples // 8  # 320 / 8 = 40

        # =====================================================================
        # GRU LAYERS
        # Input: (batch, seq_len, features) - we transpose CNN output
        # =====================================================================

        gru_input_size = config.CNN_FILTERS[2]  # 128 features from CNN

        self.gru = nn.GRU(
            input_size=gru_input_size,
            hidden_size=config.GRU_HIDDEN_SIZE,
            num_layers=config.GRU_NUM_LAYERS,
            batch_first=True,
            bidirectional=config.GRU_BIDIRECTIONAL,
            dropout=config.DROPOUT_GRU if config.GRU_NUM_LAYERS > 1 else 0
        )

        # GRU output size
        gru_output_size = config.GRU_HIDDEN_SIZE * (2 if config.GRU_BIDIRECTIONAL else 1)

        # =====================================================================
        # DENSE LAYERS (Classification Head)
        # =====================================================================

        self.classifier = nn.Sequential(
            nn.Linear(gru_output_size, config.DENSE_HIDDEN),
            nn.BatchNorm1d(config.DENSE_HIDDEN),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT_DENSE),
            nn.Linear(config.DENSE_HIDDEN, n_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """
        Forward pass

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch, n_channels, n_samples)

        Returns:
        --------
        output : torch.Tensor
            Class logits of shape (batch, n_classes)
        """
        # CNN: (batch, channels, time) -> (batch, cnn_filters, reduced_time)
        cnn_out = self.cnn_layers(x)

        # Transpose for GRU: (batch, cnn_filters, time) -> (batch, time, cnn_filters)
        gru_in = cnn_out.transpose(1, 2)

        # GRU: (batch, time, features) -> (batch, time, hidden*directions)
        gru_out, hidden = self.gru(gru_in)

        # Take the last output (or concatenate last hidden states for bidirectional)
        # Using the final hidden state from both directions
        if self.gru.bidirectional:
            # Concatenate the final hidden states from both directions
            final_hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
        else:
            final_hidden = hidden[-1]

        # Classification
        output = self.classifier(final_hidden)

        return output

# Create model instance
model = CNN_GRU_Model(
    n_channels=config.N_CHANNELS,
    n_samples=config.N_SAMPLES,
    n_classes=config.N_CLASSES,
    config=config
).to(device)

# Print model summary
print("=" * 60)
print("MODEL ARCHITECTURE")
print("=" * 60)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(2, config.N_CHANNELS, config.N_SAMPLES).to(device)
test_output = model(test_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")

# ## 6. Training Functions

class EarlyStopping:
    """Early stopping to prevent overfitting"""

    def __init__(self, patience=10, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_weights = model.state_dict().copy()
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_weights = model.state_dict().copy()
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.restore_best_weights:
                    model.load_state_dict(self.best_weights)

        return self.early_stop

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch_X, batch_y in train_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)

        # Backward pass
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

        # Store predictions
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy

def evaluate(model, data_loader, criterion, device):
    """Evaluate model on a dataset"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_X, batch_y in data_loader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)

            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return avg_loss, accuracy, f1, all_preds, all_labels

# ## 7. Training Loop

def train_model(model, train_loader, test_loader, config, device):
    """
    Main training function

    Returns:
    --------
    history : dict
        Training history with losses and metrics
    """
    print("=" * 60)
    print("TRAINING CNN + GRU MODEL")
    print("=" * 60)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5,
        patience=config.LR_SCHEDULER_PATIENCE, verbose=True
    )

    # Early stopping
    early_stopping = EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE)

    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_f1': []
    }

    best_val_acc = 0
    start_time = time.time()

    print(f"\nStarting training for {config.N_EPOCHS} epochs...")
    print(f"Batch size: {config.BATCH_SIZE}, LR: {config.LEARNING_RATE}")
    print("-" * 60)

    for epoch in range(config.N_EPOCHS):
        epoch_start = time.time()

        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # Evaluate
        val_loss, val_acc, val_f1, _, _ = evaluate(
            model, test_loader, criterion, device
        )

        # Update scheduler
        scheduler.step(val_loss)

        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)

        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(config.MODEL_DIR, 'best_model.pt'))

        epoch_time = time.time() - epoch_start

        # Print progress
        print(f"Epoch {epoch+1:3d}/{config.N_EPOCHS} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | "
              f"Val F1: {val_f1:.4f} | Time: {epoch_time:.1f}s")

        # Early stopping check
        if early_stopping(val_loss, model):
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            break

    total_time = time.time() - start_time
    print("-" * 60)
    print(f"Training completed in {total_time/60:.1f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.4f}")

    return history

# Create model directory
os.makedirs(config.MODEL_DIR, exist_ok=True)

# Train the model
history = train_model(model, train_loader, test_loader, config, device)

# ## 8. Plot Training History

def plot_training_history(history):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    epochs = range(1, len(history['train_loss']) + 1)

    # Loss plot
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Accuracy plot
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # F1 Score plot
    axes[2].plot(epochs, history['val_f1'], 'g-', label='Validation F1')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('F1 Score')
    axes[2].set_title('Validation F1 Score (Weighted)')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(config.MODEL_DIR, 'training_history.png'), dpi=150)
    plt.show()

    print(f"\nPlot saved to {config.MODEL_DIR}/training_history.png")

plot_training_history(history)

# ## 9. Final Model Evaluation

def final_evaluation(model, test_loader, device):
    """Comprehensive model evaluation"""
    print("=" * 60)
    print("FINAL MODEL EVALUATION")
    print("=" * 60)

    # Load best model
    best_model_path = os.path.join(config.MODEL_DIR, 'best_model.pt')
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path))
        print("Loaded best model weights")

    criterion = nn.CrossEntropyLoss()
    val_loss, val_acc, val_f1, all_preds, all_labels = evaluate(
        model, test_loader, criterion, device
    )

    print(f"\nTest Set Results:")
    print(f"  - Loss: {val_loss:.4f}")
    print(f"  - Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
    print(f"  - F1 Score (weighted): {val_f1:.4f}")

    # Additional metrics
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    f1_micro = f1_score(all_labels, all_preds, average='micro')
    print(f"  - F1 Score (macro): {f1_macro:.4f}")
    print(f"  - F1 Score (micro): {f1_micro:.4f}")

    # Top-5 accuracy
    # We need to get probabilities for this
    model.eval()
    all_probs = []
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X = batch_X.to(device)
            outputs = model(batch_X)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            all_probs.extend(probs)

    all_probs = np.array(all_probs)
    top5_preds = np.argsort(all_probs, axis=1)[:, -5:]
    top5_correct = sum([1 for i, label in enumerate(all_labels) if label in top5_preds[i]])
    top5_acc = top5_correct / len(all_labels)
    print(f"  - Top-5 Accuracy: {top5_acc:.4f} ({top5_acc*100:.2f}%)")

    return all_preds, all_labels, all_probs

all_preds, all_labels, all_probs = final_evaluation(model, test_loader, device)

# ## 10. Save Results

def save_results(history, all_preds, all_labels, all_probs, config):
    """Save all results for the report notebook"""

    results = {
        'history': history,
        'predictions': all_preds,
        'true_labels': all_labels,
        'probabilities': all_probs,
        'n_classes': config.N_CLASSES,
        'config': {
            'n_channels': config.N_CHANNELS,
            'n_samples': config.N_SAMPLES,
            'n_classes': config.N_CLASSES,
            'cnn_filters': config.CNN_FILTERS,
            'gru_hidden': config.GRU_HIDDEN_SIZE,
            'gru_layers': config.GRU_NUM_LAYERS,
            'bidirectional': config.GRU_BIDIRECTIONAL,
            'batch_size': config.BATCH_SIZE,
            'learning_rate': config.LEARNING_RATE
        }
    }

    with open(os.path.join(config.MODEL_DIR, 'results.pkl'), 'wb') as f:
        pickle.dump(results, f)

    print(f"\nResults saved to {config.MODEL_DIR}/results.pkl")

save_results(history, all_preds, all_labels, all_probs, config)

# ## 11. Save Final Model

def save_complete_model(model, config):
    """Save complete model with architecture"""

    # Save full model (architecture + weights)
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': {
            'n_channels': config.N_CHANNELS,
            'n_samples': config.N_SAMPLES,
            'n_classes': config.N_CLASSES,
            'cnn_filters': config.CNN_FILTERS,
            'cnn_kernel_size': config.CNN_KERNEL_SIZE,
            'gru_hidden_size': config.GRU_HIDDEN_SIZE,
            'gru_num_layers': config.GRU_NUM_LAYERS,
            'gru_bidirectional': config.GRU_BIDIRECTIONAL,
            'dense_hidden': config.DENSE_HIDDEN,
            'dropout_cnn': config.DROPOUT_CNN,
            'dropout_gru': config.DROPOUT_GRU,
            'dropout_dense': config.DROPOUT_DENSE
        }
    }, os.path.join(config.MODEL_DIR, 'complete_model.pt'))

    print(f"Complete model saved to {config.MODEL_DIR}/complete_model.pt")

    # Print saved files
    print(f"\nSaved files in {config.MODEL_DIR}:")
    for f in os.listdir(config.MODEL_DIR):
        fpath = os.path.join(config.MODEL_DIR, f)
        size = os.path.getsize(fpath) / (1024*1024)
        print(f"  - {f}: {size:.2f} MB")

save_complete_model(model, config)

# ## 12. Summary

print("=" * 60)
print("NOTEBOOK 2 COMPLETE - SUMMARY")
print("=" * 60)
print(f"""
Model: CNN + GRU Hybrid for EEG Person Identification

Architecture:
- Input: ({config.N_CHANNELS} channels, {config.N_SAMPLES} samples)
- CNN: 3 Conv1D blocks with filters {config.CNN_FILTERS}
- GRU: {config.GRU_NUM_LAYERS} layers, {config.GRU_HIDDEN_SIZE} hidden units
- Bidirectional: {config.GRU_BIDIRECTIONAL}
- Output: {config.N_CLASSES} classes (subjects)

Training:
- Epochs trained: {len(history['train_loss'])}
- Best validation accuracy: {max(history['val_acc']):.4f}
- Final F1 score: {history['val_f1'][-1]:.4f}

Saved Files:
- {config.MODEL_DIR}/best_model.pt (best weights)
- {config.MODEL_DIR}/complete_model.pt (full model)
- {config.MODEL_DIR}/results.pkl (evaluation results)
- {config.MODEL_DIR}/training_history.png (training plots)

Next Step: Run Notebook 3 to generate the performance report
""")

In [None]:
# ## 1. Setup and Load Results

import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    classification_report, confusion_matrix, top_k_accuracy_score
)
from collections import Counter
import warnings

warnings.filterwarnings('ignore')

# Set style for plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Paths
MODEL_DIR = '/content/models'
REPORT_DIR = '/content/report'
os.makedirs(REPORT_DIR, exist_ok=True)

print("Libraries loaded successfully!")

def load_results():
    """Load saved results from model training"""
    print("=" * 60)
    print("LOADING RESULTS")
    print("=" * 60)

    with open(os.path.join(MODEL_DIR, 'results.pkl'), 'rb') as f:
        results = pickle.load(f)

    print(f"Results loaded successfully!")
    print(f"  - Number of test samples: {len(results['true_labels'])}")
    print(f"  - Number of classes: {results['n_classes']}")

    return results

results = load_results()

# Extract data
y_true = np.array(results['true_labels'])
y_pred = np.array(results['predictions'])
y_probs = np.array(results['probabilities'])
history = results['history']
model_config = results['config']

# ## 2. Overall Performance Metrics

def calculate_overall_metrics(y_true, y_pred, y_probs):
    """Calculate comprehensive metrics"""
    print("=" * 60)
    print("OVERALL PERFORMANCE METRICS")
    print("=" * 60)

    metrics = {}

    # Basic metrics
    metrics['accuracy'] = accuracy_score(y_true, y_pred)
    metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted')
    metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro')
    metrics['f1_micro'] = f1_score(y_true, y_pred, average='micro')
    metrics['precision_weighted'] = precision_score(y_true, y_pred, average='weighted')
    metrics['recall_weighted'] = recall_score(y_true, y_pred, average='weighted')

    # Top-K accuracy
    for k in [1, 3, 5, 10]:
        metrics[f'top{k}_accuracy'] = top_k_accuracy_score(y_true, y_probs, k=k)

    # Print results
    print("\n┌─────────────────────────────────────────────────────────┐")
    print("│              CLASSIFICATION METRICS                      │")
    print("├─────────────────────────────────────────────────────────┤")
    print(f"│  Accuracy (Top-1):          {metrics['accuracy']*100:>6.2f}%                     │")
    print(f"│  Top-3 Accuracy:            {metrics['top3_accuracy']*100:>6.2f}%                     │")
    print(f"│  Top-5 Accuracy:            {metrics['top5_accuracy']*100:>6.2f}%                     │")
    print(f"│  Top-10 Accuracy:           {metrics['top10_accuracy']*100:>6.2f}%                     │")
    print("├─────────────────────────────────────────────────────────┤")
    print(f"│  F1 Score (Weighted):       {metrics['f1_weighted']:.4f}                       │")
    print(f"│  F1 Score (Macro):          {metrics['f1_macro']:.4f}                       │")
    print(f"│  F1 Score (Micro):          {metrics['f1_micro']:.4f}                       │")
    print("├─────────────────────────────────────────────────────────┤")
    print(f"│  Precision (Weighted):      {metrics['precision_weighted']:.4f}                       │")
    print(f"│  Recall (Weighted):         {metrics['recall_weighted']:.4f}                       │")
    print("└─────────────────────────────────────────────────────────┘")

    return metrics

metrics = calculate_overall_metrics(y_true, y_pred, y_probs)

# ## 3. Confusion Matrix

def plot_confusion_matrix(y_true, y_pred, n_classes):
    """Generate and plot confusion matrix"""
    print("\n" + "=" * 60)
    print("CONFUSION MATRIX")
    print("=" * 60)

    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    # Create figure with two versions: full and normalized
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))

    # Raw confusion matrix (subset for visibility)
    # For 109 classes, showing full matrix is impractical
    # We'll show a subset and also the normalized version

    # Plot 1: Full normalized confusion matrix as heatmap
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # Handle division by zero

    im1 = axes[0].imshow(cm_normalized, cmap='Blues', aspect='auto')
    axes[0].set_title('Normalized Confusion Matrix (All 109 Subjects)', fontsize=12)
    axes[0].set_xlabel('Predicted Subject', fontsize=10)
    axes[0].set_ylabel('True Subject', fontsize=10)
    plt.colorbar(im1, ax=axes[0], label='Proportion')

    # Plot 2: Diagonal analysis - per-class accuracy
    per_class_acc = np.diag(cm_normalized)

    axes[1].bar(range(n_classes), per_class_acc, color='steelblue', alpha=0.7)
    axes[1].axhline(y=np.mean(per_class_acc), color='red', linestyle='--',
                    label=f'Mean: {np.mean(per_class_acc):.3f}')
    axes[1].set_title('Per-Subject Classification Accuracy', fontsize=12)
    axes[1].set_xlabel('Subject ID', fontsize=10)
    axes[1].set_ylabel('Accuracy', fontsize=10)
    axes[1].legend()
    axes[1].set_xlim(-1, n_classes)

    plt.tight_layout()
    plt.savefig(os.path.join(REPORT_DIR, 'confusion_matrix.png'), dpi=150, bbox_inches='tight')
    plt.show()

    # Print statistics
    print(f"\nPer-class accuracy statistics:")
    print(f"  - Mean: {np.mean(per_class_acc):.4f}")
    print(f"  - Std: {np.std(per_class_acc):.4f}")
    print(f"  - Min: {np.min(per_class_acc):.4f} (Subject {np.argmin(per_class_acc)+1})")
    print(f"  - Max: {np.max(per_class_acc):.4f} (Subject {np.argmax(per_class_acc)+1})")

    return cm, cm_normalized, per_class_acc

cm, cm_normalized, per_class_acc = plot_confusion_matrix(y_true, y_pred, results['n_classes'])

def plot_detailed_confusion_subset(y_true, y_pred, n_show=20):
    """Plot detailed confusion matrix for a subset of classes"""

    # Select subjects with most samples for clearer visualization
    label_counts = Counter(y_true)
    top_subjects = [x[0] for x in label_counts.most_common(n_show)]
    top_subjects.sort()

    # Filter data for these subjects
    mask = np.isin(y_true, top_subjects)
    y_true_subset = y_true[mask]
    y_pred_subset = y_pred[mask]

    # Remap labels to 0-19 for visualization
    label_map = {old: new for new, old in enumerate(top_subjects)}
    y_true_remapped = np.array([label_map[y] for y in y_true_subset])
    y_pred_remapped = np.array([label_map.get(y, -1) for y in y_pred_subset])

    # Only keep predictions that fall within our subset
    valid_mask = y_pred_remapped >= 0
    y_true_remapped = y_true_remapped[valid_mask]
    y_pred_remapped = y_pred_remapped[valid_mask]

    # Calculate confusion matrix
    cm_subset = confusion_matrix(y_true_remapped, y_pred_remapped,
                                  labels=range(n_show))

    # Plot
    plt.figure(figsize=(14, 10))
    sns.heatmap(cm_subset, annot=True, fmt='d', cmap='Blues',
                xticklabels=[f'S{s+1}' for s in top_subjects],
                yticklabels=[f'S{s+1}' for s in top_subjects])
    plt.title(f'Confusion Matrix (Top {n_show} Subjects by Sample Count)', fontsize=14)
    plt.xlabel('Predicted Subject', fontsize=12)
    plt.ylabel('True Subject', fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(REPORT_DIR, 'confusion_matrix_subset.png'), dpi=150)
    plt.show()

    print(f"\nDetailed confusion matrix saved for {n_show} subjects")

plot_detailed_confusion_subset(y_true, y_pred, n_show=20)

# ## 4. Per-Class F1 Score Analysis

def analyze_per_class_f1(y_true, y_pred, n_classes):
    """Analyze F1 scores per class"""
    print("\n" + "=" * 60)
    print("PER-CLASS F1 SCORE ANALYSIS")
    print("=" * 60)

    # Calculate per-class metrics
    f1_per_class = f1_score(y_true, y_pred, average=None, labels=range(n_classes))
    precision_per_class = precision_score(y_true, y_pred, average=None,
                                          labels=range(n_classes), zero_division=0)
    recall_per_class = recall_score(y_true, y_pred, average=None,
                                    labels=range(n_classes), zero_division=0)

    # Plot F1 distribution
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # F1 Score distribution
    axes[0, 0].hist(f1_per_class, bins=20, color='steelblue', edgecolor='black', alpha=0.7)
    axes[0, 0].axvline(x=np.mean(f1_per_class), color='red', linestyle='--',
                       label=f'Mean: {np.mean(f1_per_class):.3f}')
    axes[0, 0].set_title('F1 Score Distribution Across Subjects')
    axes[0, 0].set_xlabel('F1 Score')
    axes[0, 0].set_ylabel('Number of Subjects')
    axes[0, 0].legend()

    # Precision vs Recall scatter
    axes[0, 1].scatter(precision_per_class, recall_per_class, alpha=0.6, c='steelblue')
    axes[0, 1].plot([0, 1], [0, 1], 'r--', alpha=0.5)
    axes[0, 1].set_title('Precision vs Recall per Subject')
    axes[0, 1].set_xlabel('Precision')
    axes[0, 1].set_ylabel('Recall')
    axes[0, 1].set_xlim(-0.05, 1.05)
    axes[0, 1].set_ylim(-0.05, 1.05)

    # Sorted F1 scores
    sorted_indices = np.argsort(f1_per_class)
    axes[1, 0].bar(range(n_classes), f1_per_class[sorted_indices],
                   color='steelblue', alpha=0.7)
    axes[1, 0].axhline(y=np.mean(f1_per_class), color='red', linestyle='--')
    axes[1, 0].set_title('F1 Scores Sorted (Ascending)')
    axes[1, 0].set_xlabel('Subject Rank')
    axes[1, 0].set_ylabel('F1 Score')

    # Best and worst performers
    n_show = 15
    best_idx = sorted_indices[-n_show:][::-1]
    worst_idx = sorted_indices[:n_show]

    x_labels = [f'S{i+1}' for i in worst_idx] + ['...'] + [f'S{i+1}' for i in best_idx]
    x_values = list(f1_per_class[worst_idx]) + [np.nan] + list(f1_per_class[best_idx])
    colors = ['red']*n_show + ['white'] + ['green']*n_show

    axes[1, 1].bar(range(len(x_labels)), x_values, color=colors, alpha=0.7, edgecolor='black')
    axes[1, 1].set_xticks(range(len(x_labels)))
    axes[1, 1].set_xticklabels(x_labels, rotation=45, ha='right')
    axes[1, 1].set_title(f'Best and Worst {n_show} Subjects by F1 Score')
    axes[1, 1].set_ylabel('F1 Score')

    plt.tight_layout()
    plt.savefig(os.path.join(REPORT_DIR, 'f1_analysis.png'), dpi=150, bbox_inches='tight')
    plt.show()

    # Print statistics
    print(f"\nF1 Score Statistics:")
    print(f"  - Mean: {np.mean(f1_per_class):.4f}")
    print(f"  - Std: {np.std(f1_per_class):.4f}")
    print(f"  - Min: {np.min(f1_per_class):.4f}")
    print(f"  - Max: {np.max(f1_per_class):.4f}")

    print(f"\nTop 5 Best Performing Subjects:")
    for i, idx in enumerate(sorted_indices[-5:][::-1]):
        print(f"  {i+1}. Subject {idx+1}: F1={f1_per_class[idx]:.4f}")

    print(f"\nTop 5 Worst Performing Subjects:")
    for i, idx in enumerate(sorted_indices[:5]):
        print(f"  {i+1}. Subject {idx+1}: F1={f1_per_class[idx]:.4f}")

    return f1_per_class, precision_per_class, recall_per_class

f1_per_class, precision_per_class, recall_per_class = analyze_per_class_f1(
    y_true, y_pred, results['n_classes']
)

# ## 5. Training History Analysis

def plot_training_analysis(history):
    """Detailed analysis of training history"""
    print("\n" + "=" * 60)
    print("TRAINING ANALYSIS")
    print("=" * 60)

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    epochs = range(1, len(history['train_loss']) + 1)

    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
    axes[0, 0].set_title('Loss Curves', fontsize=12)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy curves
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Training', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Validation', linewidth=2)
    axes[0, 1].set_title('Accuracy Curves', fontsize=12)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Gap between train and val (overfitting indicator)
    train_val_gap = np.array(history['train_acc']) - np.array(history['val_acc'])
    axes[1, 0].plot(epochs, train_val_gap, 'purple', linewidth=2)
    axes[1, 0].axhline(y=0, color='gray', linestyle='--')
    axes[1, 0].fill_between(epochs, 0, train_val_gap, alpha=0.3, color='purple')
    axes[1, 0].set_title('Train-Validation Accuracy Gap (Overfitting Indicator)', fontsize=12)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Gap (Train - Val)')
    axes[1, 0].grid(True, alpha=0.3)

    # F1 score progression
    axes[1, 1].plot(epochs, history['val_f1'], 'g-', linewidth=2)
    axes[1, 1].set_title('Validation F1 Score Progression', fontsize=12)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('F1 Score (Weighted)')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(REPORT_DIR, 'training_analysis.png'), dpi=150, bbox_inches='tight')
    plt.show()

    # Print training statistics
    print(f"\nTraining Statistics:")
    print(f"  - Total epochs: {len(history['train_loss'])}")
    print(f"  - Best validation accuracy: {max(history['val_acc']):.4f} "
          f"(Epoch {np.argmax(history['val_acc'])+1})")
    print(f"  - Best validation F1: {max(history['val_f1']):.4f}")
    print(f"  - Final train-val gap: {train_val_gap[-1]:.4f}")

    # Convergence analysis
    final_10_val_loss = history['val_loss'][-10:]
    convergence_std = np.std(final_10_val_loss)
    print(f"  - Convergence (last 10 epochs loss std): {convergence_std:.6f}")

plot_training_analysis(history)

# ## 6. Error Analysis

def analyze_errors(y_true, y_pred, y_probs, n_classes):
    """Analyze model errors"""
    print("\n" + "=" * 60)
    print("ERROR ANALYSIS")
    print("=" * 60)

    # Find misclassified samples
    errors = y_true != y_pred
    error_indices = np.where(errors)[0]

    print(f"\nError Statistics:")
    print(f"  - Total test samples: {len(y_true)}")
    print(f"  - Correct predictions: {np.sum(~errors)}")
    print(f"  - Incorrect predictions: {np.sum(errors)}")
    print(f"  - Error rate: {np.mean(errors)*100:.2f}%")

    # Analyze confidence of errors
    correct_conf = y_probs[~errors].max(axis=1)
    error_conf = y_probs[errors].max(axis=1)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Confidence distribution
    axes[0].hist(correct_conf, bins=30, alpha=0.6, label='Correct', color='green', density=True)
    axes[0].hist(error_conf, bins=30, alpha=0.6, label='Incorrect', color='red', density=True)
    axes[0].set_title('Prediction Confidence Distribution')
    axes[0].set_xlabel('Max Probability (Confidence)')
    axes[0].set_ylabel('Density')
    axes[0].legend()

    # Most common confusion pairs
    confusion_pairs = []
    for i in error_indices:
        confusion_pairs.append((y_true[i], y_pred[i]))

    pair_counts = Counter(confusion_pairs)
    top_confusions = pair_counts.most_common(15)

    if top_confusions:
        labels = [f'S{t+1}→S{p+1}' for (t, p), _ in top_confusions]
        counts = [c for _, c in top_confusions]

        axes[1].barh(range(len(labels)), counts, color='coral')
        axes[1].set_yticks(range(len(labels)))
        axes[1].set_yticklabels(labels)
        axes[1].set_xlabel('Count')
        axes[1].set_title('Most Common Confusion Pairs')
        axes[1].invert_yaxis()

    plt.tight_layout()
    plt.savefig(os.path.join(REPORT_DIR, 'error_analysis.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\nConfidence Analysis:")
    print(f"  - Mean confidence (correct): {np.mean(correct_conf):.4f}")
    print(f"  - Mean confidence (incorrect): {np.mean(error_conf):.4f}")

    if top_confusions:
        print(f"\nTop 5 Most Confused Subject Pairs:")
        for i, ((true_label, pred_label), count) in enumerate(top_confusions[:5]):
            print(f"  {i+1}. Subject {true_label+1} → Subject {pred_label+1}: {count} times")

analyze_errors(y_true, y_pred, y_probs, results['n_classes'])

In [None]:
# EEG Spectrograms and t-SNE Feature Embeddings
# =============================================================================

# ## 1. Setup


!pip install torch torchvision -q
!pip install scikit-learn -q
!pip install scipy -q


import os
import numpy as np
import pickle
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Paths
PROCESSED_DIR = '/content/processed_data'
MODEL_DIR = '/content/models'
VIZ_DIR = '/content/visualizations'
os.makedirs(VIZ_DIR, exist_ok=True)
REPORT_DIR = '/content/report'
os.makedirs(REPORT_DIR, exist_ok=True)
# ## 2. Load Data and Model

# Load preprocessed data
X_test = np.load(os.path.join(PROCESSED_DIR, 'X_test.npy'))
y_test = np.load(os.path.join(PROCESSED_DIR, 'y_test.npy'))

with open(os.path.join(PROCESSED_DIR, 'metadata.pkl'), 'rb') as f:
    metadata = pickle.load(f)

print(f"Test data shape: {X_test.shape}")
print(f"Number of subjects: {len(np.unique(y_test))}")

# Configuration
SAMPLING_RATE = metadata['sampling_rate']  # 160 Hz
MOTOR_CHANNELS = metadata['motor_channels']
N_CHANNELS = X_test.shape[1]
N_SAMPLES = X_test.shape[2]

# Load the trained model
class CNN_GRU_Model(nn.Module):
    """Recreate model architecture for loading weights"""

    def __init__(self, n_channels, n_samples, n_classes,
                 cnn_filters=[32, 64, 128], kernel_size=7,
                 gru_hidden=128, gru_layers=2, bidirectional=True):
        super(CNN_GRU_Model, self).__init__()

        self.cnn_layers = nn.Sequential(
            nn.Conv1d(n_channels, cnn_filters[0], kernel_size=kernel_size, padding='same'),
            nn.BatchNorm1d(cnn_filters[0]),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=kernel_size, padding='same'),
            nn.BatchNorm1d(cnn_filters[1]),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=kernel_size, padding='same'),
            nn.BatchNorm1d(cnn_filters[2]),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.3)
        )

        self.gru = nn.GRU(
            input_size=cnn_filters[2],
            hidden_size=gru_hidden,
            num_layers=gru_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.3 if gru_layers > 1 else 0
        )

        gru_output_size = gru_hidden * (2 if bidirectional else 1)

        self.classifier = nn.Sequential(
            nn.Linear(gru_output_size, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, n_classes)
        )

        self.gru_hidden = gru_hidden
        self.bidirectional = bidirectional

    def forward(self, x):
        cnn_out = self.cnn_layers(x)
        gru_in = cnn_out.transpose(1, 2)
        gru_out, hidden = self.gru(gru_in)

        if self.bidirectional:
            final_hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
        else:
            final_hidden = hidden[-1]

        output = self.classifier(final_hidden)
        return output

    def get_features(self, x):
        """Extract features before classification layer"""
        cnn_out = self.cnn_layers(x)
        gru_in = cnn_out.transpose(1, 2)
        gru_out, hidden = self.gru(gru_in)

        if self.bidirectional:
            final_hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
        else:
            final_hidden = hidden[-1]

        return final_hidden

    def get_cnn_features(self, x):
        """Extract CNN features only"""
        return self.cnn_layers(x)

# Load model
checkpoint = torch.load(os.path.join(MODEL_DIR, 'complete_model.pt'), map_location=device)
model_config = checkpoint['config']

model = CNN_GRU_Model(
    n_channels=model_config['n_channels'],
    n_samples=model_config['n_samples'],
    n_classes=model_config['n_classes'],
    cnn_filters=model_config['cnn_filters'],
    gru_hidden=model_config['gru_hidden_size'],
    gru_layers=model_config['gru_num_layers'],
    bidirectional=model_config['gru_bidirectional']
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")

# ## 3. EEG Signal Visualization

def plot_eeg_signals(X, y, subject_ids=[0, 1, 2], channel_names=None):
    """
    Plot raw EEG signals for different subjects
    """
    print("=" * 60)
    print("EEG SIGNAL VISUALIZATION")
    print("=" * 60)

    n_subjects = len(subject_ids)
    fig, axes = plt.subplots(n_subjects, 1, figsize=(14, 3*n_subjects))

    if n_subjects == 1:
        axes = [axes]

    time = np.arange(N_SAMPLES) / SAMPLING_RATE

    for idx, subject_id in enumerate(subject_ids):
        # Find a sample from this subject
        sample_indices = np.where(y == subject_id)[0]
        if len(sample_indices) == 0:
            continue

        sample_idx = sample_indices[0]
        eeg_data = X[sample_idx]

        # Plot each channel with offset
        for ch in range(min(N_CHANNELS, 10)):  # Plot max 10 channels
            offset = ch * 3  # Vertical offset
            axes[idx].plot(time, eeg_data[ch] + offset, linewidth=0.8,
                          label=MOTOR_CHANNELS[ch] if channel_names else f'Ch{ch}')

        axes[idx].set_title(f'Subject {subject_id + 1} - EEG Segment', fontsize=12)
        axes[idx].set_xlabel('Time (s)')
        axes[idx].set_ylabel('Amplitude (normalized)')
        axes[idx].set_xlim([0, time[-1]])

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'eeg_signals.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"Plot saved to {VIZ_DIR}/eeg_signals.png")

# Plot EEG signals for 3 different subjects
plot_eeg_signals(X_test, y_test, subject_ids=[0, 50, 100], channel_names=MOTOR_CHANNELS)

# ## 4. Spectrogram Visualization

def compute_spectrogram(signal_data, fs, nperseg=64, noverlap=48):
    """
    Compute spectrogram using Short-Time Fourier Transform
    """
    f, t, Sxx = signal.spectrogram(signal_data, fs=fs,
                                    nperseg=nperseg, noverlap=noverlap)
    return f, t, Sxx

def plot_spectrograms(X, y, subject_ids=[0, 1, 2], channel_idx=5):
    """
    Plot spectrograms for EEG segments from different subjects
    """
    print("\n" + "=" * 60)
    print("SPECTROGRAM VISUALIZATION")
    print("=" * 60)

    n_subjects = len(subject_ids)
    fig, axes = plt.subplots(n_subjects, 2, figsize=(14, 4*n_subjects))

    if n_subjects == 1:
        axes = axes.reshape(1, -1)

    for idx, subject_id in enumerate(subject_ids):
        # Find a sample from this subject
        sample_indices = np.where(y == subject_id)[0]
        if len(sample_indices) == 0:
            continue

        sample_idx = sample_indices[0]
        eeg_segment = X[sample_idx, channel_idx, :]

        # Compute spectrogram
        f, t, Sxx = compute_spectrogram(eeg_segment, SAMPLING_RATE)

        # Plot raw signal
        time = np.arange(len(eeg_segment)) / SAMPLING_RATE
        axes[idx, 0].plot(time, eeg_segment, 'b-', linewidth=1)
        axes[idx, 0].set_title(f'Subject {subject_id+1} - Raw EEG (Channel: {MOTOR_CHANNELS[channel_idx]})')
        axes[idx, 0].set_xlabel('Time (s)')
        axes[idx, 0].set_ylabel('Amplitude')

        # Plot spectrogram
        im = axes[idx, 1].pcolormesh(t, f, 10 * np.log10(Sxx + 1e-10),
                                      shading='gouraud', cmap='viridis')
        axes[idx, 1].set_title(f'Subject {subject_id+1} - Spectrogram')
        axes[idx, 1].set_xlabel('Time (s)')
        axes[idx, 1].set_ylabel('Frequency (Hz)')
        axes[idx, 1].set_ylim([0, 50])  # Focus on relevant frequencies
        plt.colorbar(im, ax=axes[idx, 1], label='Power (dB)')

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'spectrograms.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"Plot saved to {VIZ_DIR}/spectrograms.png")

plot_spectrograms(X_test, y_test, subject_ids=[0, 50, 100], channel_idx=5)

def plot_average_spectrograms(X, y, n_subjects_show=6):
    """
    Plot average spectrograms per subject to show subject-specific patterns
    """
    print("\n" + "=" * 60)
    print("AVERAGE SPECTROGRAMS PER SUBJECT")
    print("=" * 60)

    # Select subjects with enough samples
    unique_subjects = np.unique(y)
    np.random.seed(42)
    selected_subjects = np.random.choice(unique_subjects,
                                          size=min(n_subjects_show, len(unique_subjects)),
                                          replace=False)
    selected_subjects.sort()

    n_cols = 3
    n_rows = (len(selected_subjects) + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4*n_rows))
    axes = axes.flatten()

    for idx, subject_id in enumerate(selected_subjects):
        # Get all samples for this subject
        subject_mask = y == subject_id
        subject_data = X[subject_mask]

        # Average spectrogram across all segments and channels
        avg_spectrogram = np.zeros((33, 9))  # Will accumulate
        count = 0

        for segment in subject_data[:20]:  # Use max 20 segments
            for ch in range(min(5, N_CHANNELS)):  # Average over channels
                f, t, Sxx = compute_spectrogram(segment[ch], SAMPLING_RATE)
                if Sxx.shape == avg_spectrogram.shape:
                    avg_spectrogram += Sxx
                    count += 1

        if count > 0:
            avg_spectrogram /= count

        # Plot
        im = axes[idx].pcolormesh(t, f, 10 * np.log10(avg_spectrogram + 1e-10),
                                   shading='gouraud', cmap='magma')
        axes[idx].set_title(f'Subject {subject_id + 1}', fontsize=11)
        axes[idx].set_xlabel('Time (s)', fontsize=9)
        axes[idx].set_ylabel('Frequency (Hz)', fontsize=9)
        axes[idx].set_ylim([0, 45])

    # Hide empty subplots
    for idx in range(len(selected_subjects), len(axes)):
        axes[idx].set_visible(False)

    plt.suptitle('Average Spectrograms by Subject (Motor Cortex Channels)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'avg_spectrograms.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"Plot saved to {VIZ_DIR}/avg_spectrograms.png")

plot_average_spectrograms(X_test, y_test, n_subjects_show=9)

# ## 5. t-SNE Feature Embedding Visualization

def extract_features(model, X, batch_size=64):
    """
    Extract deep features from the trained model
    """
    model.eval()
    features = []

    n_samples = len(X)
    n_batches = (n_samples + batch_size - 1) // batch_size

    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Extracting features"):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_samples)

            batch = torch.FloatTensor(X[start_idx:end_idx]).to(device)
            batch_features = model.get_features(batch)
            features.append(batch_features.cpu().numpy())

    return np.concatenate(features, axis=0)

# Extract features
print("=" * 60)
print("EXTRACTING DEEP FEATURES")
print("=" * 60)

# Use a subset for faster computation
max_samples = 5000
if len(X_test) > max_samples:
    np.random.seed(42)
    indices = np.random.choice(len(X_test), max_samples, replace=False)
    X_subset = X_test[indices]
    y_subset = y_test[indices]
else:
    X_subset = X_test
    y_subset = y_test

features = extract_features(model, X_subset)
print(f"Feature shape: {features.shape}")

def plot_tsne(features, labels, n_components=2, perplexity=30, n_iter=1000):
    """
    Visualize features using t-SNE
    """
    print("\n" + "=" * 60)
    print("t-SNE VISUALIZATION")
    print("=" * 60)

    print(f"Running t-SNE on {len(features)} samples...")
    print(f"Parameters: perplexity={perplexity}, n_iter={n_iter}")

    # First reduce with PCA if features are high-dimensional
    if features.shape[1] > 50:
        print("Applying PCA reduction first...")
        pca = PCA(n_components=50)
        features_pca = pca.fit_transform(features)
        print(f"PCA variance explained: {sum(pca.explained_variance_ratio_)*100:.1f}%")
    else:
        features_pca = features

    # Apply t-SNE
    tsne = TSNE(n_components=n_components, perplexity=perplexity,
                n_iter=n_iter, random_state=42, verbose=1)
    features_tsne = tsne.fit_transform(features_pca)

    print(f"t-SNE complete! Output shape: {features_tsne.shape}")

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))

    # Plot 1: All subjects colored
    n_subjects = len(np.unique(labels))
    colors = plt.cm.nipy_spectral(np.linspace(0, 1, n_subjects))

    for subject_id in np.unique(labels):
        mask = labels == subject_id
        axes[0].scatter(features_tsne[mask, 0], features_tsne[mask, 1],
                       c=[colors[subject_id]], alpha=0.5, s=10, label=f'S{subject_id+1}')

    axes[0].set_title('t-SNE: All Subjects', fontsize=12)
    axes[0].set_xlabel('t-SNE Dimension 1')
    axes[0].set_ylabel('t-SNE Dimension 2')

    # Plot 2: Highlight specific subjects
    highlight_subjects = [0, 25, 50, 75, 100]
    highlight_subjects = [s for s in highlight_subjects if s in np.unique(labels)]

    # Plot all in gray
    axes[1].scatter(features_tsne[:, 0], features_tsne[:, 1],
                   c='lightgray', alpha=0.3, s=10, label='Other')

    # Highlight selected subjects
    highlight_colors = ['red', 'blue', 'green', 'orange', 'purple']
    for i, subject_id in enumerate(highlight_subjects):
        mask = labels == subject_id
        axes[1].scatter(features_tsne[mask, 0], features_tsne[mask, 1],
                       c=highlight_colors[i], alpha=0.8, s=30,
                       label=f'Subject {subject_id+1}')

    axes[1].set_title('t-SNE: Highlighted Subjects', fontsize=12)
    axes[1].set_xlabel('t-SNE Dimension 1')
    axes[1].set_ylabel('t-SNE Dimension 2')
    axes[1].legend(loc='upper right')

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'tsne_embeddings.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"Plot saved to {VIZ_DIR}/tsne_embeddings.png")

    return features_tsne

features_tsne = plot_tsne(features, y_subset, perplexity=30, n_iter=1000)

def plot_tsne_quality_analysis(features_tsne, labels):
    """
    Analyze the quality of t-SNE clustering
    """
    print("\n" + "=" * 60)
    print("t-SNE CLUSTER QUALITY ANALYSIS")
    print("=" * 60)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Calculate cluster centroids
    unique_labels = np.unique(labels)
    centroids = []
    spreads = []

    for label in unique_labels:
        mask = labels == label
        points = features_tsne[mask]
        centroid = np.mean(points, axis=0)
        spread = np.mean(np.linalg.norm(points - centroid, axis=1))
        centroids.append(centroid)
        spreads.append(spread)

    centroids = np.array(centroids)
    spreads = np.array(spreads)

    # Plot 1: Cluster spread distribution
    axes[0].hist(spreads, bins=20, color='steelblue', edgecolor='black', alpha=0.7)
    axes[0].axvline(np.mean(spreads), color='red', linestyle='--',
                    label=f'Mean: {np.mean(spreads):.2f}')
    axes[0].set_title('Distribution of Cluster Spreads')
    axes[0].set_xlabel('Cluster Spread (distance from centroid)')
    axes[0].set_ylabel('Number of Subjects')
    axes[0].legend()

    # Plot 2: Cluster separation analysis
    # Calculate inter-cluster distances
    from scipy.spatial.distance import pdist
    inter_cluster_dist = pdist(centroids)

    axes[1].hist(inter_cluster_dist, bins=30, color='coral', edgecolor='black', alpha=0.7)
    axes[1].axvline(np.mean(inter_cluster_dist), color='red', linestyle='--',
                    label=f'Mean: {np.mean(inter_cluster_dist):.2f}')
    axes[1].set_title('Distribution of Inter-Cluster Distances')
    axes[1].set_xlabel('Distance Between Cluster Centroids')
    axes[1].set_ylabel('Count')
    axes[1].legend()

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'tsne_quality.png'), dpi=150, bbox_inches='tight')
    plt.show()

    # Print statistics
    print(f"\nCluster Quality Metrics:")
    print(f"  - Mean intra-cluster spread: {np.mean(spreads):.4f}")
    print(f"  - Mean inter-cluster distance: {np.mean(inter_cluster_dist):.4f}")
    print(f"  - Separation ratio: {np.mean(inter_cluster_dist)/np.mean(spreads):.4f}")
    print(f"    (Higher is better - indicates well-separated clusters)")

plot_tsne_quality_analysis(features_tsne, y_subset)

# ## 6. CNN Feature Maps Visualization

def visualize_cnn_features(model, X, sample_indices=[0, 1, 2]):
    """
    Visualize intermediate CNN feature maps
    """
    print("\n" + "=" * 60)
    print("CNN FEATURE MAPS VISUALIZATION")
    print("=" * 60)

    model.eval()

    # Get activations from each CNN layer
    def get_activation(name, activations):
        def hook(model, input, output):
            activations[name] = output.detach().cpu().numpy()
        return hook

    # Register hooks for each layer
    activations = {}
    hooks = []

    for i, layer in enumerate(model.cnn_layers):
        if isinstance(layer, nn.Conv1d):
            hook = layer.register_forward_hook(get_activation(f'conv_{i}', activations))
            hooks.append(hook)

    # Process samples
    for sample_idx in sample_indices:
        subject_id = y_subset[sample_idx]
        sample = torch.FloatTensor(X[sample_idx:sample_idx+1]).to(device)

        # Forward pass to get activations
        with torch.no_grad():
            _ = model(sample)

        # Plot
        n_layers = len(activations)
        fig, axes = plt.subplots(1, n_layers + 1, figsize=(16, 4))

        # Original input
        time = np.arange(X.shape[2]) / SAMPLING_RATE
        for ch in range(min(5, X.shape[1])):
            axes[0].plot(time, X[sample_idx, ch] + ch*2, linewidth=0.8)
        axes[0].set_title(f'Input (Subject {subject_id+1})')
        axes[0].set_xlabel('Time (s)')

        # Feature maps
        for idx, (name, activation) in enumerate(activations.items()):
            act = activation[0]  # First sample in batch
            # Show subset of filters
            n_filters_show = min(16, act.shape[0])
            im = axes[idx+1].imshow(act[:n_filters_show], aspect='auto', cmap='viridis')
            axes[idx+1].set_title(f'{name} Features')
            axes[idx+1].set_xlabel('Time Steps')
            axes[idx+1].set_ylabel('Filter')

        plt.suptitle(f'CNN Feature Maps - Subject {subject_id+1}', fontsize=12)
        plt.tight_layout()
        plt.savefig(os.path.join(VIZ_DIR, f'cnn_features_sample_{sample_idx}.png'),
                   dpi=150, bbox_inches='tight')
        plt.show()

    # Remove hooks
    for hook in hooks:
        hook.remove()

    print(f"Feature maps saved to {VIZ_DIR}/")

visualize_cnn_features(model, X_subset, sample_indices=[0, 50, 100])

# ## 7. Power Spectrum Analysis

def plot_power_spectrum_by_subject(X, y, subject_ids=[0, 50, 100]):
    """
    Plot average power spectrum for different subjects
    """
    print("\n" + "=" * 60)
    print("POWER SPECTRUM ANALYSIS")
    print("=" * 60)

    fig, ax = plt.subplots(1, 1, figsize=(12, 6))

    colors = plt.cm.tab10(np.linspace(0, 1, len(subject_ids)))

    for idx, subject_id in enumerate(subject_ids):
        # Get all samples for this subject
        subject_mask = y == subject_id
        subject_data = X[subject_mask]

        # Calculate average power spectrum
        all_psd = []
        for segment in subject_data[:50]:  # Use max 50 segments
            for ch in range(min(5, N_CHANNELS)):
                freqs, psd = signal.welch(segment[ch], fs=SAMPLING_RATE, nperseg=128)
                all_psd.append(psd)

        avg_psd = np.mean(all_psd, axis=0)

        ax.semilogy(freqs, avg_psd, color=colors[idx], linewidth=2,
                   label=f'Subject {subject_id+1}')

    ax.set_xlabel('Frequency (Hz)', fontsize=12)
    ax.set_ylabel('Power Spectral Density', fontsize=12)
    ax.set_title('Average Power Spectrum by Subject', fontsize=14)
    ax.set_xlim([0, 50])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Mark EEG frequency bands
    bands = {'Delta': (0.5, 4), 'Theta': (4, 8), 'Alpha': (8, 13),
             'Beta': (13, 30), 'Gamma': (30, 50)}

    for band_name, (low, high) in bands.items():
        ax.axvspan(low, high, alpha=0.1)
        ax.text((low + high) / 2, ax.get_ylim()[1] * 0.5, band_name,
               ha='center', fontsize=9, rotation=90, alpha=0.7)

    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'power_spectrum.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"Plot saved to {VIZ_DIR}/power_spectrum.png")

plot_power_spectrum_by_subject(X_subset, y_subset, subject_ids=[0, 25, 50, 75, 100])

# ## 8. Summary

print("=" * 60)
print("NOTEBOOK 4 COMPLETE - VISUALIZATIONS GENERATED")
print("=" * 60)
print(f"""
Generated Visualizations:

1. EEG Signals
   - {VIZ_DIR}/eeg_signals.png

2. Spectrograms
   - {VIZ_DIR}/spectrograms.png
   - {VIZ_DIR}/avg_spectrograms.png

3. t-SNE Feature Embeddings
   - {VIZ_DIR}/tsne_embeddings.png
   - {VIZ_DIR}/tsne_quality.png

4. CNN Feature Maps
   - {VIZ_DIR}/cnn_features_sample_*.png

5. Power Spectrum Analysis
   - {VIZ_DIR}/power_spectrum.png

All files saved in: {VIZ_DIR}
""")

# List all generated files
print("\nGenerated files:")
for filename in sorted(os.listdir(VIZ_DIR)):
    filepath = os.path.join(VIZ_DIR, filename)
    size_kb = os.path.getsize(filepath) / 1024
    print(f"  - {filename}: {size_kb:.1f} KB")