In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Dense, Conv2D, BatchNormalization, Activation, AveragePooling2D, 
    Dropout, Flatten, Reshape, GlobalAveragePooling2D, LSTM, Bidirectional,
    Add, LayerNormalization, MultiHeadAttention
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from scipy.signal import butter, filtfilt, iirnotch, welch
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import HistGradientBoostingClassifier
import matplotlib.pyplot as plt
import os
import gc
import zipfile
import glob
import re
import time
import concurrent.futures
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns

# For reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Enable memory growth for GPU
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    try:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
        print(f"Found {len(physical_devices)} GPU(s), memory growth enabled")
    except Exception as e:
        print(f"Error setting memory growth: {e}")

# Mixed precision
try:
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)
    print(f"Using mixed precision - Compute dtype: {policy.compute_dtype}")
except:
    print("Mixed precision not available, using default precision")

# --- Configuration ---
INPUT_DATA_PATH = '/kaggle/input/grasp-and-lift-eeg-detection'
WORKING_PATH = '/kaggle/working/'
OUTPUT_PATH = WORKING_PATH

TRAIN_SERIES = list(range(1, 9))
VALIDATION_SPLIT_SERIES = [7, 8]

LOW_CUT = 0.5
HIGH_CUT = 60.0
FS = 500

WINDOW_SIZE = 400
WINDOW_OVERLAP = 200

BATCH_SIZE = 128
MAX_EPOCHS = 40
INITIAL_LR = 1e-3
DROPOUT_RATE = 0.5

TARGET_COLS = ['HandStart', 'FirstDigitTouch', 'BothStartLoadPhase', 'LiftOff', 'Replace', 'BothReleased']

# ------------------ Data Augmentation ---------------------
def add_gaussian_noise(X, std=0.01):
    noise = np.random.normal(0, std, X.shape)
    return X + noise

def random_channel_dropout(X, dropout_prob=0.1):
    X_aug = X.copy()
    for i in range(X.shape[0]):
        if np.random.rand() < dropout_prob:
            ch = np.random.randint(X.shape[1])
            X_aug[i, ch, ...] = 0
    return X_aug

def mixup(X, y, alpha=0.2):
    idx = np.random.permutation(len(X))
    l = np.random.beta(alpha, alpha)
    X_mix = l * X + (1 - l) * X[idx]
    y_mix = l * y + (1 - l) * y[idx]
    return X_mix, y_mix

