In [None]:
import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt, stft, welch, hilbert
from scipy.stats import linregress
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import make_pipeline
from mne.decoding import CSP  # Requires mne-python package
from mne.preprocessing import ICA as mne_ICA
import torch
import torch.nn as nn
from scipy.signal import gausspulse, chirp
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import pywt
from scipy import stats
import os
from sklearn.preprocessing import StandardScaler, LabelEncoder, RobustScaler
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.feature_selection import SelectKBest, f_classif
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Add new imports for feature extraction
from sklearn.linear_model import LinearRegression
from scipy.signal import spectrogram
from pywt import wavedec


def load_index_csvs_MI(base_path):
    train_df = pd.read_csv(os.path.join(base_path, 'train.csv'))
    validation_df = pd.read_csv(os.path.join(base_path, 'validation.csv'))
    test_df = pd.read_csv(os.path.join(base_path, 'test.csv'))
    
    # Create separate label encoders for MI and SSVEP tasks
    le_mi = LabelEncoder()
    
    # Fit encoders on training data only and transform all splits consistently
    if 'label' in train_df.columns:
        # MI task encoding
        mi_train_labels = train_df[train_df['task'] == 'MI']['label']
        if len(mi_train_labels) > 0:
            le_mi.fit(mi_train_labels)
            
            # Transform MI labels in all splits
            for df in [train_df, validation_df]:
                if 'label' in df.columns:
                    mi_mask = df['task'] == 'MI'
                    if mi_mask.any():
                        df.loc[mi_mask, 'label'] = le_mi.transform(df.loc[mi_mask, 'label'])
        
        

    return train_df, validation_df, test_df, le_mi
# Configuration
SELECTED_CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
SAMPLING_RATE = 250  # Hz (adjust if different)
TRIAL_LENGTH = 2250  # Samples

class EEGPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self, filter_low=8, filter_high=30):
        self.filter_low = filter_low
        self.filter_high = filter_high
        self.channel_indices = None
        
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        """
        X: List of DataFrames (each from EEGdata.csv)
        Returns: Normalized epochs (n_trials, 2250, 8)
        """
        all_epochs = []
        
        for df in X:
            # Step 1: Channel selection
            eeg_data = df[SELECTED_CHANNELS].values
            
            # Step 2: Band-pass filter
            nyquist = 0.5 * SAMPLING_RATE
            low = self.filter_low / nyquist
            high = self.filter_high / nyquist
            b, a = butter(5, [low, high], btype='band')
            filtered = filtfilt(b, a, eeg_data, axis=0)
            
            # Step 3: Epoch extraction
            n_trials = len(df) // TRIAL_LENGTH
            for i in range(n_trials):
                start_idx = i * TRIAL_LENGTH
                epoch = filtered[start_idx:start_idx + TRIAL_LENGTH]
                
                # Step 4: Per-channel normalization
                normalized = (epoch - epoch.mean(axis=0)) / (epoch.std(axis=0) + 1e-8)
                all_epochs.append(normalized)
                
        return np.array(all_epochs)

# Feature Extraction Options
class CSPFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, n_components=4):
        self.csp = CSP(n_components=n_components, reg=None, log=True)
        
    def fit(self, X, y):
        # X shape: (n_trials, time_points, channels)
        X_csp = X.transpose(0, 2, 1)  # MNE expects (trials, channels, time)
        self.csp.fit(X_csp, y)
        return self
        
    def transform(self, X):
        X_csp = X.transpose(0, 2, 1)
        return self.csp.transform(X_csp)

class STFTFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, nperseg=250, noverlap=125):
        self.nperseg = nperseg
        self.noverlap = noverlap
        
    def fit(self, X, y=None):
        return self
        
    def transform(self, X):
        features = []
        for trial in X:
            # Compute power spectral density per channel
            trial_features = []
            for channel in range(trial.shape[1]):
                f, t, Zxx = stft(trial[:, channel], 
                                fs=SAMPLING_RATE,
                                nperseg=self.nperseg,
                                noverlap=self.noverlap)
                psd = np.abs(Zxx) ** 2
                
                # Extract alpha (8-12Hz) and beta (12-30Hz) bands
                alpha = psd[(f >= 8) & (f < 12)].mean()
                beta = psd[(f >= 12) & (f <= 30)].mean()
                trial_features.extend([alpha, beta])
                
            features.append(trial_features)
        return np.array(features)

