In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.models import Model # type: ignore
from tensorflow.keras.layers import Input, Conv1D, Dense, Dropout, BatchNormalization, MaxPooling1D, GlobalAveragePooling1D, LeakyReLU # type: ignore
from tensorflow.keras.optimizers import Adam # type: ignore
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint # type: ignore
from tensorflow.keras.utils import to_categorical # type: ignore
import tensorflow.keras.backend as K # type: ignore

# ===================== GPU SETUP =====================
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    print("Using GPU:", gpus[0])
else:
    print("No GPU found, using CPU")

# ===================== PATHS =====================
DATA_DIR = r"D:/FCIS Content/Fourth year/GP/GP-Epileptic-seizures/Preprocessing_Updated_Kfold"
RESULTS_DIR = "results"
os.makedirs(RESULTS_DIR, exist_ok=True)

N_FOLDS = 5
NUM_CLASSES = 3

summary_path = os.path.join(RESULTS_DIR, "accuracy_summary.txt")
with open(summary_path, "w", encoding="utf-8") as f:
    f.write("Accuracy Summary per Fold\n")
    f.write("="*40 + "\n\n")

# ===================== DATA AUGMENTATION =====================
def augment_signal(segment, prob=0.5):
    if np.random.rand() > prob: return segment
    aug_type = np.random.choice(['noise','scale','shift','time_shift'], p=[0.3,0.3,0.2,0.2])
    if aug_type == 'noise': return segment + np.random.normal(0, np.random.uniform(0.01,0.05), segment.shape)
    if aug_type == 'scale': return segment * np.random.uniform(0.9,1.1)
    if aug_type == 'shift': return segment + np.random.uniform(-0.1,0.1)
    if aug_type == 'time_shift': return np.roll(segment, np.random.randint(-20,20))
    return segment