# ---------- Filtering and Feature Extraction -----------------
def butter_bandpass_filter_optimized(data, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    chunk_size = min(10000, data.shape[0])
    n_samples = data.shape[0]
    n_channels = data.shape[1]
    filtered_data = np.zeros_like(data)
    def filter_chunk(start_idx):
        end_idx = min(start_idx + chunk_size, n_samples)
        filtered_chunk = filtfilt(b, a, data[start_idx:end_idx, :], axis=0)
        return start_idx, filtered_chunk
    if data.shape[0] > 50000 and n_channels > 16:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for i in range(0, n_samples, chunk_size):
                futures.append(executor.submit(filter_chunk, i))
            for future in concurrent.futures.as_completed(futures):
                start_idx, chunk = future.result()
                end_idx = min(start_idx + chunk_size, n_samples)
                filtered_data[start_idx:end_idx, :] = chunk
    else:
        for i in range(0, n_samples, chunk_size):
            start_idx, chunk = filter_chunk(i)
            end_idx = min(start_idx + chunk_size, n_samples)
            filtered_data[start_idx:end_idx, :] = chunk
    return filtered_data

def apply_notch_filter(data, fs=500, f0=50.0, quality_factor=30.0):
    nyq = 0.5 * fs
    w0 = f0 / nyq
    b, a = iirnotch(w0, quality_factor)
    chunk_size = min(10000, data.shape[0])
    n_samples = data.shape[0]
    filtered_data = np.zeros_like(data)
    for i in range(0, n_samples, chunk_size):
        end = min(i + chunk_size, n_samples)
        filtered_data[i:end] = filtfilt(b, a, data[i:end], axis=0)
    return filtered_data

def extract_features_from_window(window):
    n_samples, n_channels = window.shape
    means = np.mean(window, axis=0)
    stds = np.std(window, axis=0)
    diff1 = np.diff(window, n=1, axis=0)
    diff2 = np.diff(window, n=2, axis=0)
    activity = np.var(window, axis=0)
    mobility_num = np.var(diff1, axis=0)
    mobility_den = activity.copy()
    mobility_den[mobility_den == 0] = 1e-10
    mobility = np.sqrt(mobility_num / mobility_den)
    complexity_num = np.var(diff2, axis=0)
    complexity_den = np.var(diff1, axis=0)
    complexity_den[complexity_den == 0] = 1e-10
    complexity = np.sqrt(complexity_num / complexity_den) / mobility
    mobility = np.nan_to_num(mobility)
    complexity = np.nan_to_num(complexity)
    downsample_factor = 4
    downsampled = window[::downsample_factor]
    fs_downsampled = FS / downsample_factor
    bands = [(1, 4), (4, 8), (8, 13), (13, 30), (30, 45)]
    band_powers = np.zeros((n_channels, len(bands)))
    for ch in range(n_channels):
        freqs, psd = welch(downsampled[:, ch], fs=fs_downsampled, nperseg=min(256, downsampled.shape[0]))
        for b, (low, high) in enumerate(bands):
            idx = np.logical_and(freqs >= low, freqs <= high)
            band_powers[ch, b] = np.mean(psd[idx]) if np.any(idx) else 0
    features = np.hstack([
        means, stds, activity, mobility, complexity, band_powers.flatten()
    ])
    return features

# --------------- Data Loading and Windowing -----------------
def load_series_data(target_series_ids, data_base_path, is_test=False, downsample_factor=2):
    all_data = []
    all_labels = []
    all_ids = []
    sub_dir = 'test' if is_test else 'train'
    full_sub_dir_path = os.path.join(data_base_path, sub_dir)
    if not os.path.isdir(full_sub_dir_path):
        raise FileNotFoundError(f"Directory not found: {full_sub_dir_path}")
    print(f"Loading series {target_series_ids}")
    all_files = glob.glob(os.path.join(full_sub_dir_path, 'subj*_series*.csv'))
    data_files = []
    event_files = {}
    for series_id in target_series_ids:
        data_pattern = re.compile(rf"subj\d+_series{series_id}_data\.csv$")
        series_data_files = [f for f in all_files if data_pattern.search(os.path.basename(f))]
        if not series_data_files:
            print(f"Warning: No data files found for series {series_id}")
            continue
        data_files.extend(series_data_files)
        if not is_test:
            for data_file in series_data_files:
                event_file = data_file.replace('_data.csv', '_events.csv')
                if os.path.exists(event_file):
                    event_files[data_file] = event_file
    if not data_files:
        raise FileNotFoundError(f"No data files found for any series in {target_series_ids}")
    print(f"Found {len(data_files)} data files")
    for data_file in data_files:
        try:
            if os.path.getsize(data_file) > 100 * 1024 * 1024:
                chunks = pd.read_csv(data_file, chunksize=100000)
                eeg_chunks = []
                id_chunks = []
                for chunk in chunks:
                    eeg_chunks.append(chunk.drop('id', axis=1).values)
                    id_chunks.append(chunk['id'].values)
                eeg_data = np.vstack(eeg_chunks)
                ids = np.concatenate(id_chunks)
            else:
                df = pd.read_csv(data_file)
                ids = df['id'].values
                eeg_data = df.drop('id', axis=1).values
            eeg_data = eeg_data[::downsample_factor]
            ids = ids[::downsample_factor]
            all_data.append(eeg_data)
            all_ids.append(ids)
            if not is_test and data_file in event_files:
                event_file = event_files[data_file]
                series_events = pd.read_csv(event_file)
                labels = np.zeros((len(eeg_data), len(TARGET_COLS)))
                event_map = {col: i for i, col in enumerate(TARGET_COLS)}
                id_to_index = {id_val: idx for idx, id_val in enumerate(ids)}
                for _, row in series_events.iterrows():
                    event_id = row['id']
                    if event_id in id_to_index:
                        idx = id_to_index[event_id]
                        for event_name in TARGET_COLS:
                            if event_name in row and row[event_name] == 1:
                                labels[idx, event_map[event_name]] = 1
                all_labels.append(labels)
            elif not is_test:
                all_labels.append(np.zeros((len(eeg_data), len(TARGET_COLS))))
            gc.collect()
        except Exception as e:
            print(f"Error processing {data_file}: {e}")
    if not all_data:
        raise ValueError("Failed to load any valid data")
    concatenated_data = np.vstack(all_data)
    concatenated_ids = np.concatenate(all_ids)
    if not is_test and all_labels:
        concatenated_labels = np.vstack(all_labels)
    else:
        concatenated_labels = None if is_test else np.zeros((concatenated_data.shape[0], len(TARGET_COLS)))
    print(f"Data loaded and downsampled. Shape: {concatenated_data.shape}")
    return concatenated_data, concatenated_labels, concatenated_ids

def create_windows_with_features(data, labels=None, ids=None, window_size=WINDOW_SIZE, overlap=WINDOW_OVERLAP):
    print("Creating windows with advanced features...")
    num_samples, num_channels = data.shape
    step = window_size - overlap
    num_windows = max(0, (num_samples - window_size) // step + 1)
    if num_windows <= 0:
        print(f"Not enough samples ({num_samples}) for window size {window_size}")
        empty_shape = (0, window_size, num_channels)
        return (np.array([]).reshape(empty_shape), 
                None if labels is None else np.array([]).reshape((0, labels.shape[1])), 
                np.array([]),
                np.array([]))
    print(f"Creating {num_windows} windows with size {window_size} and step {step}")
    windows = np.zeros((num_windows, window_size, num_channels))
    window_labels = np.zeros((num_windows, labels.shape[1])) if labels is not None else None
    window_ids = np.zeros(num_windows, dtype=object if ids.dtype == np.dtype('O') else ids.dtype)
    feature_list = []
    batch_size = min(1000, num_windows)
    for batch_start in range(0, num_windows, batch_size):
        batch_end = min(batch_start + batch_size, num_windows)
        for i in range(batch_start, batch_end):
            start_idx = i * step
            end_idx = start_idx + window_size
            window = data[start_idx:end_idx, :]
            windows[i] = window
            if i % 5 == 0 or i == num_windows - 1:
                features = extract_features_from_window(window)
                feature_list.append((i, features))
            if labels is not None:
                window_labels[i] = np.max(labels[start_idx:end_idx], axis=0)
            window_ids[i] = ids[end_idx - 1]
    feature_indices = [item[0] for item in feature_list]
    feature_values = [item[1] for item in feature_list]
    feature_dim = feature_values[0].shape[0]
    all_features = np.zeros((num_windows, feature_dim))
    for idx, features in zip(feature_indices, feature_values):
        all_features[idx] = features
    for i in range(num_windows):
        if i not in feature_indices:
            left_idx = max([idx for idx in feature_indices if idx < i], default=feature_indices[0])
            right_idx = min([idx for idx in feature_indices if idx > i], default=feature_indices[-1])
            if left_idx == right_idx:
                all_features[i] = all_features[left_idx]
            else:
                left_weight = (right_idx - i) / (right_idx - left_idx)
                right_weight = (i - left_idx) / (right_idx - left_idx)
                all_features[i] = left_weight * all_features[left_idx] + right_weight * all_features[right_idx]
    print(f"Created windows shape: {windows.shape}")
    print(f"Extracted features shape: {all_features.shape}")
    return windows, window_labels, window_ids, all_features

def preprocess_data_pipeline(series_ids, data_base_path, scaler=None, is_test=False, augment=False):
    print(f"Loading {'test' if is_test else 'training'} data for series {series_ids}...")
    raw_data, raw_labels, raw_ids = load_series_data(series_ids, data_base_path, is_test=is_test, downsample_factor=2)
    if raw_data.shape[0] == 0:
        print(f"No data loaded for series {series_ids}")
        return np.array([]), None, np.array([]), None, scaler
    print("Applying advanced filters...")
    filtered_data = butter_bandpass_filter_optimized(raw_data, LOW_CUT, HIGH_CUT, FS/2)
    filtered_data = apply_notch_filter(filtered_data, fs=FS/2, f0=50.0 * 2 / FS)
    del raw_data
    gc.collect()
    print("Scaling data...")
    if scaler is None:
        scaler = RobustScaler()
        scaled_data = scaler.fit_transform(filtered_data)
    else:
        scaled_data = scaler.transform(filtered_data)
    del filtered_data
    gc.collect()
    print("Creating windows and extracting features...")
    windows, window_labels, window_ids, features = create_windows_with_features(
        scaled_data, raw_labels, raw_ids, WINDOW_SIZE, WINDOW_OVERLAP
    )
    del scaled_data, raw_labels, raw_ids
    gc.collect()
    if windows.shape[0] == 0:
        print(f"No windows created for series {series_ids}")
        return np.array([]), None, np.array([]), None, scaler
    n_channels = windows.shape[2]
    # Data augmentation
    if augment:
        # Add noise, dropout, and mixup
        windows = add_gaussian_noise(windows, std=0.02)
        windows = random_channel_dropout(windows, dropout_prob=0.15)
        if window_labels is not None:
            windows, window_labels = mixup(windows, window_labels, alpha=0.2)
    windows_reshaped = windows.reshape(windows.shape[0], n_channels, WINDOW_SIZE, 1)
    if is_test:
        return windows_reshaped, window_ids, features, None, scaler
    else:
        return windows_reshaped, window_labels, window_ids, features, scaler

# ------------------ Model Architectures ---------------------
def build_eegnet_tcn(n_channels, window_size, n_outputs=6):
    inputs = Input(shape=(n_channels, window_size, 1))
    x = Conv2D(16, (1, 64), padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = Conv2D(32, (n_channels, 1), padding='valid', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 4))(x)
    x = Dropout(0.3)(x)
    for dilation_rate in [1, 2, 4, 8]:
        res = x
        x = Conv2D(32, (1, 3), padding='same', dilation_rate=(1, dilation_rate))(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(32, (1, 3), padding='same')(x)
        x = BatchNormalization()(x)
        if res.shape[-1] != x.shape[-1]:
            res = Conv2D(32, (1, 1), padding='same')(res)
        x = Add()([res, x])
        x = Activation('elu')(x)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='elu')(x)
    x = Dropout(DROPOUT_RATE)(x)
    x = Dense(n_outputs, activation='sigmoid')(x)
    return Model(inputs=inputs, outputs=x)

def build_cnn_blstm_model(n_channels, window_size, n_outputs=6, feature_dim=None):
    eeg_input = Input(shape=(n_channels, window_size, 1), name='eeg_input')
    x = Conv2D(32, (1, 32), padding='same')(eeg_input)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 4))(x)
    x = Conv2D(64, (n_channels, 1), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = Dropout(0.3)(x)
    _, c, t, f = x.shape
    x = Reshape((t, c*f))(x)
    x = Bidirectional(LSTM(64, return_sequences=True))(x)
    x = Dropout(0.3)(x)
    x = Bidirectional(LSTM(32))(x)
    if feature_dim is not None:
        feature_input = Input(shape=(feature_dim,), name='feature_input')
        f_x = Dense(64, activation='elu')(feature_input)
        f_x = BatchNormalization()(f_x)
        f_x = Dropout(0.3)(f_x)
        f_x = Dense(32, activation='elu')(f_x)
        x = tf.keras.layers.concatenate([x, f_x])
        x = Dense(128, activation='elu')(x)
        x = Dropout(DROPOUT_RATE)(x)
        x = Dense(32, activation='elu')(x)
        x = Dense(n_outputs, activation='sigmoid')(x)
        return Model(inputs=[eeg_input, feature_input], outputs=x)
    else:
        x = Dense(128, activation='elu')(x)
        x = Dropout(DROPOUT_RATE)(x)
        x = Dense(32, activation='elu')(x)
        x = Dense(n_outputs, activation='sigmoid')(x)
        return Model(inputs=eeg_input, outputs=x)

def build_attention_eeg_model(n_channels, window_size, n_outputs=6):
    # Attention-based EEG model
    inp = Input(shape=(n_channels, window_size, 1))
    x = Conv2D(32, (1, 32), padding='same', activation='elu')(inp)
    x = BatchNormalization()(x)
    x = Reshape((window_size, n_channels*32))(x)
    x = LayerNormalization()(x)
    x = MultiHeadAttention(num_heads=4, key_dim=16)(x, x)
    x = Flatten()(x)
    x = Dense(128, activation='elu')(x)
    x = Dropout(DROPOUT_RATE)(x)
    out = Dense(n_outputs, activation='sigmoid')(x)
    return Model(inp, out)

# ------------------ Training and Validation -----------------
def train_model_with_features(X_train, y_train, X_val, y_val, features_train=None, features_val=None, model_type='cnn_blstm'):
    print(f"\n--- Building and Training {model_type} Model ---")
    n_channels = X_train.shape[1]
    if model_type == 'eegnet_tcn':
        model = build_eegnet_tcn(n_channels, WINDOW_SIZE, len(TARGET_COLS))
        train_data = X_train
        val_data = X_val
    elif model_type == 'cnn_blstm':
        if features_train is not None:
            feature_dim = features_train.shape[1]
            model = build_cnn_blstm_model(n_channels, WINDOW_SIZE, len(TARGET_COLS), feature_dim)
            train_data = [X_train, features_train]
            val_data = [X_val, features_val]
        else:
            model = build_cnn_blstm_model(n_channels, WINDOW_SIZE, len(TARGET_COLS))
            train_data = X_train
            val_data = X_val
    elif model_type == 'attention':
        model = build_attention_eeg_model(n_channels, WINDOW_SIZE, len(TARGET_COLS))
        train_data = X_train
        val_data = X_val
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    model.compile(
        optimizer=Adam(learning_rate=INITIAL_LR),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )
    callbacks_list = [
        ModelCheckpoint(
            filepath=os.path.join(OUTPUT_PATH, f'best_{model_type}_model.keras'),
            monitor='val_auc', mode='max', save_best_only=True, verbose=1
        ),
        EarlyStopping(
            monitor='val_auc', mode='max', patience=10, verbose=1, restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_loss', mode='min', factor=0.5, patience=4, min_lr=1e-6, verbose=1
        )
    ]
    class_weights = None
    history = model.fit(
        train_data, y_train,
        batch_size=BATCH_SIZE,
        epochs=MAX_EPOCHS,
        validation_data=(val_data, y_val),
        callbacks=callbacks_list,
        class_weight=class_weights,
        verbose=2
    )
    return model, history

def tta_predict(model, X, features=None, num_aug=5):
    preds = []
    for _ in range(num_aug):
        X_aug = add_gaussian_noise(X, std=0.01)
        if features is not None:
            pred = model.predict([X_aug, features], batch_size=BATCH_SIZE*2, verbose=0)
        else:
            pred = model.predict(X_aug, batch_size=BATCH_SIZE*2, verbose=0)
        preds.append(pred)
    return np.mean(preds, axis=0)

def train_with_cross_validation(X, y, features=None, n_splits=5):
    print(f"\n--- Training with {n_splits}-Fold Stratified Cross-Validation ---")
    # Use stratified CV on the first label (works for multi-label if at least one event per window)
    y_strat = (y[:, 0] > 0).astype(int)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    model_types = ['eegnet_tcn', 'cnn_blstm', 'attention']
    all_val_aucs = []
    fold_models = []
    histories = []
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y_strat)):
        print(f"\n--- Fold {fold+1}/{n_splits} ---")
        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]
        if features is not None:
            features_train, features_val = features[train_idx], features[val_idx]
        else:
            features_train = features_val = None
        model_type = model_types[fold % len(model_types)]
        print(f"Training {model_type} model for fold {fold+1}")
        model, history = train_model_with_features(
            X_train_fold, y_train_fold, X_val_fold, y_val_fold, 
            features_train, features_val, model_type
        )
        # TTA for validation
        if features is not None and model_type == 'cnn_blstm':
            val_pred = tta_predict(model, X_val_fold, features_val)
            val_scores = [
                roc_auc_score((y_val_fold[:, i] > 0.5).astype(int), val_pred[:, i]) if len(np.unique(y_val_fold[:, i])) > 1 else np.nan
                for i in range(y_val_fold.shape[1])
            ]
            val_auc = np.nanmean(val_scores)
        else:
            val_pred = tta_predict(model, X_val_fold)
            val_scores = [
                roc_auc_score((y_val_fold[:, i] > 0.5).astype(int), val_pred[:, i]) if len(np.unique(y_val_fold[:, i])) > 1 else np.nan
                for i in range(y_val_fold.shape[1])
            ]
            val_auc = np.nanmean(val_scores)
        print(f"Fold {fold+1} {model_type} - Validation AUC: {val_auc:.4f}")
        all_val_aucs.append(val_auc)
        fold_models.append((model_type, model))
        histories.append(history)
    print("\n--- Cross-Validation Results ---")
    for fold, (model_type, _) in enumerate(fold_models):
        print(f"Fold {fold+1}: {model_type} - AUC = {all_val_aucs[fold]:.4f}")
    print(f"Average AUC: {np.mean(all_val_aucs):.4f} ± {np.std(all_val_aucs):.4f}")
    best_fold = np.argmax(all_val_aucs)
    best_model_type, best_model = fold_models[best_fold]
    print(f"Best model from fold {best_fold+1}: {best_model_type} with AUC {all_val_aucs[best_fold]:.4f}")
    return (best_model_type, best_model), fold_models, histories[best_fold]

