In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Conv1D, MaxPooling1D, Dense, Dropout, BatchNormalization, 
    GlobalAveragePooling1D, Input, Activation, SpatialDropout1D, 
    LSTM, Bidirectional, Multiply, Reshape, LeakyReLU
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.regularizers import l2
import tensorflow.keras.backend as K
import tensorflow as tf
import random

# ===================== DATA AUGMENTATION =====================
def augment_signal(segment, augmentation_prob=0.5):
    """
    Apply random augmentation to EEG segment
    
    Args:
        segment: Input EEG segment (1D array)
        augmentation_prob: Probability of applying augmentation
    
    Returns:
        Augmented segment
    """
    if np.random.random() > augmentation_prob:
        return segment  # No augmentation
    
    # Choose augmentation type
    aug_type = np.random.choice(['noise', 'scale', 'shift', 'time_shift'], p=[0.3, 0.3, 0.2, 0.2])
    
    if aug_type == 'noise':
        # Add Gaussian noise
        noise_level = np.random.uniform(0.01, 0.05)
        noise = np.random.normal(0, noise_level, segment.shape)
        return segment + noise
    
    elif aug_type == 'scale':
        # Random amplitude scaling
        scale = np.random.uniform(0.9, 1.1)
        return segment * scale
    
    elif aug_type == 'shift':
        # Random DC shift
        shift = np.random.uniform(-0.1, 0.1)
        return segment + shift
    
    elif aug_type == 'time_shift':
        # Random time shift (circular shift)
        shift_amount = np.random.randint(-20, 20)
        return np.roll(segment, shift_amount)
    
    return segment


def augment_batch(X_batch, y_batch, augmentation_prob=0.5):
    """
    Apply augmentation to a batch of data
    
    Args:
        X_batch: Batch of EEG segments (batch_size, time_steps, channels)
        y_batch: Batch of labels (1 = ICTAL, 0 = ALL)
        augmentation_prob: Probability of applying augmentation
    
    Returns:
        Augmented batch
    """
    X_augmented = np.zeros_like(X_batch)
    
    for i in range(len(X_batch)):
        # Augment ICTAL class (minority) more aggressively
        if y_batch[i] == 1:  # ICTAL (seizure)
            prob = augmentation_prob * 1.5  # Higher probability for minority class
        else:  # ALL (NORMAL + INTERICTAL)
            prob = augmentation_prob
        
        X_augmented[i, :, 0] = augment_signal(X_batch[i, :, 0], prob)
    
    return X_augmented


# ===================== CUSTOM DATA GENERATOR =====================
class AugmentedDataGenerator(tf.keras.utils.Sequence):
    """
    Custom data generator with real-time augmentation
    """
    def __init__(self, X, y, batch_size=32, augmentation_prob=0.5, shuffle=True):
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.augmentation_prob = augmentation_prob
        self.shuffle = shuffle
        self.indices = np.arange(len(self.X))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch_size))
    
    def __getitem__(self, index):
        # Get batch indices
        start_idx = index * self.batch_size
        end_idx = min((index + 1) * self.batch_size, len(self.X))
        batch_indices = self.indices[start_idx:end_idx]
        
        # Get batch data
        X_batch = self.X[batch_indices].copy()
        y_batch = self.y[batch_indices]
        
        # Apply augmentation
        X_batch = augment_batch(X_batch, y_batch, self.augmentation_prob)
        
        return X_batch, y_batch
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)


