In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.signal import resample, correlate
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, cohen_kappa_score, matthews_corrcoef, log_loss, roc_auc_score, average_precision_score

# Load datasets
# Load datasets
eeg_data = pd.read_csv('eeg_data_trimmed5.csv')
emg_data = pd.read_csv('emg_data_trimmed5.csv')
labels = pd.read_csv('labels_trimmed5.csv')

# Label encoding
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels.values.ravel())

# Resample EEG to match EMG's length (if necessary)
if len(eeg_data) != len(emg_data):
    eeg_data = resample(eeg_data, num=len(emg_data))

# Sliding Window Cross-Correlation for Real-Time Alignment
def sliding_window_cross_correlation(eeg, emg, window_size=100, overlap=50):
    eeg_aligned = np.zeros_like(eeg)
    emg_aligned = np.zeros_like(emg)
    step = window_size - overlap
    sync_scores = []  # To store synchronization scores for each window

    for i in range(0, len(eeg) - window_size, step):
        eeg_window = eeg[i:i + window_size].flatten()  # Flatten the EEG window
        emg_window = emg[i:i + window_size].flatten()  # Flatten the EMG window

        # Normalize the signals (zero mean and unit variance)
        eeg_window = (eeg_window - np.mean(eeg_window)) / np.std(eeg_window)
        emg_window = (emg_window - np.mean(emg_window)) / np.std(emg_window)

        # Compute cross-correlation
        correlation = correlate(eeg_window, emg_window, mode='full')
        lags = np.arange(-len(eeg_window) + 1, len(emg_window))

        # Normalize the cross-correlation values
        correlation = correlation / (np.linalg.norm(eeg_window) * np.linalg.norm(emg_window))

        # Find the lag with the maximum correlation
        lag = lags[np.argmax(correlation)]
        max_corr = np.max(correlation)  # Maximum correlation value (synchronization score)
        sync_scores.append(max_corr)

        # Debug: Print lag and max correlation
        print(f"Window {i}: Lag = {lag}, Max Correlation = {max_corr:.4f}")

        # Align signals within the window
        if lag > 0:
            eeg_aligned[i:i + window_size] = np.roll(eeg[i:i + window_size], lag, axis=0)
            emg_aligned[i:i + window_size] = emg[i:i + window_size]
        else:
            eeg_aligned[i:i + window_size] = eeg[i:i + window_size]
            emg_aligned[i:i + window_size] = np.roll(emg[i:i + window_size], -lag, axis=0)

    # Average synchronization score across all windows
    avg_sync_score = np.mean(sync_scores)
    print(f"Average Synchronization Score: {avg_sync_score:.4f}")

    return eeg_aligned, emg_aligned, avg_sync_score


# Apply sliding window cross-correlation
eeg_aligned, emg_aligned, sync_score = sliding_window_cross_correlation(eeg_data.values, emg_data.values)

# Combine aligned data
combined_data = pd.concat([pd.DataFrame(eeg_aligned), pd.DataFrame(emg_aligned)], axis=1)

# Define a custom Dataset
class EEGEMGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Transformer model with real-time synchronization and cross-modality attention
class EEGEMGTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_classes, dropout_rate=0.5):
        super(EEGEMGTransformer, self).__init__()
        self.align_eeg = nn.Linear(input_dim, input_dim)
        self.align_emg = nn.Linear(input_dim, input_dim)
        
        self.eeg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2)
        self.emg_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, batch_first=True, dropout=dropout_rate),
            num_layers=2)
        
        self.eeg_projector = nn.Linear(input_dim, hidden_dim)
        self.emg_projector = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        
        # Initialize cross-modality attention weights with higher weight for EEG
        self.cross_attention_weights = nn.Parameter(torch.tensor([[0.7], [0.3]]), requires_grad=True)
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
        
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, eeg, emg):
        eeg = self.align_eeg(eeg)
        emg = self.align_emg(emg)
        eeg_features = self.eeg_encoder(eeg)
        emg_features = self.emg_encoder(emg)
        
        eeg_features = self.eeg_projector(eeg_features)
        emg_features = self.emg_projector(emg_features)
        
        eeg_features = self.dropout(eeg_features)
        emg_features = self.dropout(emg_features)
        
        # Weighted combination of EEG and EMG
        combined_features = (
            self.cross_attention_weights[0] * eeg_features + self.cross_attention_weights[1] * emg_features
        )
        combined, _ = self.cross_attention(combined_features, combined_features, combined_features)
        
        output = self.fc(combined.mean(dim=1))
        return output, self.cross_attention_weights