def mixup(X1, y1, X2, y2, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    return lam*X1 + (1-lam)*X2, lam*y1 + (1-lam)*y2

def augment_batch(X, y, prob=0.6):
    X_aug = X.copy()
    y_aug = y.copy()
    for i in range(len(X_aug)):
        p = prob * 1.5 if np.argmax(y_aug[i]) != 0 else prob
        X_aug[i,:,0] = augment_signal(X_aug[i,:,0], p)
    for i in range(len(X_aug)//2):
        j = np.random.randint(len(X_aug))
        X_aug[i,:,0], y_aug[i] = mixup(X_aug[i,:,0], y_aug[i], X_aug[j,:,0], y_aug[j])
    return X_aug, y_aug

class AugmentedGenerator(tf.keras.utils.Sequence):
    def __init__(self, X, y, batch_size=16, prob=0.6):
        self.X, self.y = X, y
        self.bs = batch_size
        self.prob = prob
        self.idx = np.arange(len(X))
    def __len__(self): return int(np.ceil(len(self.X)/self.bs))
    def on_epoch_end(self): np.random.shuffle(self.idx)
    def __getitem__(self, i):
        batch_idx = self.idx[i*self.bs:(i+1)*self.bs]
        Xb, yb = self.X[batch_idx], self.y[batch_idx]
        Xb, yb = augment_batch(Xb, yb, self.prob)
        return Xb, yb

# ===================== HYBRID FOCAL LOSS =====================
def hybrid_focal_loss(alpha=[1,1.5,2], gamma=1.5):
    alpha = K.constant(alpha, dtype=K.floatx())
    def loss(y_true, y_pred):
        y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
        ce = -y_true * K.log(y_pred)
        focal = (1 - y_pred) ** gamma
        return K.mean(K.sum(alpha * focal * ce, axis=-1))
    return loss

# ===================== MODEL =====================
def build_model(input_shape):
    inputs = Input(shape=input_shape)
    x = inputs
    for filters, k, d in [(64,7,0.3),(128,5,0.3),(256,3,0.4)]:
        x = Conv1D(filters, k, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.1)(x)
        x = MaxPooling1D(2)(x)
        x = Dropout(d)(x)
    x = GlobalAveragePooling1D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(NUM_CLASSES, activation='softmax')(x)
    return Model(inputs, outputs)

# ===================== PAD SEGMENTS =====================
def pad_segments(X_list):
    """ Pad all segments in a fold to the same length """
    cleaned_segments = []
    for x in X_list:
        x = np.array(x)  # Ensure it's a numpy array
        if x.ndim == 1:
            x = x.reshape(1, -1)  # Convert 1D to 2D (1, length)
        cleaned_segments.append(x)
    
    # Find max length
    max_len = max(seg.shape[1] for seg in cleaned_segments)
    
    # Pad each segment to max length
    padded_segments = []
    for seg in cleaned_segments:
        pad_width = max_len - seg.shape[1]
        if pad_width > 0:
            seg_padded = np.pad(seg, ((0,0),(0,pad_width)), mode='constant')
        else:
            seg_padded = seg
        padded_segments.append(seg_padded)
    
    return np.vstack(padded_segments)

# ===================== TRAINING =====================
acc_folds, auc_folds, conf_matrices = [], [], []
fold_indices_list = []

for fold in range(N_FOLDS):
    print(f"\n===== Fold {fold+1}/{N_FOLDS} =====")
    
    # Load data with allow_pickle
    X_train = np.load(os.path.join(DATA_DIR,f"fold_{fold}_X_train.npy"), allow_pickle=True)
    X_test  = np.load(os.path.join(DATA_DIR,f"fold_{fold}_X_test.npy"),  allow_pickle=True)
    y_train = np.load(os.path.join(DATA_DIR,f"fold_{fold}_y_train.npy"), allow_pickle=True)
    y_test  = np.load(os.path.join(DATA_DIR,f"fold_{fold}_y_test.npy"), allow_pickle=True)

    # Pad segments to same length
    X_train = pad_segments(X_train)[..., None].astype(np.float32)
    X_test  = pad_segments(X_test)[..., None].astype(np.float32)

    fold_indices_list.append((y_train.copy(), y_test.copy()))
    
    # One-hot encoding
    y_train_cat = to_categorical(y_train, NUM_CLASSES)
    y_test_cat  = to_categorical(y_test, NUM_CLASSES)
    
    # Train-validation split
    X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train_cat, test_size=0.15, stratify=y_train, random_state=42)
    
    # Class weights
    cw = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
    class_weight = {i:w for i,w in enumerate(cw)}
    
    # Build model
    model = build_model((X_train.shape[1],1))
    model.compile(Adam(1e-4), loss=hybrid_focal_loss(), metrics=['accuracy'])
    
    # Training
    history = model.fit(
        AugmentedGenerator(X_tr, y_tr),
        validation_data=(X_val, y_val),
        epochs=80,
        callbacks=[
            EarlyStopping(patience=25, restore_best_weights=True),
            ReduceLROnPlateau(patience=8, factor=0.6, min_lr=1e-7),
            ModelCheckpoint(f"{RESULTS_DIR}/model_fold{fold+1}.weights.h5", save_best_only=True, save_weights_only=True)
        ],
        class_weight=class_weight,
        verbose=1
    )
    
    # Load best weights
    model.load_weights(f"{RESULTS_DIR}/model_fold{fold+1}.weights.h5")
    
    # Predict
    y_pred_prob = model.predict(X_test)
    y_pred = np.argmax(y_pred_prob, axis=1)
    
    acc = np.mean(y_pred == y_test)
    auc = roc_auc_score(y_test_cat, y_pred_prob, multi_class='ovr')
    
    acc_folds.append(acc)
    auc_folds.append(auc)
    
    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    conf_matrices.append(cm)
    
    # Plot fold confusion matrix
    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Fold {fold+1} Confusion Matrix | ACC={acc:.3f}")
    plt.savefig(f"{RESULTS_DIR}/confusion_fold{fold+1}.png", dpi=300)
    plt.close()
    
    with open(summary_path,"a", encoding="utf-8") as f:
        f.write(f"Fold {fold+1}: ACC={acc:.4f}, AUC={auc:.4f}\n")

# ===================== TOTAL RESULTS =====================
total_cm = np.sum(conf_matrices, axis=0)
mean_acc = np.mean(acc_folds)

plt.figure(figsize=(6,5))
sns.heatmap(total_cm, annot=True, fmt='d', cmap='Greens')
plt.title(f"Total Confusion Matrix | Mean ACC={mean_acc:.3f}")
plt.savefig(f"{RESULTS_DIR}/confusion_total.png", dpi=300)
plt.close()

with open(summary_path,"a", encoding="utf-8") as f:
    f.write("\n" + "="*40 + "\n")
    f.write(f"Mean ACC: {mean_acc:.4f}\n")
    f.write(f"Mean AUC: {np.mean(auc_folds):.4f}\n")

# Save fold indices
np.save(os.path.join(RESULTS_DIR,"fold_indices.npy"), np.array(fold_indices_list, dtype=object))

print("\n==============================")
print("TRAINING COMPLETED SUCCESSFULLY")
print("ALL RESULTS SAVED IN results/")
print("==============================")


No GPU found, using CPU

===== Fold 1/5 =====
Epoch 1/80


  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 21ms/step - accuracy: 0.6353 - loss: 0.6359 - val_accuracy: 0.4000 - val_loss: 0.8991 - learning_rate: 1.0000e-04
Epoch 2/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 23ms/step - accuracy: 0.7515 - loss: 0.4415 - val_accuracy: 0.5083 - val_loss: 0.6033 - learning_rate: 1.0000e-04
Epoch 3/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 22ms/step - accuracy: 0.8062 - loss: 0.3735 - val_accuracy: 0.8292 - val_loss: 0.2880 - learning_rate: 1.0000e-04
Epoch 4/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.8390 - loss: 0.3350 - val_accuracy: 0.9208 - val_loss: 0.1811 - learning_rate: 1.0000e-04
Epoch 5/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.8669 - loss: 0.3066 - val_accuracy: 0.9125 - val_loss: 0.1637 - learning_rate: 1.0000e-04
Epoch 6/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 48ms/step - accuracy: 0.6246 - loss: 0.6091 - val_accuracy: 0.4000 - val_loss: 0.9704 - learning_rate: 1.0000e-04
Epoch 2/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.7688 - loss: 0.4456 - val_accuracy: 0.5771 - val_loss: 0.5608 - learning_rate: 1.0000e-04
Epoch 3/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 44ms/step - accuracy: 0.8099 - loss: 0.3819 - val_accuracy: 0.9167 - val_loss: 0.2352 - learning_rate: 1.0000e-04
Epoch 4/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.8338 - loss: 0.3342 - val_accuracy: 0.8938 - val_loss: 0.1959 - learning_rate: 1.0000e-04
Epoch 5/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 42ms/step - accuracy: 0.8511 - loss: 0.3149 - val_accuracy: 0.9083 - val_loss: 0.1908 - learning_rate: 1.0000e-04
Epoch 6/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 50ms/step - accuracy: 0.6250 - loss: 0.6332 - val_accuracy: 0.4000 - val_loss: 1.1439 - learning_rate: 1.0000e-04
Epoch 2/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 52ms/step - accuracy: 0.7482 - loss: 0.4645 - val_accuracy: 0.4250 - val_loss: 0.8897 - learning_rate: 1.0000e-04
Epoch 3/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 52ms/step - accuracy: 0.8026 - loss: 0.3840 - val_accuracy: 0.8313 - val_loss: 0.3256 - learning_rate: 1.0000e-04
Epoch 4/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 51ms/step - accuracy: 0.8379 - loss: 0.3517 - val_accuracy: 0.9333 - val_loss: 0.1767 - learning_rate: 1.0000e-04
Epoch 5/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 51ms/step - accuracy: 0.8533 - loss: 0.3298 - val_accuracy: 0.9417 - val_loss: 0.1673 - learning_rate: 1.0000e-04
Epoch 6/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 64ms/step - accuracy: 0.5813 - loss: 0.6952 - val_accuracy: 0.4000 - val_loss: 0.9744 - learning_rate: 1.0000e-04
Epoch 2/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 55ms/step - accuracy: 0.7640 - loss: 0.4394 - val_accuracy: 0.4750 - val_loss: 0.6496 - learning_rate: 1.0000e-04
Epoch 3/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 51ms/step - accuracy: 0.8000 - loss: 0.3879 - val_accuracy: 0.8396 - val_loss: 0.2977 - learning_rate: 1.0000e-04
Epoch 4/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 58ms/step - accuracy: 0.8390 - loss: 0.3322 - val_accuracy: 0.9167 - val_loss: 0.1979 - learning_rate: 1.0000e-04
Epoch 5/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 58ms/step - accuracy: 0.8412 - loss: 0.3234 - val_accuracy: 0.9146 - val_loss: 0.1751 - learning_rate: 1.0000e-04
Epoch 6/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 22ms/step - accuracy: 0.6397 - loss: 0.5882 - val_accuracy: 0.4000 - val_loss: 1.1407 - learning_rate: 1.0000e-04
Epoch 2/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.7540 - loss: 0.4557 - val_accuracy: 0.4417 - val_loss: 0.8897 - learning_rate: 1.0000e-04
Epoch 3/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.7985 - loss: 0.3857 - val_accuracy: 0.8333 - val_loss: 0.2924 - learning_rate: 1.0000e-04
Epoch 4/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.8430 - loss: 0.3362 - val_accuracy: 0.9229 - val_loss: 0.1831 - learning_rate: 1.0000e-04
Epoch 5/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.8614 - loss: 0.3114 - val_accuracy: 0.9250 - val_loss: 0.1653 - learning_rate: 1.0000e-04
Epoch 6/80
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m

In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.models import Model # type: ignore
from tensorflow.keras.layers import ( # type: ignore
    Input, Conv1D, Dense, Dropout, BatchNormalization,
    MaxPooling1D, GlobalAveragePooling1D, Add, Activation
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint # type: ignore
from tensorflow.keras.optimizers import AdamW # type: ignore

# =============================================================
#                 DEVICE INFO
# =============================================================
print("TensorFlow version:", tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    print("GPU detected:", gpus)
else:
    print("No GPU detected. Using CPU.")

# =============================================================
#                RESULTS DIRECTORY
# =============================================================
results_dir = "results_binary_fixed"
os.makedirs(results_dir, exist_ok=True)
summary_path = os.path.join(results_dir, "accuracy_summary.txt")

with open(summary_path, "w", encoding="utf-8") as f:
    f.write("BINARY (NORMAL vs ALL) – FIXED PREPROCESSING\n")
    f.write("="*70 + "\n\n")

# =============================================================
#                  PAD SEGMENTS (FROM 3-CLASS CODE)
# =============================================================
def pad_segments(X_list):
    cleaned = []
    for x in X_list:
        x = np.array(x)
        if x.ndim == 1:
            x = x.reshape(1, -1)
        cleaned.append(x)

    max_len = max(seg.shape[1] for seg in cleaned)

    padded = []
    for seg in cleaned:
        pad_width = max_len - seg.shape[1]
        if pad_width > 0:
            seg = np.pad(seg, ((0,0),(0,pad_width)), mode="constant")
        padded.append(seg)

    return np.vstack(padded)

# =============================================================
#                MULTI → BINARY LABELS (HELPERS)
# =============================================================
DATA_DIR = r"Preprocessing_Updated_Kfold"

# =============================================================
#                    DATA AUGMENTATION
# =============================================================
def augment_signal(seg, p=0.5):
    if np.random.rand() > p:
        return seg
    choice = np.random.choice(['noise', 'scale', 'shift', 'roll'])
    if choice == 'noise':
        return seg + np.random.normal(0, np.random.uniform(0.01,0.05), seg.shape)
    if choice == 'scale':
        return seg * np.random.uniform(0.9, 1.1)
    if choice == 'shift':
        return seg + np.random.uniform(-0.1, 0.1)
    if choice == 'roll':
        return np.roll(seg, np.random.randint(-20,20))
    return seg

def mixup_binary(x1, y1, x2, y2, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    return lam*x1 + (1-lam)*x2, lam*y1 + (1-lam)*y2

class AugGen(tf.keras.utils.Sequence):
    def __init__(self, X, y, batch=16, p=0.6):
        self.X, self.y = X, y
        self.batch = batch
        self.p = p
        self.idx = np.arange(len(X))
        np.random.shuffle(self.idx)

    def __len__(self):
        return int(np.ceil(len(self.X)/self.batch))

    def on_epoch_end(self):
        np.random.shuffle(self.idx)

    def __getitem__(self, i):
        ids = self.idx[i*self.batch:(i+1)*self.batch]
        Xb = self.X[ids].copy()
        yb = self.y[ids].copy()

        for k in range(len(Xb)):
            prob = self.p*0.4 if yb[k]==0 else self.p*1.2
            Xb[k,:,0] = augment_signal(Xb[k,:,0], prob)

        for k in range(len(Xb)//2):
            j = np.random.randint(len(Xb))
            Xb[k,:,0], yb[k] = mixup_binary(
                Xb[k,:,0], yb[k],
                Xb[j,:,0], yb[j]
            )

        return Xb, yb

# =============================================================
#                    WEIGHTED BCE
# =============================================================
def weighted_bce(w_pos=2.8, w_neg=1.0):
    def loss(y_true, y_pred):
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        weights = y_true * w_pos + (1 - y_true) * w_neg
        return tf.reduce_mean(weights * bce)
    return loss

# =============================================================
#                    RESIDUAL MODEL
# =============================================================
def res_block(x, filters):
    shortcut = x
    x = Conv1D(filters,3,padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv1D(filters,3,padding="same")(x)
    x = BatchNormalization()(x)

    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters,1,padding="same")(shortcut)

    x = Add()([x, shortcut])
    return Activation("relu")(x)

def build_model(input_shape):
    inp = Input(input_shape)

    x = Conv1D(64,7,padding="same")(inp)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    for f,d in [(64,0.3),(128,0.3),(256,0.4),(512,0.4)]:
        x = res_block(x, f)
        x = MaxPooling1D(2)(x)
        x = Dropout(d)(x)

    x = GlobalAveragePooling1D()(x)
    x = Dense(128, activation="relu")(x)
    x = Dropout(0.3)(x)

    out = Dense(1, activation="sigmoid")(x)
    return Model(inp, out)

# =============================================================
#                    TRAINING LOOP
# =============================================================
acc_list, auc_list, confs = [], [], []

for fold in range(5):
    print(f"\n========== FOLD {fold+1} ==========")
    start = time.time()

    # LOAD DATA LIKE THE SECOND CODE
    X_tr_val = np.load(os.path.join(DATA_DIR, f"fold_{fold}_X_train.npy"), allow_pickle=True)
    X_te = np.load(os.path.join(DATA_DIR, f"fold_{fold}_X_test.npy"),  allow_pickle=True)
    y_tr_val = np.load(os.path.join(DATA_DIR, f"fold_{fold}_y_train.npy"), allow_pickle=True)
    y_te = np.load(os.path.join(DATA_DIR, f"fold_{fold}_y_test.npy"),  allow_pickle=True)

    # PAD AND RESHAPE
    X_tr_val = pad_segments(X_tr_val).astype(np.float32)
    X_te = pad_segments(X_te).astype(np.float32)
    
    # BINARY LABELS
    y_tr_val_bin = np.where(y_tr_val == 0, 0, 1).astype(np.int64)
    y_te_bin = np.where(y_te == 0, 0, 1).astype(np.int64)

    # Conv1D shape (N, T, 1)
    X_tr_val = X_tr_val[..., None]
    X_te = X_te[..., None]

    X_tr, X_val, y_tr, y_val = train_test_split(
        X_tr_val, y_tr_val_bin,
        test_size=0.15,
        stratify=y_tr_val_bin,
        random_state=42
    )

    # CLASS WEIGHTS
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(y_tr),
        y=y_tr
    )
    class_weight_dict = {i: w for i, w in enumerate(class_weights)}

    model = build_model((X_tr.shape[1], 1))
    model.compile(
        optimizer=AdamW(2e-4, weight_decay=1e-5),
        loss=weighted_bce(2.8,1.0),
        metrics=["accuracy"]
    )

    history = model.fit(
        AugGen(X_tr, y_tr),
        validation_data=(X_val, y_val),
        epochs=100,
        class_weight=class_weight_dict,
        callbacks=[
            EarlyStopping(patience=20, restore_best_weights=True),
            ReduceLROnPlateau(patience=10, factor=0.5, min_lr=1e-7),
            ModelCheckpoint(
                os.path.join(results_dir, f"model_fold{fold+1}.weights.h5"),
                save_best_only=True,
                save_weights_only=True
            )
        ],
        verbose=1
    )

    model.load_weights(os.path.join(results_dir, f"model_fold{fold+1}.weights.h5"))

    prob = model.predict(X_te).flatten()
    pred = (prob > 0.5).astype(int)

    acc = np.mean(pred == y_te_bin)
    auc = roc_auc_score(y_te_bin, prob)

    acc_list.append(acc)
    auc_list.append(auc)

    print(f"Fold {fold+1} | ACC={acc:.4f} | AUC={auc:.4f}")

    cm = confusion_matrix(y_te_bin, pred)
    confs.append(cm)

    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Normal","Abnormal"],
                yticklabels=["Normal","Abnormal"])
    plt.title(f"Fold {fold+1} Confusion Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f"confusion_fold{fold+1}.png"))
    plt.close()

    with open(summary_path,"a", encoding="utf-8") as f:
        f.write(f"Fold {fold+1}: ACC={acc:.4f}, AUC={auc:.4f}\n")

# =============================================================
#                    OVERALL RESULTS
# =============================================================
total_cm = np.sum(confs, axis=0)
mean_acc = np.mean(acc_list)
mean_auc = np.mean(auc_list)

plt.figure(figsize=(6,5))
sns.heatmap(
    total_cm, annot=True, fmt="d", cmap="Greens",
    xticklabels=["Normal","Abnormal"],
    yticklabels=["Normal","Abnormal"]
)
plt.title(f"Overall Confusion Matrix | Mean ACC={mean_acc:.4f}")
plt.tight_layout()
plt.savefig(os.path.join(results_dir, "confusion_overall.png"), dpi=300)
plt.close()

with open(summary_path,"a", encoding="utf-8") as f:
    f.write("\n"+"="*70+"\n")
    f.write(f"Mean ACC = {mean_acc:.4f}\n")
    f.write(f"Mean AUC = {mean_auc:.4f}\n")

print("\n==============================")
print(" BINARY TRAINING COMPLETED ")
print("==============================")

TensorFlow version: 2.20.0
No GPU detected. Using CPU.

Epoch 1/100


  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 99ms/step - accuracy: 0.7890 - loss: 0.9466 - val_accuracy: 0.6000 - val_loss: 1.2878 - learning_rate: 2.0000e-04
Epoch 2/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 76ms/step - accuracy: 0.8062 - loss: 0.8736 - val_accuracy: 0.9229 - val_loss: 0.7926 - learning_rate: 2.0000e-04
Epoch 3/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8305 - loss: 0.7981 - val_accuracy: 0.9479 - val_loss: 0.6649 - learning_rate: 2.0000e-04
Epoch 4/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8309 - loss: 0.7804 - val_accuracy: 0.9500 - val_loss: 0.7459 - learning_rate: 2.0000e-04
Epoch 5/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 75ms/step - accuracy: 0.8199 - loss: 0.8169 - val_accuracy: 0.9688 - val_loss: 0.6746 - learning_rate: 2.0000e-04
Epoch 6/100
[1m170/170[0m [32m━━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 77ms/step - accuracy: 0.7511 - loss: 1.0345 - val_accuracy: 0.6000 - val_loss: 1.2692 - learning_rate: 2.0000e-04
Epoch 2/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 80ms/step - accuracy: 0.7963 - loss: 0.8903 - val_accuracy: 0.8813 - val_loss: 0.8925 - learning_rate: 2.0000e-04
Epoch 3/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 80ms/step - accuracy: 0.8279 - loss: 0.8156 - val_accuracy: 0.9146 - val_loss: 0.8607 - learning_rate: 2.0000e-04
Epoch 4/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 80ms/step - accuracy: 0.8287 - loss: 0.7685 - val_accuracy: 0.9604 - val_loss: 0.4309 - learning_rate: 2.0000e-04
Epoch 5/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 143ms/step - accuracy: 0.8324 - loss: 0.7576 - val_accuracy: 0.9271 - val_loss: 0.7440 - learning_rate: 2.0000e-04
Epoch 6/100
[1m170/170[0m [32m━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 77ms/step - accuracy: 0.7827 - loss: 0.9634 - val_accuracy: 0.6021 - val_loss: 1.2699 - learning_rate: 2.0000e-04
Epoch 2/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8195 - loss: 0.8314 - val_accuracy: 0.9083 - val_loss: 0.9187 - learning_rate: 2.0000e-04
Epoch 3/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8265 - loss: 0.7948 - val_accuracy: 0.9438 - val_loss: 0.7781 - learning_rate: 2.0000e-04
Epoch 4/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8279 - loss: 0.7843 - val_accuracy: 0.7542 - val_loss: 1.4370 - learning_rate: 2.0000e-04
Epoch 5/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 78ms/step - accuracy: 0.8254 - loss: 0.7981 - val_accuracy: 0.9292 - val_loss: 0.7005 - learning_rate: 2.0000e-04
Epoch 6/100
[1m170/170[0m [32m━━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 76ms/step - accuracy: 0.7551 - loss: 1.0379 - val_accuracy: 0.6000 - val_loss: 1.2347 - learning_rate: 2.0000e-04
Epoch 2/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 75ms/step - accuracy: 0.8158 - loss: 0.8460 - val_accuracy: 0.9208 - val_loss: 0.9911 - learning_rate: 2.0000e-04
Epoch 3/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 75ms/step - accuracy: 0.8176 - loss: 0.8263 - val_accuracy: 0.9229 - val_loss: 0.8090 - learning_rate: 2.0000e-04
Epoch 4/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 75ms/step - accuracy: 0.8103 - loss: 0.8408 - val_accuracy: 0.8667 - val_loss: 0.8969 - learning_rate: 2.0000e-04
Epoch 5/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 76ms/step - accuracy: 0.8254 - loss: 0.8088 - val_accuracy: 0.9292 - val_loss: 0.5862 - learning_rate: 2.0000e-04
Epoch 6/100
[1m170/170[0m [32m━━━━━━━━━━━━━━

  self._warn_if_super_not_called()


[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 87ms/step - accuracy: 0.7651 - loss: 0.9992 - val_accuracy: 0.8188 - val_loss: 1.2642 - learning_rate: 2.0000e-04
Epoch 2/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 88ms/step - accuracy: 0.8136 - loss: 0.8286 - val_accuracy: 0.9187 - val_loss: 1.0512 - learning_rate: 2.0000e-04
Epoch 3/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 86ms/step - accuracy: 0.8243 - loss: 0.8083 - val_accuracy: 0.8375 - val_loss: 1.0184 - learning_rate: 2.0000e-04
Epoch 4/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 84ms/step - accuracy: 0.8140 - loss: 0.8047 - val_accuracy: 0.9062 - val_loss: 0.7664 - learning_rate: 2.0000e-04
Epoch 5/100
[1m170/170[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 83ms/step - accuracy: 0.8371 - loss: 0.7314 - val_accuracy: 0.9312 - val_loss: 0.7075 - learning_rate: 2.0000e-04
Epoch 6/100
[1m170/170[0m [32m━━━━━━━━━━━━━━

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.models import Model # type: ignore
from tensorflow.keras.layers import ( # type: ignore
    Input, Conv1D, Dense, Dropout, BatchNormalization,
    MaxPooling1D, GlobalAveragePooling1D, Activation
)
from tensorflow.keras.optimizers import Adam # pyright: ignore[reportMissingImports]
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint # type: ignore

# =============================================================
#                    PATHS
# =============================================================
DATA_DIR = r"Preprocessing_Updated_Kfold"
RESULTS_DIR = "results_ictal_binary_fixed"
os.makedirs(RESULTS_DIR, exist_ok=True)

summary_path = os.path.join(RESULTS_DIR, "accuracy_summary.txt")
with open(summary_path, "w", encoding="utf-8") as f:
    f.write("ICTAL vs NON-ICTAL (Binary) – Fixed Preprocessing\n")
    f.write("="*70 + "\n\n")

# =============================================================
#                PAD SEGMENTS (IMPORTANT)
# =============================================================
def pad_segments(X_list):
    cleaned = []
    for x in X_list:
        x = np.array(x)
        if x.ndim == 1:
            x = x.reshape(1, -1)
        cleaned.append(x)

    max_len = max(seg.shape[1] for seg in cleaned)

    padded = []
    for seg in cleaned:
        pad_width = max_len - seg.shape[1]
        if pad_width > 0:
            seg = np.pad(seg, ((0,0),(0,pad_width)), mode="constant")
        padded.append(seg)

    return np.vstack(padded)

# =============================================================
#                  DATA AUGMENTATION
# =============================================================
def augment(seg, p=0.6):
    if np.random.random() > p:
        return seg

    L = len(seg)
    op = np.random.choice(["noise","scale","shift","stretch","roll"])

    if op == "noise":
        augmented = seg + np.random.normal(0, 0.03, seg.shape)
    elif op == "scale":
        augmented = seg * np.random.uniform(0.85, 1.15)
    elif op == "shift":
        augmented = seg + np.random.uniform(-0.1, 0.1)
    elif op == "roll":
        augmented = np.roll(seg, np.random.randint(-30, 30))
    elif op == "stretch":
        factor = np.random.uniform(0.9, 1.1)
        stretched = np.interp(
            np.arange(0, L, factor),
            np.arange(L),
            seg
        )
        augmented = stretched
    else:
        augmented = seg

    if len(augmented) > L:
        augmented = augmented[:L]
    elif len(augmented) < L:
        augmented = np.pad(augmented, (0, L - len(augmented)))

    return augmented

# =============================================================
#              OVERSAMPLE ICTAL
# =============================================================
def oversample_ictal(X, y, factor=3):
    X_pos = X[y == 1]
    X_aug_list = []

    for _ in range(factor):
        for seg in X_pos:
            new_seg = augment(seg[:,0], p=1.0)
            X_aug_list.append(new_seg.reshape(-1,1))

    X_aug = np.array(X_aug_list)
    y_aug = np.ones(len(X_aug))

    X_new = np.concatenate([X, X_aug])
    y_new = np.concatenate([y, y_aug])

    return X_new, y_new

# =============================================================
#                    CNN MODEL
# =============================================================
def build_model(input_shape):
    inp = Input(input_shape)

    x = Conv1D(64, 7, padding="same")(inp)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling1D(2)(x)
    x = Dropout(0.3)(x)

    x = Conv1D(128, 5, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling1D(2)(x)
    x = Dropout(0.3)(x)

    x = Conv1D(256, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling1D(2)(x)
    x = Dropout(0.4)(x)

    x = GlobalAveragePooling1D()(x)
    x = Dense(128, activation="relu")(x)
    x = Dropout(0.4)(x)

    out = Dense(1, activation="sigmoid")(x)
    return Model(inp, out)

# =============================================================
#                    TRAINING
# =============================================================
acc_list, auc_list, confs = [], [], []

for fold in range(5):
    print(f"\n========== FOLD {fold+1} ==========")

    # =============================
    # LOAD PREPROCESSED FILES
    # =============================
    X_tr_val = np.load(
        os.path.join(DATA_DIR, f"fold_{fold}_X_train.npy"),
        allow_pickle=True
    )
    X_te = np.load(
        os.path.join(DATA_DIR, f"fold_{fold}_X_test.npy"),
        allow_pickle=True
    )
    y_tr_val = np.load(
        os.path.join(DATA_DIR, f"fold_{fold}_y_train.npy"),
        allow_pickle=True
    )
    y_te = np.load(
        os.path.join(DATA_DIR, f"fold_{fold}_y_test.npy"),
        allow_pickle=True
    )

    # =============================
    # PAD + SHAPE
    # =============================
    X_tr_val = pad_segments(X_tr_val).astype(np.float32)
    X_te = pad_segments(X_te).astype(np.float32)

    # =============================
    # MULTI → BINARY
    # (ICTAL=1 , NORMAL+INTERICTAL=0)
    # =============================
    y_tr_val = np.where(y_tr_val == 2, 1, 0)
    y_te = np.where(y_te == 2, 1, 0)

    X_tr_val = X_tr_val[..., None]
    X_te = X_te[..., None]

    # =============================
    # TRAIN / VAL SPLIT
    # =============================
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_tr_val,
        y_tr_val,
        test_size=0.15,
        stratify=y_tr_val,
        random_state=42
    )

    # =============================
    # CLASS WEIGHTS
    # =============================
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(y_tr),
        y=y_tr
    )
    class_weight_dict = {i: w for i, w in enumerate(class_weights)}

    # =============================
    # OVERSAMPLING
    # =============================
    X_tr, y_tr = oversample_ictal(X_tr, y_tr, factor=3)

    # =============================
    # MODEL
    # =============================
    model = build_model((X_tr.shape[1], 1))
    model.compile(
        optimizer=Adam(1e-4),
        loss="binary_crossentropy",
        metrics=["accuracy"]
    )

    history = model.fit(
        X_tr, y_tr,
        validation_data=(X_val, y_val),
        epochs=60,
        batch_size=32,
        class_weight=class_weight_dict,
        callbacks=[
            EarlyStopping(patience=25, restore_best_weights=True),
            ReduceLROnPlateau(patience=8, factor=0.5),
            ModelCheckpoint(
                os.path.join(RESULTS_DIR, f"model_fold{fold+1}.weights.h5"),
                save_best_only=True,
                save_weights_only=True
            )
        ],
        verbose=1
    )

    # =============================
    # EVALUATION
    # =============================
    model.load_weights(
        os.path.join(RESULTS_DIR, f"model_fold{fold+1}.weights.h5")
    )

    prob = model.predict(X_te).flatten()
    pred = (prob > 0.5).astype(int)

    acc = np.mean(pred == y_te)
    auc = roc_auc_score(y_te, prob)

    acc_list.append(acc)
    auc_list.append(auc)

    cm = confusion_matrix(y_te, pred)
    confs.append(cm)

    print(f"Fold {fold+1} | ACC={acc:.4f} | AUC={auc:.4f}")

    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(f"Fold {fold+1} Confusion Matrix")
    plt.savefig(os.path.join(RESULTS_DIR, f"confusion_fold{fold+1}.png"))
    plt.close()

    with open(summary_path, "a", encoding="utf-8") as f:
        f.write(f"Fold {fold+1}: ACC={acc:.4f}, AUC={auc:.4f}\n")

# =============================================================
#                OVERALL RESULTS
# =============================================================
total_cm = np.sum(confs, axis=0)
mean_acc = np.mean(acc_list)
mean_auc = np.mean(auc_list)

sns.heatmap(total_cm, annot=True, fmt="d", cmap="Greens")
plt.title(f"Overall CM | Mean ACC={mean_acc:.4f}")
plt.savefig(os.path.join(RESULTS_DIR, "confusion_overall.png"))
plt.close()

with open(summary_path, "a", encoding="utf-8") as f:
    f.write("\n" + "="*70 + "\n")
    f.write(f"Mean ACC = {mean_acc:.4f}\n")
    f.write(f"Mean AUC = {mean_auc:.4f}\n")

print("\n==============================")
print(" BINARY TRAINING COMPLETED ")
print("==============================")



Epoch 1/60
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 32ms/step - accuracy: 0.5269 - loss: 0.7375 - val_accuracy: 0.2000 - val_loss: 0.8794 - learning_rate: 1.0000e-04
Epoch 2/60
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 31ms/step - accuracy: 0.6898 - loss: 0.4922 - val_accuracy: 0.2354 - val_loss: 0.9795 - learning_rate: 1.0000e-04
Epoch 3/60
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 31ms/step - accuracy: 0.8171 - loss: 0.3301 - val_accuracy: 0.6396 - val_loss: 0.5903 - learning_rate: 1.0000e-04
Epoch 4/60
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 32ms/step - accuracy: 0.8950 - loss: 0.2404 - val_accuracy: 0.9333 - val_loss: 0.2595 - learning_rate: 1.0000e-04
Epoch 5/60
[1m136/136[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 31ms/step - accuracy: 0.9081 - loss: 0.2038 - val_accuracy: 0.9688 - val_loss: 0.1077 - learning_rate: 1.0000e-04
Epoch 6/60
[1m136/136[0m [32m━━━━━━━━━━━━