class RawEEGFeatures(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self
        
    def transform(self, X):
        return X.reshape(X.shape[0], -1)  # Flatten trials
    
    
class AutoRegressionFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, order=5):
        self.order = order
        
    def fit(self, X, y=None):
        return self
        
    def transform(self, X):
        features = []
        for trial in X:
            trial_features = []
            for channel in range(trial.shape[1]):
                channel_data = trial[:, channel]
                coeffs = []
                for lag in range(1, self.order + 1):
                    # Shift the signal by the lag
                    shifted = np.roll(channel_data, lag)
                    shifted[:lag] = 0
                    
                    # Compute correlation between original and shifted
                    slope, _, _, _, _ = linregress(shifted[lag:], channel_data[lag:])
                    coeffs.append(slope)
                trial_features.extend(coeffs)
            features.append(trial_features)
        return np.array(features)

class ICAFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, n_components=4):
        self.n_components = n_components
        
    def fit(self, X, y=None):
        return self
        
    def transform(self, X):
        features = []
        for trial in X:
            # Transpose to (channels, time)
            trial_data = trial.T
            
            # Apply ICA
            ica = mne_ICA(n_components=self.n_components, method='fastica')
            ica.fit(trial_data[np.newaxis, :, :])
            
            # Get the components
            components = ica.get_components().T  # (n_components, channels)
            
            # Flatten components as features
            features.append(components.flatten())
        return np.array(features)

def higuchi_fd(x, kmax=10):
    """Compute Higuchi Fractal Dimension of a time series"""
    n = len(x)
    lk = np.zeros(kmax)
    x = np.asarray(x)
    
    for k in range(1, kmax+1):
        lm = np.zeros((k,))
        for m in range(k):
            ll = 0
            max_i = int(np.floor((n - m - 1) / k))
            for i in range(1, max_i):
                ll += abs(x[m + i*k] - x[m + (i-1)*k])
            ll = ll * (n - 1) / (max_i * k)
            lm[m] = np.log(ll / k)
        lk[k-1] = np.mean(lm)
    
    hfd = np.polyfit(np.log(range(1, kmax+1)), lk, 1)[0]
    return hfd

class HiguchiFDFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, kmax=10):
        self.kmax = kmax
        
    def fit(self, X, y=None):
        return self
        
    def transform(self, X):
        features = []
        for trial in X:
            trial_features = []
            for channel in range(trial.shape[1]):
                hfd = higuchi_fd(trial[:, channel], self.kmax)
                trial_features.append(hfd)
            features.append(trial_features)
        return np.array(features)

class FBCSPFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, n_components=4, freq_bands=None):
      
        self.n_components = n_components
        if freq_bands is None:
            # Default frequency bands for motor imagery
            self.freq_bands = [(8, 12), (12, 16), (16, 24), (24, 30)]
        else:
            self.freq_bands = freq_bands
        self.csp_models = []
        
    def fit(self, X, y):
        # Initialize CSP for each frequency band
        self.csp_models = []
        for low, high in self.freq_bands:
            # Bandpass filter the data
            filtered = self._bandpass_filter(X, low, high)
            
            # Create and fit CSP
            csp = CSP(n_components=self.n_components, reg=None, log=True)
            csp.fit(filtered.transpose(0, 2, 1), y)  # MNE expects (trials, channels, time)
            self.csp_models.append((low, high, csp))
            
        return self
        
    def transform(self, X):
        features = []
        for low, high, csp in self.csp_models:
            # Filter and extract CSP features
            filtered = self._bandpass_filter(X, low, high)
            csp_feats = csp.transform(filtered.transpose(0, 2, 1))
            features.append(csp_feats)
            
        # Concatenate features from all bands
        return np.concatenate(features, axis=1)
    
    def _bandpass_filter(self, X, low, high):
        nyquist = 0.5 * SAMPLING_RATE
        low_norm = low / nyquist
        high_norm = high / nyquist
        b, a = butter(5, [low_norm, high_norm], btype='band')
        
        filtered = np.zeros_like(X)
        for i in range(X.shape[0]):  # Filter each trial
            for j in range(X.shape[2]):  # Filter each channel
                filtered[i, :, j] = filtfilt(b, a, X[i, :, j])
                
        return filtered

class FBRFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, freq_bands=None, metric='riemann'):
        if freq_bands is None:
            # Default frequency bands for motor imagery
            self.freq_bands = [(8, 12), (12, 16), (16, 24), (24, 30)]
        else:
            self.freq_bands = freq_bands
        self.metric = metric
        self.cov_estimators = []
        self.ts_transformers = []
        
    def fit(self, X, y):
        # Initialize covariances and tangent space for each frequency band
        self.cov_estimators = []
        self.ts_transformers = []
        
        for low, high in self.freq_bands:
            # Bandpass filter the data
            filtered = self._bandpass_filter(X, low, high)
            
            # Compute covariance matrices
            cov = Covariances(estimator='lwf').fit(filtered.transpose(0, 2, 1))
            
            # Fit tangent space mapping
            ts = TangentSpace(metric=self.metric).fit(cov.transform(filtered.transpose(0, 2, 1)))
            
            self.cov_estimators.append((low, high, cov))
            self.ts_transformers.append((low, high, ts))
            
        return self
        
    def transform(self, X):
        features = []
        for (low, high, cov), (_, _, ts) in zip(self.cov_estimators, self.ts_transformers):
            # Filter, compute covariances, and project to tangent space
            filtered = self._bandpass_filter(X, low, high)
            covs = cov.transform(filtered.transpose(0, 2, 1))
            ts_feats = ts.transform(covs)
            features.append(ts_feats)
            
        # Concatenate features from all bands
        return np.concatenate(features, axis=1)
    
    def _bandpass_filter(self, X, low, high):
      
        nyquist = 0.5 * SAMPLING_RATE
        low_norm = low / nyquist
        high_norm = high / nyquist
        b, a = butter(5, [low_norm, high_norm], btype='band')
        
        filtered = np.zeros_like(X)
        for i in range(X.shape[0]):  # Filter each trial
            for j in range(X.shape[2]):  # Filter each channel
                filtered[i, :, j] = filtfilt(b, a, X[i, :, j])
                
        return filtered
    