# Dynamic Learning Rate Adjustment
def online_adaptation_with_regularization(
    model, optimizer, buffer_X, buffer_y, criterion, val_loader, num_cycles=2, batch_size=8, lr=0.00001
):
    model.train()
    
    # Create a replay buffer with original training data
    replay_buffer_X = X_train[:100]  # Store 100 samples from training data
    replay_buffer_y = y_train[:100]
    buffer_X = np.concatenate([buffer_X, replay_buffer_X])
    buffer_y = np.concatenate([buffer_y, replay_buffer_y])
    
    buffer_dataset = EEGEMGDataset(buffer_X, buffer_y)
    buffer_loader = DataLoader(buffer_dataset, batch_size=batch_size, shuffle=True)

    # Lower learning rate for stable online updates
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    
    # Learning rate scheduler with warm-up
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr, step_size_up=5, mode='triangular2')

    best_val_loss = float('inf')
    best_model_state = model.state_dict()

    for cycle in range(num_cycles):
        print(f"Online Adaptation Cycle {cycle + 1}/{num_cycles}")
        
        # Train on adaptation data
        for X_batch, y_batch in buffer_loader:
            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)
            optimizer.zero_grad()
            outputs, _ = model(eeg, emg)  # Ignore attention weights during training
            loss = criterion(outputs, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            scheduler.step()  # Update learning rate

        # Evaluate on the validation set after each cycle
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                eeg = X_batch[:, :input_dim].unsqueeze(1)
                emg = X_batch[:, input_dim:].unsqueeze(1)
                outputs, _ = model(eeg, emg)  # Ignore attention weights during validation
                loss = criterion(outputs, y_batch)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        print(f"Validation Loss After Cycle {cycle + 1}: {val_loss:.4f}")

        # Save the best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            print("New best model saved.")
        else:
            print("Validation loss increased. Early stopping triggered.")
            break  # Stop adaptation if validation loss increases

    # Load the best model state after adaptation
    model.load_state_dict(best_model_state)
    print("Best model loaded after online adaptation.")



from sklearn.metrics import roc_curve, precision_recall_curve, auc

def plot_roc_curve(y_true, y_scores, num_classes):
    plt.figure(figsize=(10, 8))
    y_true = np.array(y_true)  # Ensure y_true is a NumPy array
    
    # One-vs-Rest ROC Curve for multiclass classification
    for i in range(num_classes):
        fpr, tpr, _ = roc_curve(y_true == i, y_scores[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'Class {i} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower center")
    plt.show()

def plot_precision_recall_curve(y_true, y_scores, num_classes):
    plt.figure(figsize=(10, 8))
    y_true = np.array(y_true)  # Ensure y_true is a NumPy array
    
    # One-vs-Rest Precision-Recall Curve for multiclass classification
    for i in range(num_classes):
        precision, recall, _ = precision_recall_curve((y_true == i).astype(int), y_scores[:, i])
        pr_auc = auc(recall, precision)
        plt.plot(recall, precision, label=f'Class {i} (AUC = {pr_auc:.2f})')
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc="lower center")
    plt.show()


from sklearn.metrics import hamming_loss

def calculate_hamming_loss(y_true, y_pred):
    return hamming_loss(y_true, y_pred)

from sklearn.metrics import top_k_accuracy_score

def calculate_top_k_accuracy(y_true, y_scores, k=3):
    return top_k_accuracy_score(y_true, y_scores, k=k)

def analyze_attention_weights(attention_weights):
    avg_attention_weights = np.mean(attention_weights, axis=0)
    print(f"Average Attention Weights - EEG: {avg_attention_weights[0][0]:.4f}, EMG: {avg_attention_weights[1][0]:.4f}")
    return avg_attention_weights

def plot_learning_curves(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Learning Curves')
    plt.legend()
    plt.show()

def ablation_study(model, val_loader, ablation_type='eeg'):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)
            
            if ablation_type == 'eeg':
                eeg = torch.zeros_like(eeg)  # Ablate EEG
            elif ablation_type == 'emg':
                emg = torch.zeros_like(emg)  # Ablate EMG
            
            outputs, _ = model(eeg, emg)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    
    accuracy = accuracy_score(y_true, y_pred)
    print(f"Ablation Study ({ablation_type.upper()} ablated): Accuracy = {accuracy:.4f}")

import time

def measure_inference_time(model, val_loader):
    model.eval()
    start_time = time.time()
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)
            outputs, _ = model(eeg, emg)
    
    inference_time = time.time() - start_time
    print(f"Inference Time: {inference_time:.4f} seconds")
    return inference_time

def measure_model_size(model):
    param_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model Size: {param_size} parameters")
    return param_size


def compute_average_metrics(metrics_list):
    average_metrics = {}
    for metric in metrics_list[0].keys():
        values = [result[metric] for result in metrics_list if result[metric] is not None]  # Filter out None values
        if values:  # Only compute the average if there are valid values
            average_metrics[metric] = np.mean(values)
        else:
            average_metrics[metric] = None  # Set to None if no valid values exist
    return average_metrics



def analyze_attention_weights(attention_weights):
    avg_attention_weights = np.mean(attention_weights, axis=0)
    print(f"Average Attention Weights - EEG: {avg_attention_weights[0][0]:.4f}, EMG: {avg_attention_weights[1][0]:.4f}")
    return avg_attention_weights

def plot_modality_contribution(attention_weights):

    modalities = ['EEG', 'EMG']
    contributions = attention_weights.flatten()  # Flatten the attention weights array
    
    plt.figure(figsize=(6, 4))
    sns.barplot(x=modalities, y=contributions, palette="viridis")
    plt.title('Modality Contribution Based on Attention Weights')
    plt.ylabel('Contribution Weight')
    plt.xlabel('Modality')
    plt.show()

def evaluate_model(model, val_loader, train_losses=None, val_losses=None):
    model.eval()
    y_true = []
    y_pred = []
    y_scores = []  # For probability outputs
    attention_weights = []  # To store attention weights for contribution analysis

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)
            outputs, weights = model(eeg, emg)  # Get attention weights
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
            y_scores.extend(probs.cpu().numpy())
            attention_weights.append(weights.cpu().numpy())

    # Convert to NumPy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_scores = np.array(y_scores)
    attention_weights = np.array(attention_weights)  # Convert to NumPy array

    # Compute classification metrics
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    accuracy = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    log_loss_value = log_loss(y_true, y_scores)
    hamming_loss_value = calculate_hamming_loss(y_true, y_pred)
    top_k_accuracy_value = calculate_top_k_accuracy(y_true, y_scores, k=3)

    # AUROC and AUPRC (only for binary or multiclass with probability scores)
    try:
        auroc = roc_auc_score(y_true, y_scores, multi_class='ovr', average='weighted')
        auprc = average_precision_score(y_true, y_scores, average='weighted')
    except ValueError:
        auroc = None
        auprc = None

    # Mean Per-Class Error (MPCE)
    cm = confusion_matrix(y_true, y_pred)
    per_class_error = 1 - (np.diag(cm) / cm.sum(axis=1))
    mpce = np.mean(per_class_error)

    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}, Accuracy: {accuracy:.2f}")
    print(f"Balanced Accuracy: {balanced_acc:.2f}, Cohen's Kappa: {kappa:.2f}, MCC: {mcc:.2f}")
    print(f"Log Loss: {log_loss_value:.4f}, "
          f"AUROC: {auroc:.2f}" if auroc is not None else "AUROC: N/A",
          f"AUPRC: {auprc:.2f}" if auprc is not None else "AUPRC: N/A")
    print(f"MPCE: {mpce:.4f}")
    print(f"Hamming Loss: {hamming_loss_value:.4f}")
    print(f"Top-3 Accuracy: {top_k_accuracy_value:.4f}")

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix")
    plt.show()

    # Plot ROC and Precision-Recall curves
    plot_roc_curve(y_true, y_scores, num_classes=len(np.unique(y_true)))
    plot_precision_recall_curve(y_true, y_scores, num_classes=len(np.unique(y_true)))

    # Contribution analysis: EEG vs. EMG
    avg_attention_weights = analyze_attention_weights(attention_weights)

    # Plot modality contribution
    plot_modality_contribution(avg_attention_weights)

    # Ablation studies
    ablation_study(model, val_loader, ablation_type='eeg')
    ablation_study(model, val_loader, ablation_type='emg')

    # Computational efficiency metrics
    inference_time = measure_inference_time(model, val_loader)
    model_size = measure_model_size(model)

    # Return all relevant metrics and data
    metrics = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'cohen_kappa': kappa,
        'mcc': mcc,
        'log_loss': log_loss_value,
        'auroc': auroc,
        'auprc': auprc,
        'mpce': mpce,
        'hamming_loss': hamming_loss_value,
        'top_k_accuracy': top_k_accuracy_value,
        'sync_score': sync_score,
        'attention_weights': avg_attention_weights,
        'inference_time': inference_time,
        'model_size': model_size,
        'y_true': y_true,  # Add y_true to the returned dictionary
        'y_scores': y_scores,  # Add y_scores to the returned dictionary
        'confusion_matrix': cm  # Add confusion matrix to the returned dictionary
    }

    # Add train_losses and val_losses if provided
    if train_losses is not None:
        metrics['train_losses'] = train_losses
    if val_losses is not None:
        metrics['val_losses'] = val_losses

    return metrics

