In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv1D, MaxPooling1D, LSTM, Dense, Dropout, BatchNormalization
)
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.layers import (
    Conv1D, MaxPooling1D, LSTM, Dense, Dropout,
    LayerNormalization, Add, Bidirectional,
    GlobalAveragePooling1D
)
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, LSTM, Bidirectional, Dense, Dropout
from tensorflow.keras.models import Model




# -------------------- Load Data --------------------
X = np.load(r"preprocessed\ALL_X.npy")
y = np.load(r"preprocessed\ALL_y.npy")

# Convert to binary: ICTAL = 0, ALL OTHERS = 1
y_encoded = np.where(y == 'ICTAL', 0, 1)

# Reshape for CNN input
X = X.reshape((X.shape[0], X.shape[1], 1))

print("Dataset shape:", X.shape)
print("Labels distribution:", np.unique(y_encoded, return_counts=True))

# -------------------- Prepare Cross Validation --------------------
random_state = np.random.randint(0, 10000)
print(f"ðŸŽ² Random state used for this run: {random_state}")

kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)
fold_indices = [(train_idx, test_idx) for train_idx, test_idx in kfold.split(X, y_encoded)]

os.makedirs("results(Interictal_VS_Normal_VS_Ictal)", exist_ok=True)
np.save("results(Interictal_VS_Normal_VS_Ictal)/fold_indices.npy", np.array(fold_indices, dtype=object), allow_pickle=True)

def hybrid_focal_loss(alpha=0.25, gamma=2.0, bce_weight=0.5):
    """
    Hybrid = BCE * bce_weight + FocalLoss * (1 - bce_weight)
    """

    def loss(y_true, y_pred):
        # --- Binary Cross Entropy ---
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)

        # --- Focal Loss ---
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        focal = -alpha * (1 - y_pred) ** gamma * y_true * K.log(y_pred) \
                - (1 - alpha) * y_pred ** gamma * (1 - y_true) * K.log(1 - y_pred)

        focal = K.mean(focal, axis=-1)

        # --- Combine Both ---
        return bce_weight * bce + (1 - bce_weight) * focal

    return loss

# -------------------- CNN + LSTM Model --------------------
def build_cnn_lstm(input_length):
    model = Sequential([
        # --- CNN Layers ---
        Conv1D(32, kernel_size=7, activation='relu', input_shape=(input_length, 1)),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.2),

        Conv1D(64, kernel_size=5, activation='relu'),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.3),

        Conv1D(128, kernel_size=3, activation='relu'),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.4),
        
        # --- LSTM ---
        LSTM(64, return_sequences=False),

        # --- Dense Layers ---
        Dense(64, activation='relu'),
        Dropout(0.4),

        Dense(1, activation='sigmoid')
    ])

    model.compile(
    optimizer=Adam(1e-4),
    loss=hybrid_focal_loss(alpha=0.25, gamma=2.0, bce_weight=0.5),
    metrics=['accuracy']
    )
    return model


# -------------------- Data Augmentation for 1D EEG (Optimized Version) --------------------
def augment_signal(signal):
    # 1) Very Light Noise
    noise = np.random.normal(0, 0.005, signal.shape)
    signal_noisy = signal + noise

    # 2) Small Time Shift
    shift = np.random.randint(-5, 5)
    signal_shifted = np.roll(signal_noisy, shift)

    # 3) Gentle Scaling
    scale = np.random.uniform(0.97, 1.03)
    signal_scaled = signal_shifted * scale

    return signal_scaled



def augment_batch(X, y):
    X_aug = []
    y_aug = []

    for i in range(len(X)):
        X_aug.append(X[i])
        y_aug.append(y[i])

        # Generate **1 weakly augmented version**
        X_aug.append(augment_signal(X[i]))
        y_aug.append(y[i])

    return np.array(X_aug), np.array(y_aug)



# -------------------- Training --------------------
acc_per_fold = []
conf_matrices = []

for fold_no, (train_val_idx, test_idx) in enumerate(fold_indices, start=1):
    print(f"\nðŸ”¹ Fold {fold_no}")

    # Split into train/val/test
    X_train_val, X_test = X[train_val_idx], X[test_idx]
    y_train_val, y_test = y_encoded[train_val_idx], y_encoded[test_idx]

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.1765, stratify=y_train_val, random_state=42
    )

    # ------------ APPLY DATA AUGMENTATION ------------
    X_train, y_train = augment_batch(X_train, y_train)

    print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

    # Build new model for each fold
    model = build_cnn_lstm(X_train.shape[1])
    model.summary()

    # Handle class imbalance
    cw = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = {0: cw[0], 1: cw[1]}

    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=40,
        batch_size=32,
        validation_data=(X_val, y_val),
        class_weight=class_weights,
        verbose=1
    )

    # Evaluate
    test_loss, test_acc = model.evaluate(X_test, y_test)
    acc_per_fold.append(test_acc)
    print(f"Fold {fold_no} - Test Accuracy: {test_acc:.4f}")

    # Save weights
    weight_path = f"results(Interictal_VS_Normal_VS_Ictal)/cnn_lstm_fold{fold_no}.weights.h5"
    model.save_weights(weight_path)
    print(f"âœ… Weights saved to {weight_path}")

    # Confusion matrix
    y_pred = (model.predict(X_test) > 0.5).astype("int32").flatten()
    cm = confusion_matrix(y_test, y_pred)
    conf_matrices.append(cm)

    plt.figure(figsize=(5, 4))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=['Normal', 'Abnormal'], 
        yticklabels=['Normal', 'Abnormal']
    )
    plt.title(f"Fold {fold_no} Confusion Matrix")
    plt.tight_layout()
    plt.savefig(f"results(Interictal_VS_Normal_VS_Ictal)/cnn_lstm_conf_fold{fold_no}.png")
    plt.close()

print("\nðŸ“Š Mean Accuracy:", np.mean(acc_per_fold))

# -------------------- Overall Confusion Matrix --------------------
total_cm = np.sum(conf_matrices, axis=0)

plt.figure(figsize=(5, 4))
sns.heatmap(
    total_cm, annot=True, fmt='d', cmap='Greens',
    xticklabels=['Normal', 'Abnormal'], 
    yticklabels=['Normal', 'Abnormal']
)
plt.title("Overall Confusion Matrix (All Folds)")
plt.tight_layout()
plt.savefig("results(Interictal_VS_Normal_VS_Ictal)/cnn_lstm_conf_overall.png")

tn, fp, fn, tp = total_cm.ravel()

overall_accuracy = (tp + tn) / np.sum(total_cm)
overall_precision = tp / (tp + fp)
overall_recall = tp / (tp + fn)
overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall)

print("\nðŸ“Š Overall Metrics:")
print(f"  Accuracy : {overall_accuracy:.4f}")
print(f"  Precision: {overall_precision:.4f}")
print(f"  Recall   : {overall_recall:.4f}")
print(f"  F1-score : {overall_f1:.4f}")

print("âœ… CNN+LSTM Training Completed!")