# ===================== SE BLOCK IMPLEMENTATION =====================
class SEBlock(tf.keras.layers.Layer):
    """
    Squeeze-and-Excitation Block for channel attention
    Paper: "Interpretable classification of epileptic EEG signals..."
    """
    def __init__(self, reduction=8, **kwargs):
        super(SEBlock, self).__init__(**kwargs)
        self.reduction = reduction
    
    def build(self, input_shape):
        channels = input_shape[-1]
        self.squeeze = GlobalAveragePooling1D()
        
        # Excitation network
        self.fc1 = Dense(
            channels // self.reduction, 
            activation='relu', 
            kernel_initializer='he_normal'
        )
        self.fc2 = Dense(
            channels, 
            activation='sigmoid', 
            kernel_initializer='he_normal'
        )
        
        super(SEBlock, self).build(input_shape)
    
    def call(self, inputs):
        # Squeeze: Global average pooling
        squeeze = self.squeeze(inputs)
        
        # Excitation: Learn channel importance
        excitation = self.fc1(squeeze)
        excitation = self.fc2(excitation)
        
        # Reshape for broadcasting
        excitation = tf.reshape(excitation, [-1, 1, tf.shape(inputs)[-1]])
        
        # Scale: Multiply input with learned weights
        return Multiply()([inputs, excitation])
    
    def get_config(self):
        config = super(SEBlock, self).get_config()
        config.update({"reduction": self.reduction})
        return config


# ===================== LOAD DATA =====================
data_dir = os.path.join("preprocessed")
X = np.load(os.path.join(data_dir, "ALL_X.npy"))
y = np.load(os.path.join(data_dir, "ALL_y.npy"))
file_ids = np.load(os.path.join(data_dir, "ALL_file_ids.npy"))  # âœ… Load file IDs

# Convert to Binary Classification: ICTAL (1) vs ALL (0)
# ICTAL = seizure (minority class)
# ALL = NORMAL + INTERICTAL (majority class)
y_encoded = np.where(y == 'ICTAL', 1, 0)

print("="*60)
print("BINARY CLASSIFICATION SETUP")
print("="*60)
print(f"Original classes: {np.unique(y)}")
print(f"Binary encoding: 1 (ICTAL) vs 0 (ALL: NORMAL+INTERICTAL)")
print(f"Binary labels distribution: {dict(zip(*np.unique(y_encoded, return_counts=True)[::-1]))}")
print(f"Total samples: {len(X)}")
print(f"Total unique files: {len(np.unique(file_ids))}")
print("="*60)

# Prepare Data
X = X.reshape((X.shape[0], X.shape[1], 1))
print(f"Dataset shape: {X.shape}")

# Compute Class Weights with stronger emphasis on minority class
class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
# Boost minority class (ICTAL) weight even more
class_weights[1] = class_weights[1] * 1.5  # 50% increase for ICTAL
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
print(f"Class weights (adjusted): {class_weight_dict}")

# ===================== PREPARE CROSS VALIDATION (FILE-LEVEL) =====================
random_state = np.random.randint(0, 10000)
print(f"ðŸŽ² Random state used for this run: {random_state}")

# âœ… KEY FIX: Split by unique file IDs, not by segments
unique_file_ids = np.unique(file_ids)
print(f"Total unique files for splitting: {len(unique_file_ids)}")

# Create a mapping from file_id to its label (use majority vote if needed)
file_id_to_label = {}
for fid in unique_file_ids:
    mask = file_ids == fid
    labels_in_file = y_encoded[mask]
    # Use the most common label (should be the same for all segments from one file)
    file_id_to_label[fid] = np.bincount(labels_in_file).argmax()

file_labels = np.array([file_id_to_label[fid] for fid in unique_file_ids])

# âœ… Stratified K-Fold on FILE level
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
file_fold_indices = [(train_val_files, test_files) for train_val_files, test_files 
                     in kfold.split(unique_file_ids, file_labels)]

# Convert file-level indices to segment-level indices
fold_indices = []
for train_val_files_idx, test_files_idx in file_fold_indices:
    train_val_files = unique_file_ids[train_val_files_idx]
    test_files = unique_file_ids[test_files_idx]
    
    # Get all segments from these files
    train_val_mask = np.isin(file_ids, train_val_files)
    test_mask = np.isin(file_ids, test_files)
    
    train_val_idx = np.where(train_val_mask)[0]
    test_idx = np.where(test_mask)[0]
    
    fold_indices.append((train_val_idx, test_idx))
    
    print(f"Fold: Train/Val files: {len(train_val_files)}, Test files: {len(test_files)}")
    print(f"      Train/Val segments: {len(train_val_idx)}, Test segments: {len(test_idx)}")