def train_model_with_weight_decay(model, train_loader, val_loader, criterion, epochs=15, patience=5, lr=0.0005):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    best_loss = float('inf')
    patience_counter = 0
    train_losses = []  # To store training losses for each epoch
    val_losses = []    # To store validation losses for each epoch

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            eeg = X_batch[:, :input_dim].unsqueeze(1)
            emg = X_batch[:, input_dim:].unsqueeze(1)
            optimizer.zero_grad()
            outputs, _ = model(eeg, emg)  # Ignore attention weights during training
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Record average training loss for the epoch
        train_losses.append(total_loss / len(train_loader))

        # Validation phase
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                eeg = X_batch[:, :input_dim].unsqueeze(1)
                emg = X_batch[:, input_dim:].unsqueeze(1)
                outputs, _ = model(eeg, emg)  # Ignore attention weights during validation
                loss = criterion(outputs, y_batch)
                val_loss += loss.item()
        
        # Record average validation loss for the epoch
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
        
        # Early stopping logic
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                model.load_state_dict(best_model_state)
                break

    # Return training and validation losses
    return train_losses, val_losses


# Initialize lists to store aggregated data across folds
all_y_true = []
all_y_scores = []
all_train_losses = []
all_val_losses = []
all_confusion_matrices = []

# k-Fold Cross-Validation with Online Adaptation
k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)
fold_results = []
before_adaptation_metrics = []  # To store metrics before online adaptation
after_adaptation_metrics = []   # To store metrics after online adaptation
online_adaptation_percentage = 0.3  # 30% data for online adaptation