class MatchingPursuit(BaseEstimator, TransformerMixin):
    def __init__(self, n_atoms=20, max_iter=100, epsilon=0.1):
        self.n_atoms = n_atoms
        self.max_iter = max_iter
        self.epsilon = epsilon
        self.dictionary_ = None
        
    def _create_dictionary(self, signal_length):
        """Create a dictionary of Gabor atoms"""
        t = np.linspace(0, 1, signal_length)
        dictionary = []
        
        # Create Gabor atoms (windowed sinusoids)
        for scale in np.linspace(0.1, 1.0, 5):
            for freq in np.linspace(1, 30, 10):  # EEG relevant frequencies
                for position in np.linspace(0, signal_length-1, 10, dtype=int):
                    atom = gausspulse(t - t[position], fc=freq, bw=scale)
                    atom /= np.linalg.norm(atom)  # Normalize
                    dictionary.append(atom)
        
        # Add some chirp atoms
        for f0 in [1, 5, 10]:
            for f1 in [20, 30, 40]:
                atom = chirp(t, f0=f0, f1=f1, t1=1, method='linear')
                atom /= np.linalg.norm(atom)
                dictionary.append(atom)
                
        self.dictionary_ = np.array(dictionary)
        return self
    
    def fit(self, X, y=None):
        """Learn the dictionary based on input signal length"""
        signal_length = X.shape[1] if len(X.shape) > 1 else len(X)
        self._create_dictionary(signal_length)
        return self
    
    def transform(self, X):
        """Decompose signals using matching pursuit"""
        if self.dictionary_ is None:
            self.fit(X)
            
        results = []
        for signal in X:
            residual = signal.copy()
            coefficients = np.zeros(len(self.dictionary_))
            selected_atoms = []
            
            for _ in range(self.n_atoms):
                if np.linalg.norm(residual) < self.epsilon:
                    break
                    
                # Find atom with maximum correlation
                correlations = np.abs(np.dot(self.dictionary_, residual))
                best_idx = np.argmax(correlations)
                best_atom = self.dictionary_[best_idx]
                
                # Update coefficients and residual
                coeff = np.dot(best_atom, residual)
                coefficients[best_idx] += coeff
                residual -= coeff * best_atom
                
                selected_atoms.append(best_idx)
            
            # Extract features from the decomposition
            if len(selected_atoms) > 0:
                top_atoms = selected_atoms[:5]  # Take top 5 atoms
                atom_features = []
                for idx in top_atoms:
                    atom = self.dictionary_[idx]
                    # Features from each atom: position, frequency, scale, coefficient
                    center = np.argmax(np.abs(atom))
                    freq = np.argmax(np.abs(np.fft.fft(atom)[:len(atom)//2]))
                    scale = np.std(atom)
                    coeff = coefficients[idx]
                    atom_features.extend([center, freq, scale, coeff])
                
                # Pad with zeros if not enough atoms were found
                if len(atom_features) < 20:  # 5 atoms * 4 features
                    atom_features.extend([0]*(20 - len(atom_features)))
                
                results.append(atom_features[:20])  # Return fixed-size feature vector
            else:
                results.append(np.zeros(20))
        
        return np.array(results)
# ============================
# Step 4: Enhanced Deep Learning Models 
# ============================


# ------------------------
# 1) Define your PyTorch deep‐models
# ------------------------
class EEGNet(nn.Module):
    def __init__(self, channels, samples, num_classes):
        super().__init__()
        # minimal EEGNet skeleton – adapt kernel sizes & paddings to your data
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 64), padding=(0, 32))
        self.bn1   = nn.BatchNorm2d(16)
        self.depthwise = nn.Conv2d(16, 32, kernel_size=(channels, 1), groups=16)
        self.bn2       = nn.BatchNorm2d(32)
        self.pool       = nn.AvgPool2d(kernel_size=(1, 4))
        self.dropout    = nn.Dropout(0.25)
        # classifier
        self.classifier = nn.Linear(32 * ((samples // 4)), num_classes)

    def forward(self, x):
        # x: (batch, 1, channels, samples)
        x = F.elu(self.bn1(self.conv1(x)))
        x = F.elu(self.bn2(self.depthwise(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

class DeepConvNet(nn.Module):
    def __init__(self, channels, samples, num_classes):
        super().__init__()
        # minimal DeepConvNet skeleton
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(1, 25, (1, 5)), 
            nn.Conv2d(25, 25, (channels, 1)),
            nn.BatchNorm2d(25), 
            nn.ELU(), 
            nn.MaxPool2d((1, 2)), 
            nn.Dropout(0.5),

            nn.Conv2d(25, 50, (1, 5)), 
            nn.BatchNorm2d(50), 
            nn.ELU(),
            nn.MaxPool2d((1, 2)), 
            nn.Dropout(0.5),

            nn.Conv2d(50, 100, (1, 5)), 
            nn.BatchNorm2d(100), 
            nn.ELU(),
            nn.MaxPool2d((1, 2)), 
            nn.Dropout(0.5),

            nn.Conv2d(100, 200, (1, 5)), 
            nn.BatchNorm2d(200), 
            nn.ELU(),
            nn.MaxPool2d((1, 2)), 
            nn.Dropout(0.5),
        )
        
        # Calculate the correct output size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, channels, samples)
            dummy_output = self.conv_blocks(dummy_input)
            out_feat = dummy_output.view(1, -1).shape[1]
            
        self.classifier = nn.Linear(out_feat, num_classes)

    def forward(self, x):
        x = self.conv_blocks(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional=True):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            batch_first=True, bidirectional=bidirectional, dropout=0.3)
        factor = 2 if bidirectional else 1
        self.classifier = nn.Linear(hidden_size * factor, num_classes)

    def forward(self, x):
        # x: (batch, seq_len, input_size)
        out, _ = self.lstm(x)
        # take last time‐step
        out = out[:, -1, :]
        return self.classifier(out)


class CNNLSTM(nn.Module):
    def __init__(self, n_channels=8, n_classes=2, n_samples=1000, dropout_rate=0.5):
        super(CNNLSTM, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(1, 16), padding=(0, 8), padding_mode='zeros'),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Dropout(dropout_rate/2),
            
            nn.Conv2d(32, 32, kernel_size=(n_channels, 1), groups=32),
            nn.Conv2d(32, 64, kernel_size=1),
            nn.BatchNorm2d(64),
            nn.ELU(),
            
            nn.Conv2d(64, 64, kernel_size=(1, 8), padding=(0, 4)),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 4)),
            nn.Dropout(dropout_rate),
            
            nn.AvgPool2d(kernel_size=(1, 2)),
        )
        
        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=dropout_rate if 2 > 1 else 0
        )
        
        self.attention = nn.Sequential(
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Linear(128, 1),
            nn.Softmax(dim=1))
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, n_classes))
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='elu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight_ih' in name:
                        nn.init.xavier_normal_(param.data)
                    elif 'weight_hh' in name:
                        nn.init.orthogonal_(param.data)
                    elif 'bias' in name:
                        param.data.fill_(0)
                        if len(param) > 1:
                            param.data[1::4].fill_(1)
    
    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
            
        x = self.cnn(x)
        x = x.squeeze(2).permute(0, 2, 1)
        lstm_out, _ = self.lstm(x)
        attention_weights = self.attention(lstm_out)
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)
        return self.classifier(context_vector)