# Save indices for reproducibility
os.makedirs("results", exist_ok=True)
np.save(os.path.join("results", "fold_indices.npy"), np.array(fold_indices, dtype=object), allow_pickle=True)
np.save(os.path.join("results", "file_fold_indices.npy"), np.array(file_fold_indices, dtype=object), allow_pickle=True)


# ===================== LOSS FUNCTIONS =====================
def focal_loss(alpha=0.75, gamma=2.0):
    def loss(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        
        cross_entropy = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_term = K.pow(1 - p_t, gamma)
        alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
        
        return K.mean(alpha_t * focal_term * cross_entropy)
    
    return loss


def hybrid_focal_loss(alpha=0.75, gamma=1.7, focal_weight=0.55):
    """Hybrid: Focal + BCE"""
    def loss(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        
        bce = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred)
        
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_term = K.pow(1 - p_t, gamma)
        alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
        focal = alpha_t * focal_term * bce
        
        combined = focal_weight * focal + (1 - focal_weight) * bce
        return K.mean(combined)
    
    return loss


# ===================== MODEL BUILDING =====================
def build_model(input_shape):
    """
    MODIFIED MODEL:
    - Added 4th CNN block
    - Changed Bidirectional LSTM to regular LSTM
    """
    inputs = Input(shape=input_shape)
    
    # ========== Block 1: Local patterns ==========
    x = Conv1D(48, kernel_size=7, padding='same', kernel_regularizer=l2(0.002))(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = SEBlock(reduction=8)(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = SpatialDropout1D(0.28)(x)
    
    # ========== Block 2: Mid-level features ==========
    x = Conv1D(96, kernel_size=5, padding='same', kernel_regularizer=l2(0.002))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = SEBlock(reduction=8)(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = SpatialDropout1D(0.32)(x)
    
    # ========== Block 3: High-level features ==========
    x = Conv1D(128, kernel_size=3, padding='same', kernel_regularizer=l2(0.002))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = SEBlock(reduction=8)(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = SpatialDropout1D(0.38)(x)
    
    # ========== Block 4: Deep features (NEW LAYER) ==========
    x = Conv1D(160, kernel_size=3, padding='same', kernel_regularizer=l2(0.002))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.1)(x)
    x = SEBlock(reduction=8)(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = SpatialDropout1D(0.38)(x)
    
    # ========== LSTM Temporal Modeling (Changed from Bidirectional to regular LSTM) ==========
    x = LSTM(64, return_sequences=False, kernel_regularizer=l2(0.001))(x)
    x = Dropout(0.4)(x)
    
    # ========== Dense Classification ==========
    x = Dense(64, activation='relu', kernel_regularizer=l2(0.003))(x)
    x = Dropout(0.48)(x)
    
    x = Dense(32, activation='relu', kernel_regularizer=l2(0.003))(x)
    x = Dropout(0.48)(x)
    
    outputs = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    
    return model


# ===================== TRAINING =====================
USE_DATA_AUGMENTATION = True  

acc_per_fold = []
auc_per_fold = []
conf_matrices = []
class_names = ['ALL (NORMAL+INTERICTAL)', 'ICTAL (SEIZURE)']

# Track metrics for best model selection
fold_metrics = {
    'fold_no': [],
    'test_acc': [],
    'test_loss': [],
    'test_auc': [],
    'val_acc': [],
    'train_acc': [],
    'f1_score': [],
    'precision': [],
    'recall': []
}
best_model = None
best_fold = None
best_acc = 0


for fold_no, (train_val_idx, test_idx) in enumerate(fold_indices, start=1):
    
    print(f"\n{'='*60}")
    print(f" FOLD {fold_no}")
    print(f"{'='*60}\n")

    # Split into train/val/test
    X_train_val, X_test = X[train_val_idx], X[test_idx]
    y_train_val, y_test = y_encoded[train_val_idx], y_encoded[test_idx]
    file_ids_train_val = file_ids[train_val_idx]
    
    # âœ… KEY FIX: Split train/val by files, not segments
    unique_train_val_files = np.unique(file_ids_train_val)
    train_val_file_labels = np.array([file_id_to_label[fid] for fid in unique_train_val_files])
    
    train_files, val_files = train_test_split(
        unique_train_val_files, 
        test_size=0.15, 
        stratify=train_val_file_labels, 
        random_state=42
    )
    
    train_mask = np.isin(file_ids_train_val, train_files)
    val_mask = np.isin(file_ids_train_val, val_files)
    
    X_train = X_train_val[train_mask]
    y_train = y_train_val[train_mask]
    X_val = X_train_val[val_mask]
    y_val = y_train_val[val_mask]

    print(f"Train files: {len(train_files)}, Val files: {len(val_files)}, Test files: {len(np.unique(file_ids[test_idx]))}")
    print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
    print(f"Test set distribution: {dict(zip(*np.unique(y_test, return_counts=True)[::-1]))}")

    # Build model
    model = build_model(input_shape=(X.shape[1], 1))
    
    # Print model summary only for first fold
    if fold_no == 1:
        print("\n" + "="*60)
        print("MODEL ARCHITECTURE")
        print("="*60)
        model.summary()
        print("="*60 + "\n")
    
    # Compile with adjusted hyperparameters
    model.compile(
        optimizer=Adam(learning_rate=1e-4),  # Increased learning rate
        loss=hybrid_focal_loss(alpha=0.80, gamma=2.0, focal_weight=0.7),  # More aggressive focal loss
        metrics=['accuracy']
    )
    
    # Callbacks 
    early_stop = EarlyStopping(
        monitor='val_loss', 
        patience=25, 
        restore_best_weights=True, 
        verbose=1, 
        mode='min'
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,  # More aggressive reduction
        patience=5,  # Reduce faster
        min_lr=1e-7,
        verbose=1,
        mode='min'
    )
    
    checkpoint = ModelCheckpoint(
        os.path.join("results", f"model_fold{fold_no}.weights.h5"),
        monitor='val_accuracy', 
        save_best_only=True, 
        save_weights_only=True,
        verbose=1,  # Changed to 1 for better monitoring
        mode='max'
    )
    
    callbacks = [early_stop, reduce_lr, checkpoint]
    
    # Train Model
    print(f"\nðŸš€ Training Fold {fold_no}...")
    
    if USE_DATA_AUGMENTATION:
        print("ðŸ“Š Using data augmentation with real-time generator")
        
        # Create data generators
        train_generator = AugmentedDataGenerator(
            X_train, y_train, 
            batch_size=12, 
            augmentation_prob=0.55,
            shuffle=True
        )
        
        # Note: Validation data is NOT augmented
        history = model.fit(
            train_generator,
            epochs=200,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            class_weight=class_weight_dict,
            verbose=1
        )
    else:
        print("ðŸ“Š Training without data augmentation")
        history = model.fit(
            X_train, y_train,
            epochs=150,
            batch_size=32,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            class_weight=class_weight_dict,
            verbose=1
        )

    # Load best weights
    model.load_weights(os.path.join("results", f"model_fold{fold_no}.weights.h5"))

    # Evaluate
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
    acc_per_fold.append(test_acc)
    
    # Predictions
    y_pred_prob = model.predict(X_test, verbose=0)
    y_pred = (y_pred_prob > 0.5).astype(int).flatten()
    
    # AUC Score
    test_auc = roc_auc_score(y_test, y_pred_prob)
    auc_per_fold.append(test_auc)
    
    print(f"\nâœ… Fold {fold_no} Results:")
    print(f"  Test Accuracy: {test_acc:.4f}")
    print(f"  Test AUC: {test_auc:.4f}")
    print(f"  Test Loss: {test_loss:.4f}")

    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    conf_matrices.append(cm)
    
    # Classification Metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
    
    # Store metrics
    fold_metrics['fold_no'].append(fold_no)
    fold_metrics['test_acc'].append(test_acc)
    fold_metrics['test_loss'].append(test_loss)
    fold_metrics['test_auc'].append(test_auc)
    fold_metrics['val_acc'].append(max(history.history['val_accuracy']))
    fold_metrics['train_acc'].append(max(history.history['accuracy']))
    fold_metrics['f1_score'].append(f1)
    fold_metrics['precision'].append(precision)
    fold_metrics['recall'].append(recall)
    
    # Check if this is the best model
    if test_acc > best_acc:
        best_acc = test_acc
        best_fold = fold_no
        best_model = model
        best_cm = cm
        best_y_test = y_test
        best_y_pred = y_pred
        best_y_pred_prob = y_pred_prob
        print(f"ðŸŒŸ New best model! Fold {fold_no} with accuracy: {test_acc:.4f}")
    
    print(f"\nFold {fold_no} Classification Report:")
    print(classification_report(y_test, y_pred, target_names=class_names, digits=4))

    # Plot Confusion Matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                annot_kws={'size': 14})
    plt.title(f"Fold {fold_no} Confusion Matrix\nAcc: {test_acc:.4f} | AUC: {test_auc:.4f}", 
              fontsize=14)
    plt.xlabel("Predicted", fontsize=12)
    plt.ylabel("True", fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join("results", f"confusion_fold{fold_no}.png"), dpi=300)
    plt.close()

    # Plot Training History 
    plt.figure(figsize=(14, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
    plt.plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
    plt.axhline(y=test_acc, color='r', linestyle='--', label=f'Test Acc: {test_acc:.4f}')
    plt.title(f'Fold {fold_no} - Model Accuracy', fontsize=13)
    plt.xlabel('Epoch', fontsize=11)
    plt.ylabel('Accuracy', fontsize=11)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss', linewidth=2)
    plt.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    plt.title(f'Fold {fold_no} - Model Loss', fontsize=13)
    plt.xlabel('Epoch', fontsize=11)
    plt.ylabel('Loss', fontsize=11)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join("results", f"training_history_fold{fold_no}.png"), dpi=300)
    plt.close()


# ===================== SUMMARY =====================
print("\n" + "="*60)
print(" CROSS-VALIDATION SUMMARY")
print("="*60)
print(f"\nðŸ“Š Mean Test Accuracy across folds: {np.mean(acc_per_fold):.4f} (Â±{np.std(acc_per_fold):.4f})")
print(f"ðŸ“Š Mean Test AUC across folds: {np.mean(auc_per_fold):.4f} (Â±{np.std(auc_per_fold):.4f})")

# Combine confusion matrices
total_cm = np.sum(conf_matrices, axis=0)
plt.figure(figsize=(8, 6))
sns.heatmap(total_cm, annot=True, fmt='d', cmap='Greens', 
            xticklabels=class_names, yticklabels=class_names,
            annot_kws={'size': 14})
plt.title("Overall Confusion Matrix (All Folds)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.savefig(os.path.join("results", "confusion_overall.png"), dpi=300)
plt.close()

# Print Fold Metrics
print("\nðŸ“‹ Fold-wise Metrics Summary:")
for i in range(len(fold_metrics['fold_no'])):
    print(f"\nðŸ”¸ Fold {fold_metrics['fold_no'][i]} Metrics:")
    print(f"  Train Accuracy : {fold_metrics['train_acc'][i]:.4f}")
    print(f"  Val Accuracy   : {fold_metrics['val_acc'][i]:.4f}")
    print(f"  Test Accuracy  : {fold_metrics['test_acc'][i]:.4f}")
    print(f"  Test Loss      : {fold_metrics['test_loss'][i]:.4f}")
    print(f"  Precision      : {fold_metrics['precision'][i]:.4f}")
    print(f"  Recall         : {fold_metrics['recall'][i]:.4f}")
    print(f"  F1 Score       : {fold_metrics['f1_score'][i]:.4f}")
    print(f"  Test AUC       : {fold_metrics['test_auc'][i]:.4f}")

# Save best model
best_model.save(os.path.join("results", "best_model.keras"))
print(f"\nðŸ’¾ Best model (Fold {best_fold}) saved as 'results/best_model.keras'")
print("="*60)