In [None]:
"""
Physiological Recording Classification: Healthy vs Impaired
==========================================================

This script combines PSD and SNR features to build classifiers that distinguish
between healthy and impaired patient recordings.

Features:
- Loads and combines PSD and SNR feature tables
- Implements customizable filtering via filter functions
- Implements multiple simple classifiers (Decision Tree, Random Forest, SVM, etc.)
- Performs train/test split with proper evaluation
- Provides comprehensive performance metrics and visualizations
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import (classification_report, confusion_matrix, 
                           accuracy_score, precision_score, recall_score, 
                           f1_score, roc_auc_score, roc_curve)
from sklearn.inspection import permutation_importance
import warnings
warnings.filterwarnings('ignore')

plt.style.use('default')
sns.set_palette("husl")

class PhysioClassifier:
    """Main class for physiological recording classification with customizable filtering."""
    
    def __init__(self, filter_func=None):
        """
        Initialize the classifier.
        
        Args:
            filter_func: A function that takes a DataFrame and returns a filtered DataFrame.
                        If None, no filtering is applied.
        """
        self.filter_func = filter_func
        self.psd_data = None
        self.snr_data = None
        self.ccn_data = None
        self.combined_data = None
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.scaler = StandardScaler()
        self.models = {}
        self.results = {}
        
    def load_data(self, psd_file='detailed_psd_table.csv', snr_file='detailed_snrs_table.csv', ccn_file='detailed_ccn_table.csv'):
        """Load PSD, SNR, and CCN data from CSV files and optionally apply filtering."""
        print("Loading data...")
        
        # Load raw data
        psd_data_raw = pd.read_csv(psd_file)
        snr_data_raw = pd.read_csv(snr_file)
        ccn_data_raw = pd.read_csv(ccn_file)
        
        print(f"Raw PSD data shape: {psd_data_raw.shape}")
        print(f"Raw SNR data shape: {snr_data_raw.shape}")
        print(f"Raw CCN data shape: {ccn_data_raw.shape}")
        print(f"PSD columns: {list(psd_data_raw.columns[:5])}...")
        print(f"SNR columns: {list(snr_data_raw.columns)}")
        print(f"CCN columns: {list(ccn_data_raw.columns)}")
        
        # Display original label distribution
        print(f"\nOriginal label distribution:")
        print(psd_data_raw['Arm Type'].value_counts())
        
        # Apply filtering if provided
        if self.filter_func is not None:
            print("\nApplying custom filter...")
            self.psd_data = self.filter_func(psd_data_raw)
            self.snr_data = self.filter_func(snr_data_raw)
            self.ccn_data = self.filter_func(ccn_data_raw)
            
            print(f"Filtered PSD data shape: {self.psd_data.shape}")
            print(f"Filtered SNR data shape: {self.snr_data.shape}")
            print(f"Filtered CCN data shape: {self.ccn_data.shape}")
            print(f"Removed {len(psd_data_raw) - len(self.psd_data)} recordings")
            
            print(f"\nFiltered label distribution:")
            print(self.psd_data['Arm Type'].value_counts())
        else:
            print("\nNo filtering applied.")
            self.psd_data = psd_data_raw
            self.snr_data = snr_data_raw
            self.ccn_data = ccn_data_raw
        
    def combine_datasets(self):
        """Combine PSD, SNR, and CCN datasets on Patient, Recording, and Arm Type."""
        print("\nCombining datasets...")
        
        # First merge PSD and SNR data
        temp_combined = pd.merge(
            self.psd_data, 
            self.snr_data, 
            on=['Patient', 'Recording', 'Arm Type'],
            how='inner'
        )
        
        # Then merge with CCN data
        self.combined_data = pd.merge(
            temp_combined,
            self.ccn_data,
            on=['Patient', 'Recording', 'Arm Type'],
            how='inner'
        )
        
        print(f"Combined data shape: {self.combined_data.shape}")
        print(f"Combined label distribution:")
        print(self.combined_data['Arm Type'].value_counts())
        
        # Check for any missing values
        missing_values = self.combined_data.isnull().sum().sum()
        print(f"Missing values: {missing_values}")
        
        return self.combined_data
    
    def prepare_features_and_labels(self):
        """Prepare feature matrix X and target vector y."""
        print("\nPreparing features and labels...")
        
        # Identify feature columns (exclude Patient, Recording, Arm Type)
        feature_cols = [col for col in self.combined_data.columns 
                       if col not in ['Patient', 'Recording', 'Arm Type']]
        
        # Extract features and labels
        X = self.combined_data[feature_cols].values
        y = self.combined_data['Arm Type'].values
        
        print(f"Feature matrix shape: {X.shape}")
        print(f"Number of features: {len(feature_cols)}")
        print(f"Feature types: PSD ({len([c for c in feature_cols if 'PSD' in c])}), "
              f"SNR ({len([c for c in feature_cols if 'SNR' in c])}), "
              f"CCN ({len([c for c in feature_cols if 'CCN' in c])})")
        
        return X, y, feature_cols
    
    def split_and_scale_data(self, X, y, test_size=0.15, random_state=42):
        """Split data into train/test sets and apply scaling."""
        print(f"\nSplitting data (test_size={test_size})...")
        
        # Split the data
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_state, stratify=y
        )
        
        # Scale the features
        self.X_train = self.scaler.fit_transform(self.X_train)
        self.X_test = self.scaler.transform(self.X_test)
        
        print(f"Training set: {self.X_train.shape[0]} samples")
        print(f"Test set: {self.X_test.shape[0]} samples")
        print(f"Training label distribution:")
        unique, counts = np.unique(self.y_train, return_counts=True)
        for label, count in zip(unique, counts):
            print(f"  {label}: {count} ({count/len(self.y_train)*100:.1f}%)")
            
    def initialize_models(self):
        """Initialize different classifier models."""
        print("\nInitializing models...")
        
        self.models = {
            'Decision Tree': DecisionTreeClassifier(
                random_state=42,
                max_depth=5,
                min_samples_split=5,
                min_samples_leaf=2
            ),
            'Random Forest': RandomForestClassifier(
                n_estimators=200,
                random_state=42,
                max_depth=8,
                min_samples_split=5,
                min_samples_leaf=2,
                max_features='sqrt'
            ),
            'SVM (RBF)': SVC(
                kernel='rbf',
                random_state=42,
                probability=True,
                C=10.0,
                gamma='scale'
            ),
            'SVM (Linear)': SVC(
                kernel='linear',
                random_state=42,
                probability=True,
                C=1.0
            ),
            'Logistic Regression': LogisticRegression(
                random_state=42,
                max_iter=2000,
                C=0.1,
                class_weight='balanced'
            ),
            'K-Nearest Neighbors': KNeighborsClassifier(
                n_neighbors=7,
                weights='distance',
                metric='minkowski',
                p=2
            ),
            'Naive Bayes': GaussianNB(var_smoothing=1e-9)
        }
        
        print(f"Initialized {len(self.models)} models: {list(self.models.keys())}")
    
    def train_and_evaluate_models(self):
        """Train all models and evaluate their performance."""
        print("\nTraining and evaluating models...")
        
        self.results = {}
        
        for name, model in self.models.items():
            print(f"\nTraining {name}...")
            
            # Train the model
            model.fit(self.X_train, self.y_train)
            
            # Make predictions
            y_pred = model.predict(self.X_test)
            y_pred_proba = model.predict_proba(self.X_test)[:, 1] if hasattr(model, 'predict_proba') else None
            
            # Calculate metrics
            accuracy = accuracy_score(self.y_test, y_pred)
            precision = precision_score(self.y_test, y_pred, pos_label='Healthy')
            recall = recall_score(self.y_test, y_pred, pos_label='Healthy')
            f1 = f1_score(self.y_test, y_pred, pos_label='Healthy')
            
            # ROC AUC (if probabilities available)
            roc_auc = None
            if y_pred_proba is not None:
                # Convert labels to binary for ROC calculation
                y_test_binary = (self.y_test == 'Healthy').astype(int)
                roc_auc = roc_auc_score(y_test_binary, y_pred_proba)
            
            # Cross-validation score
            cv_scores = cross_val_score(model, self.X_train, self.y_train, cv=5, scoring='accuracy')
            
            # Store results
            self.results[name] = {
                'model': model,
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'roc_auc': roc_auc,
                'cv_mean': cv_scores.mean(),
                'cv_std': cv_scores.std(),
                'y_pred': y_pred,
                'y_pred_proba': y_pred_proba,
                'confusion_matrix': confusion_matrix(self.y_test, y_pred)
            }
            
            print(f"  Accuracy: {accuracy:.3f}")
            print(f"  Precision: {precision:.3f}")
            print(f"  Recall: {recall:.3f}")
            print(f"  F1-score: {f1:.3f}")
            if roc_auc:
                print(f"  ROC AUC: {roc_auc:.3f}")
            print(f"  CV Score: {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
    
    def print_results_summary(self):
        """Print a summary table of all model results."""
        print("\n" + "="*80)
        print("MODEL PERFORMANCE SUMMARY")
        print("="*80)
        
        # Create summary DataFrame
        summary_data = []
        for name, result in self.results.items():
            summary_data.append({
                'Model': name,
                'Accuracy': f"{result['accuracy']:.3f}",
                'Precision': f"{result['precision']:.3f}",
                'Recall': f"{result['recall']:.3f}",
                'F1-Score': f"{result['f1_score']:.3f}",
                'ROC AUC': f"{result['roc_auc']:.3f}" if result['roc_auc'] else "N/A",
                'CV Score': f"{result['cv_mean']:.3f} ± {result['cv_std']:.3f}"
            })
        
        summary_df = pd.DataFrame(summary_data)
        print(summary_df.to_string(index=False))
        
        # Find best model
        best_model_name = max(self.results.keys(), key=lambda x: self.results[x]['accuracy'])
        print(f"\nBest performing model: {best_model_name}")
        print(f"Best accuracy: {self.results[best_model_name]['accuracy']:.3f}")
    
    def plot_results(self):
        """Create visualization plots for model comparison."""
        print("\nGenerating plots...")
        
        # Set up the plotting
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Model Performance Comparison', fontsize=16, fontweight='bold')
        
        # Extract data for plotting
        model_names = list(self.results.keys())
        accuracies = [self.results[name]['accuracy'] for name in model_names]
        f1_scores = [self.results[name]['f1_score'] for name in model_names]
        cv_means = [self.results[name]['cv_mean'] for name in model_names]
        cv_stds = [self.results[name]['cv_std'] for name in model_names]
        
        # 1. Accuracy comparison
        axes[0, 0].bar(model_names, accuracies, color='skyblue', alpha=0.7)
        axes[0, 0].set_title('Test Accuracy by Model')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].tick_params(axis='x', rotation=45)
        axes[0, 0].grid(True, alpha=0.3)
        
        # 2. F1-Score comparison
        axes[0, 1].bar(model_names, f1_scores, color='lightgreen', alpha=0.7)
        axes[0, 1].set_title('F1-Score by Model')
        axes[0, 1].set_ylabel('F1-Score')
        axes[0, 1].tick_params(axis='x', rotation=45)
        axes[0, 1].grid(True, alpha=0.3)
        
        # 3. Cross-validation scores with error bars
        axes[1, 0].bar(model_names, cv_means, yerr=cv_stds, capsize=5, 
                      color='orange', alpha=0.7)
        axes[1, 0].set_title('Cross-Validation Accuracy (5-fold)')
        axes[1, 0].set_ylabel('CV Accuracy')
        axes[1, 0].tick_params(axis='x', rotation=45)
        axes[1, 0].grid(True, alpha=0.3)
        
        # 4. ROC curves for models with probability predictions
        axes[1, 1].set_title('ROC Curves')
        axes[1, 1].set_xlabel('False Positive Rate')
        axes[1, 1].set_ylabel('True Positive Rate')
        
        # Convert labels to binary for ROC
        y_test_binary = (self.y_test == 'Healthy').astype(int)
        
        for name, result in self.results.items():
            if result['y_pred_proba'] is not None:
                fpr, tpr, _ = roc_curve(y_test_binary, result['y_pred_proba'])
                axes[1, 1].plot(fpr, tpr, label=f"{name} (AUC={result['roc_auc']:.3f})")
        
        axes[1, 1].plot([0, 1], [0, 1], 'k--', alpha=0.5)
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_confusion_matrices(self):
        """Plot confusion matrices for all models."""
        n_models = len(self.results)
        cols = 3
        rows = (n_models + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
        fig.suptitle('Confusion Matrices', fontsize=16, fontweight='bold')
        
        if n_models == 1:
            axes = [axes]
        elif rows == 1:
            axes = axes.reshape(1, -1)
        
        for idx, (name, result) in enumerate(self.results.items()):
            row, col = idx // cols, idx % cols
            ax = axes[row, col] if rows > 1 else axes[col]
            
            cm = result['confusion_matrix']
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                       xticklabels=['Healthy', 'Impaired'], 
                       yticklabels=['Healthy', 'Impaired'])
            ax.set_title(f'{name}\nAccuracy: {result["accuracy"]:.3f}')
            ax.set_xlabel('Predicted')
            ax.set_ylabel('Actual')
        
        # Hide empty subplots
        for idx in range(n_models, rows * cols):
            row, col = idx // cols, idx % cols
            ax = axes[row, col] if rows > 1 else axes[col]
            ax.set_visible(False)
        
        plt.tight_layout()
        plt.savefig('confusion_matrices.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def analyze_feature_importance(self):
        """Analyze feature importance for tree-based models."""
        print("\nAnalyzing feature importance...")
        
        # Get feature names
        feature_cols = [col for col in self.combined_data.columns 
                       if col not in ['Patient', 'Recording', 'Arm Type']]
        
        # Analyze Random Forest feature importance
        if 'Random Forest' in self.results:
            rf_model = self.results['Random Forest']['model']
            feature_importance = rf_model.feature_importances_
            
            # Create feature importance DataFrame
            importance_df = pd.DataFrame({
                'feature': feature_cols,
                'importance': feature_importance
            }).sort_values('importance', ascending=False)
            
            print("\nTop 10 Most Important Features (Random Forest):")
            print(importance_df.head(10).to_string(index=False))
            
            # Plot feature importance
            plt.figure(figsize=(12, 8))
            top_features = importance_df.head(15)
            plt.barh(range(len(top_features)), top_features['importance'])
            plt.yticks(range(len(top_features)), top_features['feature'])
            plt.xlabel('Feature Importance')
            plt.title('Top 15 Feature Importances (Random Forest)')
            plt.gca().invert_yaxis()
            plt.tight_layout()
            plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
            plt.show()
    
    def run_complete_analysis(self):
        """Run the complete classification analysis pipeline."""
        print("Starting Physiological Recording Classification Analysis")
        print("="*60)
        
        # Load and combine data
        self.load_data()
        self.combine_datasets()
        
        # Prepare features and labels
        X, y, feature_cols = self.prepare_features_and_labels()
        
        # Split and scale data
        self.split_and_scale_data(X, y)
        
        # Initialize and train models
        self.initialize_models()
        self.train_and_evaluate_models()
        
        # Display results
        self.print_results_summary()
        
        # Generate visualizations
        self.plot_results()
        self.plot_confusion_matrices()
        self.analyze_feature_importance()
        
        print("\nAnalysis complete! Check the generated plots:")
        print("- model_comparison.png: Overall model performance")
        print("- confusion_matrices.png: Detailed confusion matrices")
        print("- feature_importance.png: Most important features")


In [None]:
# ============================================================================
# FILTER FUNCTION EXAMPLES
# ============================================================================

def no_filter(df):
    """No filtering - return all data."""
    return df

def fma_filter_healthy_only_perfect(df):
    """
    Filter to exclude healthy recordings with FMA score ≠ 2.
    Keeps all impaired recordings regardless of FMA score.
    """
    # Load FMA data
    try:
        fma_data = pd.read_csv('detailed_fma_table.csv')
        print(f"Loaded FMA data for filtering: {fma_data.shape}")
    except FileNotFoundError:
        print("Warning: FMA data not found. Returning unfiltered data.")
        return df
    
    # Merge with FMA data
    df_with_fma = pd.merge(df, fma_data, on=['Patient', 'Recording', 'Arm Type'], how='left')
    
    # Apply filter: Keep all impaired, only healthy with FMA = 2
    filter_mask = (
        (df_with_fma['Arm Type'] == 'Impaired') |  # Keep all impaired
        ((df_with_fma['Arm Type'] == 'Healthy') & (df_with_fma['Average FMA Score'] == 2.0))  # Only perfect healthy
    )
    
    filtered_df = df_with_fma[filter_mask].drop('Average FMA Score', axis=1)
    
    # Report filtering results
    excluded = len(df) - len(filtered_df)
    print(f"FMA Filter applied: Excluded {excluded} recordings")
    print(f"Remaining: {filtered_df['Arm Type'].value_counts().to_dict()}")
    
    return filtered_df

def fma_filter_healthy_and_impaired_only_perfect(df):
    """
    Filter to exclude healthy recordings with FMA score ≠ 2.
    It also excludes impaired recordings with FMA score = 2.
    Keeps all impaired recordings regardless of FMA score.
    """
    # Load FMA data
    try:
        fma_data = pd.read_csv('detailed_fma_table.csv')
        print(f"Loaded FMA data for filtering: {fma_data.shape}")
    except FileNotFoundError:
        print("Warning: FMA data not found. Returning unfiltered data.")
        return df
    
    # Merge with FMA data
    df_with_fma = pd.merge(df, fma_data, on=['Patient', 'Recording', 'Arm Type'], how='left')
    
    # Apply filter: Keep all impaired, only healthy with FMA = 2
    filter_mask = (
        (df_with_fma['Arm Type'] == 'Impaired') & (df_with_fma['Average FMA Score'] > 1.0) |  # Keep all impaired
        ((df_with_fma['Arm Type'] == 'Healthy') & (df_with_fma['Average FMA Score'] == 2.0))  # Only perfect healthy
    )
    
    filtered_df = df_with_fma[filter_mask].drop('Average FMA Score', axis=1)
    
    # Report filtering results
    excluded = len(df) - len(filtered_df)
    print(f"FMA Filter applied: Excluded {excluded} recordings")
    print(f"Remaining: {filtered_df['Arm Type'].value_counts().to_dict()}")
    
    return filtered_df

def fma_filter_range(min_fma=1.5, max_fma=2.0):
    """
    Create a filter function for recordings by FMA score range.
    
    Args:
        min_fma: Minimum FMA score to include
        max_fma: Maximum FMA score to include
        
    Returns:
        A filter function that can be used with PhysioClassifier
    """
    def filter_func(df):
        try:
            fma_data = pd.read_csv('detailed_fma_table.csv')
        except FileNotFoundError:
            print("Warning: FMA data not found. Returning unfiltered data.")
            return df
        
        # Merge with FMA data
        df_with_fma = pd.merge(df, fma_data, on=['Patient', 'Recording', 'Arm Type'], how='left')
        
        # Apply FMA range filter
        filter_mask = (
            (df_with_fma['Average FMA Score'] >= min_fma) & 
            (df_with_fma['Average FMA Score'] <= max_fma)
        )
        
        filtered_df = df_with_fma[filter_mask].drop('Average FMA Score', axis=1)
        
        excluded = len(df) - len(filtered_df)
        print(f"FMA Range Filter ({min_fma}-{max_fma}): Excluded {excluded} recordings")
        print(f"Remaining: {filtered_df['Arm Type'].value_counts().to_dict()}")
        
        return filtered_df
    
    return filter_func

def patient_filter(patient_list):
    """
    Create a filter function to include only specific patients.
    
    Args:
        patient_list: List of patient numbers to include
        
    Returns:
        A filter function that can be used with PhysioClassifier
    """
    def filter_func(df):
        filter_mask = df['Patient'].isin(patient_list)
        filtered_df = df[filter_mask]
        
        excluded = len(df) - len(filtered_df)
        print(f"Patient Filter (patients {patient_list}): Excluded {excluded} recordings")
        print(f"Remaining: {filtered_df['Arm Type'].value_counts().to_dict()}")
        
        return filtered_df
    
    return filter_func

def snr_threshold_filter(min_snr=5.0):
    """
    Create a filter function for recordings by minimum SNR threshold.
    
    Args:
        min_snr: Minimum SNR Mean (dB) to include
        
    Returns:
        A filter function that can be used with PhysioClassifier
    """
    def filter_func(df):
        if 'SNR Mean (dB)' not in df.columns:
            print("Warning: SNR data not found in DataFrame. Returning unfiltered data.")
            return df
        
        filter_mask = df['SNR Mean (dB)'] >= min_snr
        filtered_df = df[filter_mask]
        
        excluded = len(df) - len(filtered_df)
        print(f"SNR Filter (≥{min_snr} dB): Excluded {excluded} recordings")
        print(f"Remaining: {filtered_df['Arm Type'].value_counts().to_dict()}")
        
        return filtered_df
    
    return filter_func

def balanced_dataset_filter(random_state=42):
    """
    Create a filter function for balanced dataset with equal numbers of healthy and impaired recordings.
    
    Args:
        random_state: Random seed for sampling
        
    Returns:
        A filter function that can be used with PhysioClassifier
    """
    def filter_func(df):
        healthy_data = df[df['Arm Type'] == 'Healthy']
        impaired_data = df[df['Arm Type'] == 'Impaired']
        
        min_samples = min(len(healthy_data), len(impaired_data))
        
        balanced_df = pd.concat([
            healthy_data.sample(n=min_samples, random_state=random_state),
            impaired_data.sample(n=min_samples, random_state=random_state)
        ])
        
        excluded = len(df) - len(balanced_df)
        print(f"Balanced Filter: Excluded {excluded} recordings")
        print(f"Remaining: {balanced_df['Arm Type'].value_counts().to_dict()}")
        
        return balanced_df
    
    return filter_func

# Lambda examples for quick filters
exclude_patient_1 = lambda df: df[df['Patient'] != 1]
only_first_10_patients = lambda df: df[df['Patient'] <= 10]
high_snr_only = lambda df: df[df['SNR Mean (dB)'] > 10] if 'SNR Mean (dB)' in df.columns else df

In [None]:
# Run analysis WITH FMA filtering (exclude healthy recordings with FMA ≠ 2)
print("="*80)
print("RUNNING ANALYSIS WITH FMA FILTERING")
print("Filter: Exclude healthy recordings with FMA score ≠ 2")
print("="*80)

classifier_fma_filtered = PhysioClassifier(filter_func=fma_filter_healthy_only_perfect)
classifier_fma_filtered.run_complete_analysis()


In [None]:
# Run analysis WITHOUT filtering (original approach)
print("="*80)
print("RUNNING ANALYSIS WITHOUT FILTERING")
print("="*80)

classifier_unfiltered = PhysioClassifier()
classifier_unfiltered.run_complete_analysis()