for fold, (train_index, val_index) in enumerate(kf.split(combined_data)):
    print(f"Fold {fold + 1}/{k}")
    X_train, X_val = combined_data.values[train_index], combined_data.values[val_index]
    y_train, y_val = labels[train_index], labels[val_index]
    
    train_dataset = EEGEMGDataset(X_train, y_train)
    val_dataset = EEGEMGDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    input_dim = X_train.shape[1] // 2
    hidden_dim = 256
    num_heads = 4
    num_classes = len(np.unique(labels))
    
    model = EEGEMGTransformer(input_dim=input_dim, hidden_dim=hidden_dim, num_heads=num_heads, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()

    # Train model and get training/validation losses
    train_losses, val_losses = train_model_with_weight_decay(model, train_loader, val_loader, criterion)

    print("Before Online Adaptation:")
    metrics_before = evaluate_model(model, val_loader, train_losses=train_losses, val_losses=val_losses)  # Pass losses here
    before_adaptation_metrics.append(metrics_before)  # Store metrics before adaptation

    # Collect data for average plots
    all_y_true.extend(metrics_before['y_true'])
    all_y_scores.extend(metrics_before['y_scores'])
    all_train_losses.append(metrics_before['train_losses'])
    all_val_losses.append(metrics_before['val_losses'])
    all_confusion_matrices.append(metrics_before['confusion_matrix'])

    # Online adaptation
    online_data_size = int(len(X_val) * online_adaptation_percentage)
    online_X, online_y = X_val[:online_data_size], y_val[:online_data_size]
    print(f"Online Adaptation using {online_data_size} samples.")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000001, weight_decay=0.01)
    online_adaptation_with_regularization(
        model, optimizer, online_X, online_y, criterion, val_loader, num_cycles=5, batch_size=16, lr=0.00001
    )

    print("After Online Adaptation:")
    metrics_after = evaluate_model(model, val_loader)
    after_adaptation_metrics.append(metrics_after)  # Store metrics after adaptation

    # Combine metrics for this fold
    fold_results.append({
        'before_adaptation': metrics_before,
        'after_adaptation': metrics_after
    })

# Debugging: Print metrics before and after adaptation
print("Before Adaptation Metrics:")
for i, metrics in enumerate(before_adaptation_metrics):
    print(f"Fold {i + 1}:")
    for key, value in metrics.items():
        print(f"  {key}: {value}")

print("After Adaptation Metrics:")
for i, metrics in enumerate(after_adaptation_metrics):
    print(f"Fold {i + 1}:")
    for key, value in metrics.items():
        print(f"  {key}: {value}")


# Save the best model
best_model_index = np.argmax([result['after_adaptation']['accuracy'] for result in fold_results])
torch.save(model.state_dict(), 'EEGEMGTransformer_best.pth')
print(f"Best model (Fold {best_model_index + 1}) saved as EEGEMGTransformer_best.pth")