# ------------------------
# 2) The unified trainer
# ------------------------
class ModelTrainer:
    def __init__(self,
                 model_type: str,
                 input_shape,
                 num_classes: int,
                 device: str = None,
                 **kwargs):
        """
        model_type: one of ['LDA','SVM','RF','EEGNet','DeepConvNet','BiLSTM']
        input_shape: for traditional: (n_features,). For EEGNet/DeepConvNet: (channels, samples).
                     For BiLSTM: (seq_len, feature_dim).
        num_classes: number of classes.
        kwargs: extra hyperparams for each model type.
        """
        self.model_type = model_type
        self.num_classes = num_classes

        # device for pytorch
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        # init model
        if model_type == 'LDA':
            self.model = LinearDiscriminantAnalysis(**kwargs)
        elif model_type == 'SVM':
            self.model = SVC(probability=True, **kwargs)
        elif model_type == 'RF':
            self.model = RandomForestClassifier(**kwargs)
        elif model_type == 'EEGNet':
            ch, samp = input_shape
            self.model = EEGNet(ch, samp, num_classes).to(self.device)
        elif model_type == 'DeepConvNet':
            ch, samp = input_shape
            self.model = DeepConvNet(ch, samp, num_classes).to(self.device)
        elif model_type == 'BiLSTM':
            seq_len, feat = input_shape
            hidden = kwargs.get('hidden_size', 64)
            layers = kwargs.get('num_layers', 1)
            self.model = BiLSTM(feat, hidden, layers, num_classes).to(self.device)
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        # only used for sklearn models
        self.le = LabelEncoder() if model_type in ('LDA','SVM','RF') else None

    def fit(self, X_train, y_train, X_val=None, y_val=None, **train_kwargs):
        """
        X_train, X_val: numpy arrays or PyTorch tensors
        y_train, y_val: labels
        """
        if self.model_type in ('LDA','SVM','RF'):
            # encode labels
            y_enc = self.le.fit_transform(y_train)
            self.model.fit(X_train, y_enc)
            if y_val is not None:
                val_pred = self.model.predict(X_val)
                acc = accuracy_score(self.le.transform(y_val), val_pred)
                print(f"Validation accuracy: {acc:.4f}")
        else:
            # PyTorch training loop
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(self.model.parameters(), lr=train_kwargs.get('lr', 1e-3))
            scheduler = train_kwargs.get('scheduler', None)

            # build dataloaders
            bs = train_kwargs.get('batch_size', 32)
            train_ds = torch.utils.data.TensorDataset(
                torch.tensor(X_train, dtype=torch.float32),
                torch.tensor(y_train, dtype=torch.long))
            train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True)

            val_loader = None
            if X_val is not None:
                val_ds = torch.utils.data.TensorDataset(
                    torch.tensor(X_val, dtype=torch.float32),
                    torch.tensor(y_val, dtype=torch.long))
                val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False)

            # train
            best_acc = 0
            for epoch in range(train_kwargs.get('epochs', 20)):
                self.model.train()
                for xb, yb in train_loader:
                    xb, yb = xb.to(self.device), yb.to(self.device)
                    optimizer.zero_grad()
                    logits = self.model(xb)
                    loss = criterion(logits, yb)
                    loss.backward()
                    optimizer.step()

                if val_loader:
                    self.model.eval()
                    all_preds, all_labels = [], []
                    with torch.no_grad():
                        for xb, yb in val_loader:
                            xb = xb.to(self.device)
                            logits = self.model(xb)
                            preds = logits.argmax(1).cpu().numpy()
                            all_preds += preds.tolist()
                            all_labels += yb.numpy().tolist()
                    acc = accuracy_score(all_labels, all_preds)
                    print(f"[{self.model_type}] Epoch {epoch+1} val acc: {acc:.4f}")
                    if acc > best_acc:
                        best_acc = acc
                        torch.save(self.model.state_dict(), f"best_{self.model_type}.pt")
            # restore best
            if val_loader:
                self.model.load_state_dict(torch.load(f"best_{self.model_type}.pt"))
                print(f"Best val acc for {self.model_type}: {best_acc:.4f}")

    def predict(self, X_test):
        if self.model_type in ('LDA','SVM','RF'):
            proba = self.model.predict_proba(X_test)
            preds = self.le.inverse_transform(proba.argmax(1))
            return preds, proba
        else:
            self.model.eval()
            with torch.no_grad():
                X = torch.tensor(X_test, dtype=torch.float32).to(self.device)
                logits = self.model(X)
                probs = F.softmax(logits, dim=1).cpu().numpy()
                preds = np.argmax(probs, axis=1)
            return preds, probs


