In [6]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.ensemble import RandomForestClassifier
import pickle
import os
import warnings
warnings.filterwarnings('ignore')

# For time series analysis
from scipy import signal, interpolate
from scipy.stats import skew, kurtosis, mode # Import mode explicitly

# Deep learning libraries
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM, Dense, Dropout, Conv1D, MaxPooling1D, Flatten, Input, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical

# Removed tsai as it's not directly used in the provided code snippets but can be added back if needed.
# from tsai.all import *

print("Libraries imported successfully!")
print(f"TensorFlow version: {tf.__version__}")

# --- Configuration ---
DATA_PATH = 'data/WESAD/'
SUBJECT_IDS = [f"S{i}" for i in [2,3,4,5,6,7,8,9,10,11,13,14,15,16,17]] # Exclude S12 if it's missing or corrupted

# Sensor sampling rates
FS_CHEST = 700 # Hz
FS_WRIST_ACC = 32 # Hz
FS_WRIST_BVP = 64 # Hz
FS_WRIST_EDA_TEMP = 4 # Hz

# Windowing parameters
CHEST_WINDOW_SIZE_SEC = 5    # 5 seconds
WRIST_WINDOW_SIZE_SEC = 2.5  # 2.5 seconds (adjusted for lower sampling rate)
OVERLAP_RATIO = 0.5          # 50% overlap

# Labels to consider (0: Not defined, 1: Baseline, 2: Stress, 3: Amusement, 4: Meditation)
VALID_LABELS = [0, 1, 2, 3, 4] # Now includes 0 (Not defined) and 4 (Meditation)
LABEL_MAP = {0: 'Not defined', 1: 'Baseline', 2: 'Stress', 3: 'Amusement', 4: 'Meditation'}

Libraries imported successfully!
TensorFlow version: 2.19.0


In [7]:
# =============================================================================
# 1. WESAD DATASET OVERVIEW AND LOADING
# =============================================================================

"""
WESAD Dataset Structure:
- Physiological signals: ECG, EDA, EMG, Respiration, Temperature
- Labels: Baseline (1), Stress (2), Amusement (3), Meditation (4)
- Sampling rates: 700Hz (chest), 64Hz (wrist)
- 15 subjects with multimodal data

Common approach for emotion detection:
1. Feature extraction from physiological signals
2. Time-series modeling (LSTM, CNN, Transformer)
3. Traditional ML with engineered features
4. Multimodal fusion approaches
"""

def load_wesad_subject(subject_id, data_path=DATA_PATH):
    """
    Load WESAD data for a specific subject.
    Returns chest and wrist sensor data with labels.
    """
    file_path = os.path.join(data_path, subject_id, f"{subject_id}.pkl")
    
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        
        print(f"✓ Successfully loaded subject {subject_id}")
        
        # Extract chest sensor data (700Hz)
        chest_data = {
            'ACC': data['signal']['chest']['ACC'],
            'ECG': data['signal']['chest']['ECG'],
            'EMG': data['signal']['chest']['EMG'],
            'EDA': data['signal']['chest']['EDA'], 
            'Temp': data['signal']['chest']['Temp'],
            'Resp': data['signal']['chest']['Resp']
        }
        
        # Extract wrist sensor data (64Hz for BVP, 32Hz for ACC, 4Hz for EDA/TEMP)
        wrist_data = {
            'ACC': data['signal']['wrist']['ACC'],
            'BVP': data['signal']['wrist']['BVP'],
            'EDA': data['signal']['wrist']['EDA'],
            'Temp': data['signal']['wrist']['TEMP']
        }
        
        # Labels are originally sampled at 700Hz (same as chest)
        labels = data['label']

        # # Filter out unwanted labels and corresponding data based on the global VALID_LABELS
        # mask = np.isin(labels, VALID_LABELS)
        # labels = labels[mask]
        # chest_data = {k: v[mask] for k, v in chest_data.items()}
        
        # Data inspection
        print("Chest data shapes (after label masking):")
        for key, value in chest_data.items():
            print(f"  {key}: {value.shape}")
        
        print("Wrist data shapes (original - length will be aligned in feature extraction):")
        for key, value in wrist_data.items():
            print(f"  {key}: {value.shape}")
        
        print(f"Labels shape (after filtering to {VALID_LABELS}): {labels.shape}")
        print(f"Unique labels (after filtering): {np.unique(labels)}")
        print(f"🧪 Raw label length: {len(labels)}, ECG length: {len(chest_data['ECG'])}, Wrist BVP length: {len(wrist_data['BVP'])}")
        return chest_data, wrist_data, labels
        
    except FileNotFoundError:
        print(f"✗ File not found for subject {subject_id}")
        return None, None, None
    except Exception as e:
        print(f"✗ Error loading subject {subject_id}: {e}")
        return None, None, None