def plot_history(history, model_type='model'):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title(f'{model_type} Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='lower right')
    plt.grid(True)
    plt.subplot(1, 3, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'{model_type} Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    plt.grid(True)
    if 'auc' in history.history:
        plt.subplot(1, 3, 3)
        plt.plot(history.history['auc'])
        plt.plot(history.history['val_auc'])
        plt.title(f'{model_type} AUC')
        plt.ylabel('AUC')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='lower right')
        plt.grid(True)
    plt.tight_layout()
    history_plot_path = os.path.join(OUTPUT_PATH, f'{model_type}_training_history.png')
    try:
        plt.savefig(history_plot_path)
        print(f"History plot saved to: {history_plot_path}")
    except Exception as e:
        print(f"Error saving history plot: {e}")
    plt.close()

# ----------- Advanced Ensemble with Gradient Boosting --------
def ensemble_predictions(preds_list, y_val=None):
    preds_array = np.stack(preds_list, axis=-1)
    n_samples, n_outputs, n_models = preds_array.shape
    meta_preds = np.zeros((n_samples, n_outputs))
    for i in range(n_outputs):
        if len(np.unique(y_val[:, i])) > 1:
            meta_model = HistGradientBoostingClassifier(max_iter=200)
            meta_model.fit(preds_array[:, i, :], y_val[:, i])
            meta_preds[:, i] = meta_model.predict_proba(preds_array[:, i, :])[:, 1]
        else:
            meta_preds[:, i] = np.mean(preds_array[:, i, :], axis=-1)
    return meta_preds

#------------ Confusion Matrix ---------------------------------
def plot_confusion_matrices(y_true, y_pred, class_names):
    """
    Plot confusion matrices for all classes
    
    Args:
        y_true: Ground truth labels (one-hot encoded)
        y_pred: Predicted probabilities
        class_names: List of class names
    """
    # Convert probabilities to binary predictions
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    # Create a figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Overall metrics for reporting
    all_cms = []
    all_metrics = {
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': []
    }
    
    for i, class_name in enumerate(class_names):
        # Calculate confusion matrix
        cm = confusion_matrix(y_true[:, i], y_pred_binary[:, i])
        all_cms.append(cm)
        
        # Calculate metrics
        acc = accuracy_score(y_true[:, i], y_pred_binary[:, i])
        prec = precision_score(y_true[:, i], y_pred_binary[:, i], zero_division=0)
        rec = recall_score(y_true[:, i], y_pred_binary[:, i], zero_division=0)
        f1 = f1_score(y_true[:, i], y_pred_binary[:, i], zero_division=0)
        
        all_metrics['accuracy'].append(acc)
        all_metrics['precision'].append(prec)
        all_metrics['recall'].append(rec)
        all_metrics['f1'].append(f1)
        
        # Plot confusion matrix
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i])
        axes[i].set_title(f'{class_name}\nAcc: {acc:.4f}, F1: {f1:.4f}')
        axes[i].set_xlabel('Predicted')
        axes[i].set_ylabel('True')
        axes[i].set_xticklabels(['Negative', 'Positive'])
        axes[i].set_yticklabels(['Negative', 'Positive'])
    
    plt.tight_layout()
    cm_plot_path = os.path.join(OUTPUT_PATH, 'confusion_matrices.png')
    plt.savefig(cm_plot_path)
    plt.close()
    print(f"Confusion matrices saved to: {cm_plot_path}")
    
    # Return the metrics for reporting
    return all_cms, all_metrics