import os
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, f_classif
from torch.utils.data import DataLoader

# assume EEGPreprocessor, CSPFeatures, STFTFeatures, RawEEGFeatures,
# and ModelTrainer are already imported

def load_raw_eeg(df, base_path):
    raws = []
    for idx, row in df.iterrows():
        task = row['task']
        subject = row['subject_id']
        session = row['trial_session']
        fpath = os.path.join(base_path, task, 'train' if row['id'] <= 4800 else 'validation' if row['id'] <= 4900 else 'test', subject, str(session), 'EEGdata.csv')
        eeg = pd.read_csv(fpath)

        # Extract correct trial slice
        trial = int(row['trial'])
        if task == 'MI':
            samples_per_trial = 2250
        else:
            samples_per_trial = 1750
        start = (trial - 1) * samples_per_trial
        end = start + samples_per_trial
        raws.append(eeg.iloc[start:end])
    return raws


def main_MI(model_type='LDA',
            feature_type='CSP',
            n_components=4,
            k_best=100,
            trainer_kwargs=None,
            order=5,
            kmax=10):
    """
    model_type: 'LDA','SVM','RF','EEGNet','DeepConvNet','BiLSTM'
    feature_type: 'CSP','STFT','RAW'
    """
    trainer_kwargs = trainer_kwargs or {}
    base_path = '/kaggle/input/mtcaic3'

    # 1) Load index CSVs + label encoders
    print("Loading index files...")
    train_df, val_df, test_df, le_mi = load_index_csvs_MI(base_path)
    # Keep only MI rows
    train_df = train_df[train_df['task'] == 'MI'].copy()
    val_df = val_df[val_df['task'] == 'MI'].copy()
    test_df = test_df[test_df['task'] == 'MI'].copy()

    # Ensure labels are integers
    train_df['label'] = train_df['label'].astype(int)
    val_df['label'] = val_df['label'].astype(int)

    num_classes = len(le_mi.classes_)
    print(f"  Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

    # 2) Read raw EEG data
    X_tr_raw = load_raw_eeg(train_df, base_path); y_tr = train_df.label.values
    X_val_raw = load_raw_eeg(val_df, base_path);   y_val = val_df.label.values
    X_te_raw = load_raw_eeg(test_df, base_path)

    # 3) Preprocess → epochs
    preproc = EEGPreprocessor()
    X_tr_ep = preproc.fit_transform(X_tr_raw)
    X_val_ep = preproc.transform(X_val_raw)
    X_te_ep = preproc.transform(X_te_raw)
    print(f"  Epochs: train={len(X_tr_ep)}, val={len(X_val_ep)}, test={len(X_te_ep)}")

    # 4) Feature extraction
    if feature_type == 'CSP':
        feat_ext = CSPFeatures(n_components=n_components)
        X_tr_ft = feat_ext.fit_transform(X_tr_ep, y_tr)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)

    elif feature_type == 'STFT':
        feat_ext = STFTFeatures()
        X_tr_ft = feat_ext.fit_transform(X_tr_ep)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)

    elif feature_type == 'RAW':
        feat_ext = RawEEGFeatures()
        X_tr_ft = feat_ext.transform(X_tr_ep)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)
        
    elif feature_type == 'AutoRegressionFeatures':
        feat_ext = AutoRegressionFeatures(order=order)
        X_tr_ft = feat_ext.transform(X_tr_ep)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)

    elif feature_type =='HiguchiFDFeatures':
        feat_ext = HiguchiFDFeatures(kmax=kmax)
        X_tr_ft = feat_ext.transform(X_tr_ep)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)

    elif feature_type =='ICAFeatures':
        feat_ext = CSPFeatures(n_components=n_components)
        X_tr_ft = feat_ext.fit_transform(X_tr_ep, y_tr)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)
    elif feature_type == 'FBCSP':  # Add this option
        feat_ext = FBCSPFeatures(n_components=n_components)
        X_tr_ft = feat_ext.fit_transform(X_tr_ep, y_tr)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)

    elif feature_type == 'FBR':  # Add this option
        feat_ext = FBRFeatures()
        X_tr_ft = feat_ext.fit_transform(X_tr_ep, y_tr)
        X_val_ft = feat_ext.transform(X_val_ep)
        X_te_ft = feat_ext.transform(X_te_ep)
    # In the feature extraction section of main_MI():
    elif feature_type == 'MatchingPursuit':
        feat_ext = MatchingPursuit(n_atoms=20)
        X_tr_ft = feat_ext.fit_transform(X_tr_ep.reshape(X_tr_ep.shape[0], -1))  # Flatten time×channels
        X_val_ft = feat_ext.transform(X_val_ep.reshape(X_val_ep.shape[0], -1))
        X_te_ft = feat_ext.transform(X_te_ep.reshape(X_te_ep.shape[0], -1))

    else:
        raise ValueError("feature_type must be one of 'CSP','STFT','RAW'")

    # 5) Scale + SelectKBest (only for sklearn models)
    if model_type in ('LDA','SVM','RF'):
        scaler = StandardScaler()
        X_tr_sc = scaler.fit_transform(X_tr_ft)
        X_val_sc = scaler.transform(X_val_ft)
        X_te_sc = scaler.transform(X_te_ft)

        k = min(k_best, X_tr_sc.shape[1])
        selector = SelectKBest(f_classif, k=k)
        X_tr_sel = selector.fit_transform(X_tr_sc, y_tr)
        X_val_sel = selector.transform(X_val_sc)
        X_te_sel = selector.transform(X_te_sc)
        print(f"  Selected features: {X_tr_sel.shape[1]} / {X_tr_sc.shape[1]}")

        # 6a) Train sklearn model
        trainer = ModelTrainer(model_type,
                               input_shape=(X_tr_sel.shape[1],),
                               num_classes=num_classes)
        trainer.fit(X_tr_sel, y_tr, X_val_sel, y_val)
        trainer.fit(X_tr_sel, y_tr, X_val_sel, y_val)
        preds, probs = trainer.predict(X_te_sel)

    
        out_df = pd.DataFrame({
            'id': test_df.id.values,
            'label': preds,
            'confidence': probs.max(axis=1)
        })

    # 6b) Train deep‐learning model
    else:
        # shape for PyTorch nets
        if model_type in ('EEGNet','DeepConvNet'):
            # expects (batch,1,channels,time)
            # X_tr_ep: (n_trials, time, channels) → swap axes
            X_tr_in = X_tr_ep.transpose(0,2,1)[:, None, :, :]
            X_val_in = X_val_ep.transpose(0,2,1)[:, None, :, :]
            X_te_in = X_te_ep.transpose(0,2,1)[:, None, :, :]
            input_shape = (X_tr_ep.shape[2], X_tr_ep.shape[1])
        elif model_type == 'BiLSTM':
            # expects (batch, time, features)
            X_tr_in = X_tr_ft if feature_type!='RAW' else X_tr_ep.reshape(len(X_tr_ep), -1)
            X_val_in = X_val_ft if feature_type!='RAW' else X_val_ep.reshape(len(X_val_ep), -1)
            X_te_in = X_te_ft if feature_type!='RAW' else X_te_ep.reshape(len(X_te_ep), -1)
            input_shape = X_tr_in.shape[1:]  # (seq_len, feat_dim)
        else:
            raise ValueError("Unknown deep model type")

        # build trainer and dataloaders
        trainer = ModelTrainer(model_type,
                               input_shape=input_shape,
                               num_classes=num_classes,
                               **trainer_kwargs)

        # .fit expects numpy arrays
        trainer.fit(X_tr_in, y_tr, X_val_in, y_val)
        preds, probs = trainer.predict(X_te_in)

        out_df = pd.DataFrame({
            'id': test_df.id.values,
            'label': preds,
            'confidence': probs.max(axis=1)
        })

    # 7) save & sanity‐check
    out_df.to_csv('vs_submission.csv', index=False)
    print("✅ Saved submission.csv")
    missing = set(test_df.id) - set(out_df.id)
    if missing:
        print(f"⚠ Missing IDs: {missing}")
    else:
        print("✓ All test IDs covered")

