Evaluation and Analysis

In [None]:
# Comprehensive evaluation and analysis
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true = []
    y_pred = []
    y_probs = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            main_output, _ = model(data)
            probs = F.softmax(main_output, dim=1)
            pred = main_output.argmax(dim=1)
            
            y_true.extend(target.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
            y_probs.extend(probs.cpu().numpy())
    
    return np.array(y_true), np.array(y_pred), np.array(y_probs)

# Test evaluation
y_true, y_pred, y_probs = evaluate_model(model, test_loader, device)
test_accuracy = accuracy_score(y_true, y_pred)

print(f"\nFinal Test Results:")
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print("\nDetailed Classification Report:")
print(classification_report(y_true, y_pred, 
                          target_names=['Left Hand', 'Right Hand', 'Feet', 'Tongue']))

In [None]:
print("\n" + "="*60)
print("EEG MOTOR IMAGERY CLASSIFICATION RESULTS")
print("="*60 + "\n")
print(f"Final Test Accuracy: {test_accuracy*100:.2f}%")
print(f"Best Validation Accuracy: {trainer.best_val_acc:.2f}%")
print(f"Model Parameters: {trainable_params:,}")
print("\n" + "="*60)

# Performance summary
per_class_acc = [f"{acc:.1f}%" for acc in class_accuracies]
performance_summary = {
    'Overall Accuracy': f"{test_accuracy*100:.2f}%",
    'Left Hand': per_class_acc[0],
    'Right Hand': per_class_acc[1], 
    'Feet': per_class_acc[2],
    'Tongue': per_class_acc[3],
    'Mean Confidence': f"{np.mean(confidence_scores):.3f}",
    'Model Size': f"{trainable_params:,} parameters"
}

print("\nPERFORMANCE BREAKDOWN:")
for key, value in performance_summary.items():
    print(f"{key:.<20} {value}")

print("\n" + "="*60)

Plots

In [None]:
import seaborn as sns

# Visualization and analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Training history
axes[0, 0].plot(history['train_losses'], label='Train Loss', alpha=0.7)
axes[0, 0].plot(history['val_losses'], label='Val Loss', alpha=0.7)
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

axes[0, 1].plot(history['train_accs'], label='Train Acc', alpha=0.7)
axes[0, 1].plot(history['val_accs'], label='Val Acc', alpha=0.7)
axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True)

# 2. Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Left Hand', 'Right Hand', 'Feet', 'Tongue'],
            yticklabels=['Left Hand', 'Right Hand', 'Feet', 'Tongue'],
            ax=axes[0, 2])
axes[0, 2].set_title('Confusion Matrix')
axes[0, 2].set_ylabel('True Label')
axes[0, 2].set_xlabel('Predicted Label')

# 3. Class-wise accuracy
class_accuracies = []
for i in range(4):
    class_mask = y_true == i
    if np.sum(class_mask) > 0:
        class_acc = accuracy_score(y_true[class_mask], y_pred[class_mask])
        class_accuracies.append(class_acc * 100)
    else:
        class_accuracies.append(0)

class_names = ['Left Hand', 'Right Hand', 'Feet', 'Tongue']
bars = axes[1, 0].bar(class_names, class_accuracies, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
axes[1, 0].set_title('Class-wise Accuracy')
axes[1, 0].set_ylabel('Accuracy (%)')
axes[1, 0].set_ylim([0, 100])
for bar, acc in zip(bars, class_accuracies):
    axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                    f'{acc:.1f}%', ha='center', va='bottom')

# 4. Prediction confidence distribution
confidence_scores = np.max(y_probs, axis=1)
axes[1, 1].hist(confidence_scores, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[1, 1].set_title('Prediction Confidence Distribution')
axes[1, 1].set_xlabel('Confidence Score')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].axvline(np.mean(confidence_scores), color='red', linestyle='--', 
                   label=f'Mean: {np.mean(confidence_scores):.3f}')
axes[1, 1].legend()

# 5. Feature importance analysis (Simplified approach)
model.eval()
with torch.no_grad():
    sample_input = X_test[:10].to(device)
    main_output, _ = model(sample_input)
    feature_importance = torch.mean(torch.abs(main_output), dim=0).cpu().numpy()

axes[1, 2].bar(class_names, feature_importance, color=['#FF9999', '#66B2FF', '#99FF99', '#FFCC99'])
axes[1, 2].set_title('Average Feature Activation by Class')
axes[1, 2].set_ylabel('Activation Magnitude')

plt.tight_layout()
plt.savefig('enhanced_eeg_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# --- AUC Evaluation ---
print("\n" + "="*60)
print("AUC EVALUATION")
print("="*60)

try:
    macro_auc_score = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
    weighted_auc_score = roc_auc_score(y_true, y_probs, multi_class='ovr', average='weighted')
    print(f"Macro-averaged AUC (One-vs-Rest): {macro_auc_score:.4f}")
    print(f"Weighted-averaged AUC (One-vs-Rest): {weighted_auc_score:.4f}")

    # Plotting ROC curves for each class
    plt.figure(figsize=(10, 8))
    class_names = ['Left Hand', 'Right Hand', 'Feet', 'Tongue'] 

    for i, class_name in enumerate(class_names):
        # Compute ROC curve and AUC for each class
        fpr, tpr, _ = roc_curve(y_true == i, y_probs[:, i])
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label=f'ROC curve for {class_name} (area = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Chance level (AUC = 0.5)')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve - One-vs-Rest')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig('multi_class_roc_curve.png', dpi=300, bbox_inches='tight')
    plt.show()

except ValueError as e:
    print(f"Could not compute AUC score or plot ROC curve: {e}")
    print("This typically happens if a class has only one sample, or if predictions are all the same for a class.")