# Function to print detailed metrics report
def print_metrics_report(metrics, class_names):
    """Print a detailed metrics report"""
    print("\n===== Detailed Metrics Report =====")
    print(f"{'Class':<20} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1 Score':<10}")
    print("-" * 60)
    
    for i, class_name in enumerate(class_names):
        print(f"{class_name:<20} {metrics['accuracy'][i]:.4f}     {metrics['precision'][i]:.4f}     "
              f"{metrics['recall'][i]:.4f}     {metrics['f1'][i]:.4f}")
    
    # Print averages
    print("-" * 60)
    print(f"{'Average':<20} {np.mean(metrics['accuracy']):.4f}     {np.mean(metrics['precision']):.4f}     "
          f"{np.mean(metrics['recall']):.4f}     {np.mean(metrics['f1']):.4f}")



# ------------------- Main Pipeline --------------------------
if __name__ == "__main__":
    start_time = time.time()
    print("=== Starting Advanced EEG Ensemble Pipeline ===")
    if os.path.exists(os.path.join(INPUT_DATA_PATH, 'train.zip')):
        print("Unzipping data files...")
        with zipfile.ZipFile(os.path.join(INPUT_DATA_PATH, 'train.zip'), 'r') as zip_ref:
            zip_ref.extractall(WORKING_PATH)
    print("\n--- Preparing Training Data ---")
    train_ids = [s for s in TRAIN_SERIES if s not in VALIDATION_SPLIT_SERIES]
    X_train, y_train, _, features_train, scaler = preprocess_data_pipeline(
        train_ids, WORKING_PATH, is_test=False, augment=True
    )
    if X_train.shape[0] == 0:
        raise ValueError("No training samples after preprocessing")
    print(f"Train shapes: X={X_train.shape}, y={y_train.shape}, features={features_train.shape}")
    print("\n--- Preparing Validation Data ---")
    X_val, y_val, _, features_val, _ = preprocess_data_pipeline(
        VALIDATION_SPLIT_SERIES, WORKING_PATH, scaler=scaler, is_test=False
    )
    if X_val.shape[0] == 0:
        raise ValueError("No validation samples after preprocessing")
    print(f"Validation shapes: X={X_val.shape}, y={y_val.shape}, features={features_val.shape}")
    gc.collect()
    print("\n--- Using Cross-Validation Training Strategy ---")
    X_full = np.concatenate([X_train, X_val])
    y_full = np.concatenate([y_train, y_val])
    features_full = np.concatenate([features_train, features_val])
    best_model_info, fold_models, history = train_with_cross_validation(
        X_full, y_full, features=features_full, n_splits=5
    )
    # Ensemble predictions on validation set using TTA
    preds_list = []
    for model_type, model in fold_models:
        if model_type == 'cnn_blstm' and features_val is not None:
            val_pred = tta_predict(model, X_val, features_val)
        else:
            val_pred = tta_predict(model, X_val)
        preds_list.append(val_pred)
    blended_preds = ensemble_predictions(preds_list, y_val=y_val)
    aucs = [roc_auc_score(y_val[:, i], blended_preds[:, i]) if len(np.unique(y_val[:, i])) > 1 else np.nan for i in range(blended_preds.shape[1])]
    avg_auc = np.nanmean(aucs)
    print(f"\n=== Ensemble Model Performance ===")
    
    # Calculate and plot confusion matrices
    cms, detailed_metrics = plot_confusion_matrices(y_val, blended_preds, TARGET_COLS)
    print_metrics_report(detailed_metrics, TARGET_COLS)
    
    print(f"\Validaiton AUC Scores by Class:")
    for i, event_name in enumerate(TARGET_COLS):
        print(f"{event_name}: {aucs[i]:.4f}")
    print(f"Average Ensemble AUC: {avg_auc:.4f}")
    
    # Plot ROC curves for ensemble
    plt.figure(figsize=(12, 10))
    for i, event_name in enumerate(TARGET_COLS):
        try:
            if len(np.unique(y_val[:, i])) > 1:
                fpr, tpr, _ = roc_curve(y_val[:, i], blended_preds[:, i])
                plt.plot(fpr, tpr, label=f'{event_name} (AUC = {aucs[i]:.4f})')
        except Exception as e:
            print(f"{event_name} ROC error: {e}")
    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Ensemble ROC Curves - Avg AUC: {avg_auc:.4f}')
    plt.legend(loc='lower right')
    plt.grid(True)
    plot_path = os.path.join(OUTPUT_PATH, f'ensemble_roc_curves.png')
    try:
        plt.savefig(plot_path)
        print(f"Ensemble ROC plot saved to: {plot_path}")
    except Exception as e:
        print(f"Error saving ensemble ROC plot: {e}")
    plt.close()
    plot_history(history, model_type='ensemble')
    print(f"\n=== Final Model Performance ===")
    print(f"Model type: ensemble")
    print(f"Validation AUC: {avg_auc:.4f}")
    end_time = time.time()
    execution_time = end_time - start_time
    hours, remainder = divmod(execution_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"\nTotal execution time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
    print("\n=== Process Complete ===")

2025-05-02 15:07:42.692922: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746198462.944313      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746198463.010362      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Found 2 GPU(s), memory growth enabled
Using mixed precision - Compute dtype: float16
=== Starting Advanced EEG Ensemble Pipeline ===
Unzipping data files...

--- Preparing Training Data ---
Loading training data for series [1, 2, 3, 4, 5, 6]...
Loading series [1, 2, 3, 4, 5, 6]
Found 72 data files
Data loaded and downsampled. Shape: (7409687, 32)
Applying advanced filters...
Scaling data...
Creating windows and extracting features...
Creating windows with advanced features...
Creating 37047 windows with size 400 and step 200
Created windows shape: (37047, 400, 32)
Extracted features shape: (37047, 320)
Train shapes: X=(37047, 32, 400, 1), y=(37047, 6), features=(37047, 320)

--- Preparing Validation Data ---
Loading training data for series [7, 8]...
Loading series [7, 8]
Found 24 data files
Data loaded and downsampled. Shape: (1583211, 32)
Applying advanced filters...
Scaling data...
Creating windows and extracting features...
Creating windows with advanced features...
Creating 7915 w

I0000 00:00:1746199591.225841      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1746199591.226555      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


Epoch 1/40


I0000 00:00:1746199615.021211     112 service.cc:148] XLA service 0x7c871804d230 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1746199615.023720     112 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1746199615.023744     112 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1746199616.428862     112 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1746199630.768764     112 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.