if __name__ == "__main__":
    # Example using FBCSP with LDA
    main_MI(model_type='EEGNet', feature_type='FBCSP',
            n_components=4)

Loading index files...
  Train=2400, Val=50, Test=50
  Epochs: train=2400, val=50, test=50
[DeepConvNet] Epoch 1 val acc: 0.4800
[DeepConvNet] Epoch 2 val acc: 0.4600
[DeepConvNet] Epoch 3 val acc: 0.5000
[DeepConvNet] Epoch 4 val acc: 0.4800
[DeepConvNet] Epoch 5 val acc: 0.4000
[DeepConvNet] Epoch 6 val acc: 0.4000
[DeepConvNet] Epoch 7 val acc: 0.4600
[DeepConvNet] Epoch 8 val acc: 0.5000
[DeepConvNet] Epoch 9 val acc: 0.4000
[DeepConvNet] Epoch 10 val acc: 0.3800
[DeepConvNet] Epoch 11 val acc: 0.3800
[DeepConvNet] Epoch 12 val acc: 0.4000
[DeepConvNet] Epoch 13 val acc: 0.4000
[DeepConvNet] Epoch 14 val acc: 0.4800
[DeepConvNet] Epoch 15 val acc: 0.5400
[DeepConvNet] Epoch 16 val acc: 0.4400
[DeepConvNet] Epoch 17 val acc: 0.3800
[DeepConvNet] Epoch 18 val acc: 0.4200
[DeepConvNet] Epoch 19 val acc: 0.4200
[DeepConvNet] Epoch 20 val acc: 0.4000
Best val acc for DeepConvNet: 0.5400
✅ Saved submission.csv
✓ All test IDs covered