In [8]:
# =============================================================================
# 2. DATA PREPROCESSING AND FEATURE EXTRACTION
# =============================================================================
def extract_time_domain_features(signal_array):
    """Extract time-domain features from a 1D physiological signal segment."""
    if len(signal_array) == 0: # Handle empty windows gracefully
        return [0.0] * 12 # Return a list of zeros for all features

    return [
        np.mean(signal_array),
        np.std(signal_array),
        np.var(signal_array),
        np.min(signal_array),
        np.max(signal_array),
        np.median(signal_array),
        skew(signal_array),
        kurtosis(signal_array),
        np.percentile(signal_array, 25),
        np.percentile(signal_array, 75),
        np.ptp(signal_array),  # peak-to-peak
        np.sqrt(np.mean(signal_array**2)),  # RMS
    ]

def extract_frequency_domain_features(input_signal, fs):
    """Extract frequency-domain features from a 1D physiological signal segment."""
    if len(input_signal) < 2: # Welch requires at least 2 points
        return [0.0] * 8 # Return a list of zeros for all features

    # Ensure nperseg is not greater than the signal length
    nperseg_val = min(1024, len(input_signal) // 4)
    if nperseg_val < 4: # welch needs at least 4 points for nperseg
        return [0.0] * 8 # Return a list of zeros for all features

    freqs, psd = signal.welch(input_signal, fs=fs, nperseg=nperseg_val)
    
    features = []
    
    # Frequency bands for physiological signals (example bands, adjust as needed)
    # These bands might be more relevant for specific signals (e.g., HRV for ECG)
    # Here, they are generic examples.
    bands = {
        'very_low': (0.01, 0.04), # Example for very low frequency
        'low': (0.04, 0.15),     # Example for low frequency
        'mid': (0.15, 0.4),      # Example for mid frequency
        'high': (0.4, 2.0)       # Example for high frequency
    }
    
    for band_name, (low, high) in bands.items():
        band_mask = (freqs >= low) & (freqs <= high)
        if np.sum(band_mask) > 0:
            band_power = np.trapz(psd[band_mask], freqs[band_mask])
        else:
            band_power = 0.0 # No frequencies in this band
        features.append(band_power)
    
    # Spectral features
    features.extend([
        np.mean(psd) if len(psd) > 0 else 0.0,
        np.std(psd) if len(psd) > 0 else 0.0,
        freqs[np.argmax(psd)] if len(psd) > 0 else 0.0,  # dominant frequency
        np.sum(psd) if len(psd) > 0 else 0.0  # total power
    ])
    
    return features

def create_sliding_windows(data_stream, labels_stream, sensor_type):
    """
    Create sliding windows from a signal stream and assign majority labels.
    
    Args:
        data_stream (np.ndarray): The 1D or 2D (for ACC) sensor data stream.
        labels_stream (np.ndarray): The 1D label stream (aligned with data_stream's original rate).
        sensor_type (str): 'chest' or 'wrist' to determine window size.
        
    Returns:
        tuple: (windowed_data, windowed_labels)
    """
    
    if sensor_type == 'chest':
        fs = FS_CHEST
        window_size_samples = int(CHEST_WINDOW_SIZE_SEC * fs)
    elif sensor_type == 'wrist':
        fs = FS_WRIST_BVP # Use BVP's sampling rate as the base for windowing logic (highest wrist FS)
        window_size_samples = int(WRIST_WINDOW_SIZE_SEC * fs)
    else:
        raise ValueError(f"Invalid sensor_type: {sensor_type}. Must be 'chest' or 'wrist'.")
        
    step_size = int(window_size_samples * (1 - OVERLAP_RATIO))
    
    windowed_data = []
    windowed_labels = []
    
    # Ensure data_stream and labels_stream are sufficiently long for at least one window
    if len(data_stream) < window_size_samples or len(labels_stream) < window_size_samples:
        print(f"Warning: Data stream length ({len(data_stream)}) or label stream length ({len(labels_stream)}) is too short for window size ({window_size_samples}). Skipping windowing.")
        return np.array([]), np.array([])


    for i in range(0, len(data_stream) - window_size_samples + 1, step_size):
        window_data = data_stream[i : i + window_size_samples]
        window_labels = labels_stream[i : i + window_size_samples]
        
        # Use majority vote for window label
        unique_labels, counts = np.unique(window_labels, return_counts=True)
        
        # Ensure that the mode calculation correctly handles empty counts or all zeros if labels are sparse.
        if len(unique_labels) == 0:
            continue # Skip window if no labels found
        
        majority_label = unique_labels[np.argmax(counts)]
        
        # Only keep windows with consistent labels (>80% same label)
        # and ensure the majority label is one of the VALID_LABELS (now including 0,1,2,3,4)
        if np.max(counts) / len(window_labels) > 0.8 and majority_label in VALID_LABELS:
            windowed_data.append(window_data)
            windowed_labels.append(majority_label)
    
    return np.array(windowed_data), np.array(windowed_labels)

def extract_features_from_windows(windows, fs):
    """Helper to apply feature extraction to a set of windows."""
    features_list = []
    for window in windows:
        feats = extract_time_domain_features(window)
        feats += extract_frequency_domain_features(window, fs=fs)
        features_list.append(feats)
    return np.array(features_list)


In [9]:
 # =============================================================================
# 3. FEATURE-BASED TRADITIONAL ML APPROACH (with simplified wrist branch)
# =============================================================================
def process_subject_features(chest_data, wrist_data, labels, subject_id):
    """
    Processes chest and wrist sensor data for a single subject,
    extracting features and aligning labels.
    Returns: chest_features_df, wrist_features_df, fused_features_df (Pandas DataFrames)
    """
    print(f"\n--- Processing features for Subject {subject_id} ---")
    
    # Initialize empty dataframes for return
    chest_features_df = pd.DataFrame()
    wrist_features_df = pd.DataFrame()
    fused_features_df = pd.DataFrame()

    # --- Process Chest Data ---
    print(f"  Processing Chest Data (fs={FS_CHEST}Hz)...")
    ecg_signal = chest_data['ECG'].flatten()
    eda_signal = chest_data['EDA'].flatten()
    
    # Align labels with the chest data length (already masked in load_wesad_subject)
    labels_chest_aligned = labels[:len(ecg_signal)]
    
    ecg_windows, ecg_labels_win = create_sliding_windows(ecg_signal, labels_chest_aligned, 'chest')
    eda_windows, eda_labels_win = create_sliding_windows(eda_signal, labels_chest_aligned, 'chest')
    
    # Ensure consistent number of windows for chest data
    min_chest_windows = min(len(ecg_windows), len(eda_windows))
    if min_chest_windows == 0:
        print(f"  Warning: No valid chest windows found for subject {subject_id}. Skipping chest features.")
        chest_features_flat = np.array([])
        chest_labels_flat = np.array([])
    else:
        ecg_features = extract_features_from_windows(ecg_windows[:min_chest_windows], fs=FS_CHEST)
        eda_features = extract_features_from_windows(eda_windows[:min_chest_windows], fs=FS_CHEST)
        
        # Concatenate chest features and align labels
        chest_features_flat = np.concatenate((ecg_features, eda_features), axis=1)
        chest_labels_flat = ecg_labels_win[:min_chest_windows] # Assuming labels are consistent for both chest signals

        print(f"    Chest features shape: {chest_features_flat.shape}, Labels shape: {chest_labels_flat.shape}")

        # Create DataFrame for chest data
        num_chest_features = chest_features_flat.shape[1]
        chest_feature_col_names = [f'c_feature_{i}' for i in range(num_chest_features)]
        
        chest_features_df = pd.DataFrame(chest_features_flat, columns=chest_feature_col_names)
        chest_features_df['sid'] = int(subject_id.replace('S', ''))
        chest_features_df['label'] = chest_labels_flat
        # Reorder columns to put 'sid' first
        chest_features_df = chest_features_df[['sid'] + chest_feature_col_names + ['label']]


    # --- Process Wrist Data Only ---
    print(f"  Processing Wrist Data...")
    wrist_eda = wrist_data['EDA'].flatten()
    wrist_bvp = wrist_data['BVP'].flatten()
    wrist_temp = wrist_data['Temp'].flatten()
    wrist_acc_mag = np.linalg.norm(wrist_data['ACC'], axis=1).flatten()

    # Interpolate original (700Hz) labels to match BVP length (64Hz is highest of wrist sensors)
    # Use full original label array for interpolation (assumed to be 700Hz)
    # Ensure no prior filtering of `labels` happened
    if len(labels) != FS_CHEST * len(chest_data['ECG']):
        print(f"⚠️ Label length mismatch for subject {subject_id}. Check raw label integrity.")
        
    # Interpolate from original 700Hz label stream to wrist sampling rate (64Hz for BVP)
    interp_func = interpolate.interp1d(
        np.linspace(0, 1, len(labels)), labels, kind='nearest', fill_value="extrapolate"
    )
    labels_aligned_bvp_rate = interp_func(np.linspace(0, 1, len(wrist_bvp))).astype(int)

    # ⚠️ Print a distribution check before masking
    print(f"Pre-masking wrist labels (raw interpolated): {np.unique(labels_aligned_bvp_rate, return_counts=True)}")

    # Create sliding windows (aligned to BVP sampling rate)
    wrist_eda_windows, wrist_eda_labels_win = create_sliding_windows(wrist_eda, labels_aligned_bvp_rate, 'wrist')
    wrist_bvp_windows, _ = create_sliding_windows(wrist_bvp, labels_aligned_bvp_rate, 'wrist')
    wrist_temp_windows, _ = create_sliding_windows(wrist_temp, labels_aligned_bvp_rate, 'wrist')
    wrist_acc_windows, _ = create_sliding_windows(wrist_acc_mag, labels_aligned_bvp_rate, 'wrist')

    # Make sure all modalities have same number of windows
    min_wrist_windows = min(len(wrist_eda_windows), len(wrist_bvp_windows), 
                            len(wrist_temp_windows), len(wrist_acc_windows),
                            len(wrist_eda_labels_win))

    if min_wrist_windows == 0:
        print(f"  Warning: No valid wrist windows found for subject {subject_id}. Skipping wrist features.")
        wrist_features_flat = np.array([])
        wrist_labels_flat = np.array([])
    else:
        # Feature extraction
        eda_features_wrist = extract_features_from_windows(wrist_eda_windows[:min_wrist_windows], fs=FS_WRIST_EDA_TEMP)
        bvp_features_wrist = extract_features_from_windows(wrist_bvp_windows[:min_wrist_windows], fs=FS_WRIST_BVP)
        temp_features_wrist = extract_features_from_windows(wrist_temp_windows[:min_wrist_windows], fs=FS_WRIST_EDA_TEMP)
        acc_features_wrist = extract_features_from_windows(wrist_acc_windows[:min_wrist_windows], fs=FS_WRIST_ACC)

        # Concatenate
        wrist_features_flat = np.concatenate((eda_features_wrist, bvp_features_wrist, 
                                            temp_features_wrist, acc_features_wrist), axis=1)
        wrist_labels_flat = wrist_eda_labels_win[:min_wrist_windows]

        # ✅ NOW apply label filtering (after everything is ready)
        valid_mask = np.isin(wrist_labels_flat, VALID_LABELS)  # skip 0 ("Not defined")
        wrist_features_flat = wrist_features_flat[valid_mask]
        wrist_labels_flat = wrist_labels_flat[valid_mask]

        # Create DataFrame
        num_wrist_features = wrist_features_flat.shape[1]
        wrist_feature_col_names = [f'w_feature_{i}' for i in range(num_wrist_features)]

        wrist_features_df = pd.DataFrame(wrist_features_flat, columns=wrist_feature_col_names)
        wrist_features_df['sid'] = int(subject_id.replace('S', ''))
        wrist_features_df['label'] = wrist_labels_flat
        wrist_features_df = wrist_features_df[['sid'] + wrist_feature_col_names + ['label']]


    # --- Merge Chest and Wrist Features (for Fused Output) ---
    # Only fuse if both chest and wrist data were successfully processed
    if len(chest_features_flat) > 0 and len(wrist_features_flat) > 0:
        # Take the minimum number of windows to ensure alignment for fusion
        min_windows_overall = min(len(chest_features_flat), len(wrist_features_flat))
        
        fused_features_np = np.concatenate((chest_features_flat[:min_windows_overall], 
                                            wrist_features_flat[:min_windows_overall]), axis=1)
        
        # For labels, use the chest labels as the primary source for the merged dataset
        fused_labels_np = chest_labels_flat[:min_windows_overall]

        # Add subject ID column
        subject_id_col = np.full((fused_features_np.shape[0], 1), int(subject_id.replace('S', '')))
        fused_features_with_id_np = np.concatenate((subject_id_col, fused_features_np), axis=1)

        print(f"  Fused features shape: {fused_features_with_id_np.shape}, Fused labels shape: {fused_labels_np.shape}")
        
        # Handle NaN values after feature extraction
        nan_count = np.isnan(fused_features_with_id_np).sum()
        if nan_count > 0:
            print(f"  Replacing {nan_count} NaN values with 0 in fused features for subject {subject_id}...")
            fused_features_with_id_np = np.nan_to_num(fused_features_with_id_np)
        
        # Create DataFrame for fused data
        num_fused_features = fused_features_with_id_np.shape[1] - 1 # -1 for 'sid' column
        fused_feature_col_names = [f'feature_{i}' for i in range(num_fused_features)]
        
        fused_features_df = pd.DataFrame(fused_features_with_id_np, 
                                        columns=['sid'] + fused_feature_col_names)
        fused_features_df['label'] = fused_labels_np

    else:
        print(f"  Skipping feature fusion for Subject {subject_id} due to insufficient data from one or both modalities.")
    
    print(wrist_features_df['label'].value_counts().sort_index().rename(index=LABEL_MAP))
    
    return chest_features_df, wrist_features_df, fused_features_df

In [10]:
# --- Main Processing Loop ---
all_chest_features_df = []
all_wrist_features_df = []
all_fused_features_df = []

print("\n" + "="*50)
print("STARTING WESAD DATA PROCESSING ACROSS ALL SUBJECTS")
print("="*50)

for sub_id in SUBJECT_IDS:
    chest_data, wrist_data, labels = load_wesad_subject(sub_id)
    
    if chest_data is not None and wrist_data is not None and labels is not None:
        # Call the updated function that returns three DataFrames
        subject_chest_df, subject_wrist_df, subject_fused_df = process_subject_features(chest_data, wrist_data, labels, sub_id)
        
        if not subject_chest_df.empty:
            all_chest_features_df.append(subject_chest_df)
        if not subject_wrist_df.empty:
            all_wrist_features_df.append(subject_wrist_df)
        if not subject_fused_df.empty:
            all_fused_features_df.append(subject_fused_df)
            print(f"  Aggregated {len(subject_fused_df)} fused windows for Subject {sub_id}.")
        else:
            print(f"  No fused features generated for Subject {sub_id}.")
    else:
        print(f"  Skipping Subject {sub_id} due to loading errors.")

# Concatenate all subjects' data into separate DataFrames
if all_chest_features_df:
    final_chest_dataset = pd.concat(all_chest_features_df, ignore_index=True)
    print("\n" + "="*50)
    print("ALL CHEST DATA MERGED")
    print("="*50)
    print(f"Final merged chest dataset shape: {final_chest_dataset.shape}")
    print(f"Final merged chest dataset label distribution:\n{final_chest_dataset['label'].value_counts().sort_index().rename(index=LABEL_MAP)}")
    output_filename_chest = 'wesad_processed_chest_features.pkl'
    final_chest_dataset.to_pickle(output_filename_chest)
    print(f"\n✓ Processed chest data saved to '{output_filename_chest}'")
else:
    print("\n✗ No chest data processed and merged.")

if all_wrist_features_df:
    final_wrist_dataset = pd.concat(all_wrist_features_df, ignore_index=True)
    print("\n" + "="*50)
    print("ALL WRIST DATA MERGED")
    print("="*50)
    print(f"Final merged wrist dataset shape: {final_wrist_dataset.shape}")
    print(f"Final merged wrist dataset label distribution:\n{final_wrist_dataset['label'].value_counts().sort_index().rename(index=LABEL_MAP)}")
    output_filename_wrist = 'wesad_processed_wrist_features.pkl'
    final_wrist_dataset.to_pickle(output_filename_wrist)
    print(f"\n✓ Processed wrist data saved to '{output_filename_wrist}'")
else:
    print("\n✗ No wrist data processed and merged.")

if all_fused_features_df:
    final_fused_dataset = pd.concat(all_fused_features_df, ignore_index=True)
    print("\n" + "="*50)
    print("ALL FUSED DATA MERGED")
    print("="*50)
    print(f"Final merged fused dataset shape: {final_fused_dataset.shape}")
    print(f"Final merged fused dataset label distribution:\n{final_fused_dataset['label'].value_counts().sort_index().rename(index=LABEL_MAP)}")
    output_filename_fused = 'wesad_processed_fused_features.pkl'
    final_fused_dataset.to_pickle(output_filename_fused)
    print(f"\n✓ Processed fused data saved to '{output_filename_fused}'")
else:
    print("\n✗ No fused data processed and merged.")



STARTING WESAD DATA PROCESSING ACROSS ALL SUBJECTS
✓ Successfully loaded subject S2
Chest data shapes (after label masking):
  ACC: (4255300, 3)
  ECG: (4255300, 1)
  EMG: (4255300, 1)
  EDA: (4255300, 1)
  Temp: (4255300, 1)
  Resp: (4255300, 1)
Wrist data shapes (original - length will be aligned in feature extraction):
  ACC: (194528, 3)
  BVP: (389056, 1)
  EDA: (24316, 1)
  Temp: (24316, 1)
Labels shape (after filtering to [0, 1, 2, 3, 4]): (4255300,)
Unique labels (after filtering): [0 1 2 3 4 6 7]
🧪 Raw label length: 4255300, ECG length: 4255300, Wrist BVP length: 389056

--- Processing features for Subject S2 ---
  Processing Chest Data (fs=700Hz)...
    Chest features shape: (2361, 40), Labels shape: (2361,)
  Processing Wrist Data...
⚠️ Label length mismatch for subject S2. Check raw label integrity.
Pre-masking wrist labels (raw interpolated): (array([0, 1, 2, 3, 4, 6, 7]), array([195904,  73216,  39360,  23168,  49152,   4160,   4096]))
  Fused features shape: (301, 121), 

In [None]:
# =============================================================================
# 4. TRADITIONAL ML MODEL TRAINING
# =============================================================================

def train_traditional_ml_model(X, y):
    """Train and evaluate traditional ML model"""
    
    print("\n" + "="*50)
    print("TRAINING TRADITIONAL ML MODEL")
    print("="*50)
    
    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    print(f"Training set shape: {X_train.shape}")
    print(f"Test set shape: {X_test.shape}")
    print(f"Training labels shape: {y_train.shape}")
    print(f"Test labels shape: {y_test.shape}")
    
    # Feature scaling
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    print(f"Scaled training data type: {X_train_scaled.dtype}")
    print(f"Scaled training data shape: {X_train_scaled.shape}")
    
    # Train Random Forest
    rf_model = RandomForestClassifier(
        n_estimators=100,
        max_depth=10,
        random_state=42,
        n_jobs=-1
    )
    
    print("Training Random Forest...")
    rf_model.fit(X_train_scaled, y_train)
    
    # Predictions
    y_pred = rf_model.predict(X_test_scaled)
    
    # Evaluation
    accuracy = accuracy_score(y_test, y_pred)
    print(f"\nRandom Forest Accuracy: {accuracy:.4f}")
    
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred))
    
    return rf_model, scaler, (X_train_scaled, X_test_scaled, y_train, y_test)

# Train traditional ML model
if 'features' in locals():
    rf_model, scaler, (X_train, X_test, y_train, y_test) = train_traditional_ml_model(features, feature_labels)