Epoch 1: val_auc improved from -inf to 0.59750, saving model to /kaggle/working/best_eegnet_tcn_model.keras
282/282 - 48s - 171ms/step - accuracy: 0.1964 - auc: 0.5614 - loss: 0.4505 - val_accuracy: 0.1956 - val_auc: 0.5975 - val_loss: 0.4313 - learning_rate: 0.0010
Epoch 2/40

Epoch 2: val_auc improved from 0.59750 to 0.60997, saving model to /kaggle/working/best_eegnet_tcn_model.keras
282/282 - 6s - 21ms/step - accuracy: 0.2255 - auc: 0.5919 - loss: 0.4265 - val_accuracy: 0.2789 - val_auc: 0.6100 - val_loss: 0.4212 - learning_rate: 0.0010
Epoch 3/40

Epoch 3: val_auc improved from 0.60997 to 0.61593, saving model to /kaggle/working/best_eegnet_tcn_model.keras
282/282 - 6s - 21ms/step - accuracy: 0.2407 - auc: 0.6076 - loss: 0.4186 - val_accuracy: 0.2462 - val_auc: 0.6159 - val_loss: 0.4177 - learning_rate: 0.0010
Epoch 4/40

Epoch 4: val_auc improved from 0.61593 to 0.61669, saving model to /kaggle/working/best_eegnet_tcn_model.keras
282/282 - 6s - 21ms/step - accuracy: 0.2516 - auc