In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
from fitparse import FitFile
import neurokit2 as nk
#import tensorflow as tf
import pickle

In [13]:
class FitEKGAnalyzerWithML(object):
    """
    A class to analyze EKG data from FIT files using pre-trained ML models
    to classify heartbeats and detect abnormalities.
    """
    
    def __init__(self, file_path):
        """
        Initialize the analyzer with a FIT file path.
        
        Parameters:
        -----------
        file_path : str
            Path to the FIT file
        """
        self.file_path = file_path
        self.raw_data = None
        self.ecg_signal = None
        self.sampling_rate = None
        self.r_peaks = None
        self.heartbeats = None
        self.beat_classifications = None
        self.metrics = {}
        
        # ML model attributes
        self.model = None
        self.model_sampling_rate = None
        self.beat_length = None
        self.class_mapping = None
        
    def load_fit_file(self):
        """Load and parse the FIT file"""
        try:
            fit_file = FitFile(self.file_path)
            
            # Extract EKG data - field names may vary by device
            ecg_data = []
            timestamps = []
            
            for record in fit_file.get_messages('record'):
                # Look for ECG data fields
                for field in record:
                    if 'ecg' in field.name.lower() or 'ekg' in field.name.lower():
                        ecg_data.append(field.value)
                    if field.name == 'timestamp':
                        timestamps.append(field.value)
            
            # If no explicit ECG data is found, try to use other fields
            if not ecg_data:
                for record in fit_file.get_messages('record'):
                    # Some devices store ECG under different names
                    for field in record:
                        if field.name in ['heart_rate_raw', 'heart_waveform']:
                            ecg_data.append(field.value)
                        if field.name == 'timestamp':
                            timestamps.append(field.value)
            
            if not ecg_data:
                raise ValueError("No ECG/EKG data found in the FIT file")
                
            # Estimate sampling rate from timestamps if available
            if len(timestamps) > 1:
                time_diff = (timestamps[-1] - timestamps[0]).total_seconds()
                self.sampling_rate = len(ecg_data) / time_diff
            else:
                # Default to a common ECG sampling rate if can't be determined
                self.sampling_rate = 250  # Hz
            
            self.raw_data = pd.DataFrame({
                'timestamp': timestamps if len(timestamps) == len(ecg_data) else range(len(ecg_data)),
                'ecg': ecg_data
            })
            
            self.ecg_signal = np.array(ecg_data)
            
            print(f"Loaded ECG data with {len(ecg_data)} samples at {self.sampling_rate:.2f} Hz")
            return True
            
        except Exception as e:
            print(f"Error loading FIT file: {e}")
            return False
    
    def preprocess_ecg(self):
        """Preprocess the ECG signal by removing noise and baseline wander"""
        if self.ecg_signal is None:
            print("No ECG data loaded. Please load a FIT file first.")
            return False
        
        # Apply bandpass filter to remove noise
        # ECG typically has frequency components between 0.5 and 40 Hz
        self.ecg_signal = nk.ecg_clean(self.ecg_signal, sampling_rate=self.sampling_rate)
        return True
    
    def detect_r_peaks(self):
        """Detect R-peaks in the ECG signal"""
        if self.ecg_signal is None:
            print("No ECG data loaded. Please load a FIT file first.")
            return False
        
        # Use neurokit2 for R-peak detection
        _, info = nk.ecg_peaks(self.ecg_signal, sampling_rate=self.sampling_rate)
        self.r_peaks = info['ECG_R_Peaks']
        
        print(f"Detected {len(self.r_peaks)} R-peaks")
        return True
    
    def segment_heartbeats(self):
        """Segment the ECG signal into individual heartbeats"""
        if self.r_peaks is None:
            print("No R-peaks detected. Please run detect_r_peaks first.")
            return False
        
        # Segment heartbeats around R-peaks
        # Typical heartbeat is ~200ms before R-peak and ~400ms after
        before = int(0.2 * self.sampling_rate)
        after = int(0.4 * self.sampling_rate)
        
        self.heartbeats = []
        self.heartbeat_indices = []  # Store indices for later reference
        
        for r_peak in self.r_peaks:
            if r_peak - before >= 0 and r_peak + after < len(self.ecg_signal):
                beat = self.ecg_signal[r_peak - before : r_peak + after]
                self.heartbeats.append(beat)
                self.heartbeat_indices.append((r_peak - before, r_peak + after))
        
        print(f"Segmented {len(self.heartbeats)} heartbeats")
        return True
    
    def load_model(self, model_path='ecg_model', scaler_path=None):
        """
        Load a pre-trained ECG classification model
        
        Parameters:
        -----------
        model_path : str
            Path to the saved model
        scaler_path : str, optional
            Path to the saved scaler for feature normalization
        """
        try:
            # Load the TensorFlow/Keras model
            self.model = tf.keras.models.load_model(model_path)
            
            # Load the scaler if provided
            self.scaler = None
            if scaler_path and os.path.exists(scaler_path):
                with open(scaler_path, 'rb') as f:
                    self.scaler = pickle.load(f)
            
            # Get model configuration
            # This assumes the model expects a specific input shape
            input_shape = self.model.layers[0].input_shape
            
            if isinstance(input_shape, list):
                # If model has multiple inputs
                self.beat_length = input_shape[0][1]
            else:
                # Single input model
                self.beat_length = input_shape[1]
            
            # Default to 250Hz if model was trained on that rate
            # This is common for PhysioNet models
            self.model_sampling_rate = 250
            
            # Define class mapping based on PhysioNet 2020 Challenge
            self.class_mapping = {
                0: 'Normal',
                1: 'AF',       # Atrial Fibrillation
                2: 'IAVB',     # First-degree AV block
                3: 'LBBB',     # Left bundle branch block
                4: 'RBBB',     # Right bundle branch block
                5: 'PAC',      # Premature atrial contraction
                6: 'PVC',      # Premature ventricular contraction
                7: 'STD',      # ST-segment depression
                8: 'STE',      # ST-segment elevation
            }
            
            print(f"Loaded model expecting input length: {self.beat_length}")
            return True
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Since the model couldn't be loaded, will use simplified classification.")
            return False
    
    def preprocess_beat_for_model(self, beat):
        """
        Preprocess a single heartbeat for model input
        
        Parameters:
        -----------
        beat : numpy.ndarray
            Raw heartbeat signal
            
        Returns:
        --------
        numpy.ndarray
            Preprocessed beat ready for model input
        """
        # Resample to expected input length
        if len(beat) != self.beat_length:
            beat = signal.resample(beat, self.beat_length)
        
        # Normalize
        beat = (beat - np.mean(beat)) / (np.std(beat) + 1e-6)  # Add small epsilon to avoid div by zero
        
        # Apply scaler if available
        if self.scaler is not None:
            beat = self.scaler.transform(beat.reshape(1, -1)).reshape(-1)
        
        return beat
    
    def extract_beat_features(self, beat):
        """
        Extract additional features from a heartbeat for models that need them
        
        Parameters:
        -----------
        beat : numpy.ndarray
            Preprocessed heartbeat signal
            
        Returns:
        --------
        numpy.ndarray
            Feature vector
        """
        # Example features (customize based on your model's requirements)
        features = []
        
        # Statistical features
        features.append(np.mean(beat))
        features.append(np.std(beat))
        features.append(np.max(beat))
        features.append(np.min(beat))
        
        # Spectral features
        freqs, psd = signal.welch(beat, fs=self.model_sampling_rate)
        
        # Get power in relevant frequency bands
        # P wave: 0.5-3 Hz
        # QRS complex: 3-40 Hz
        # T wave: 0.5-7 Hz
        p_power = np.sum(psd[(freqs >= 0.5) & (freqs <= 3)])
        qrs_power = np.sum(psd[(freqs >= 3) & (freqs <= 40)])
        t_power = np.sum(psd[(freqs >= 0.5) & (freqs <= 7)])
        
        features.extend([p_power, qrs_power, t_power])
        
        return np.array(features)
    
    def classify_beats_with_ml(self):
        """Classify beats using the pre-trained ML model"""
        if self.heartbeats is None or len(self.heartbeats) == 0:
            print("No heartbeats segmented. Please run segment_heartbeats first.")
            return False
        
        if self.model is None:
            print("No model loaded. Using simplified classification method instead.")
            return self.classify_beats_simple()
        
        # Preprocess beats to match model input requirements
        X = []
        for beat in self.heartbeats:
            processed_beat = self.preprocess_beat_for_model(beat)
            X.append(processed_beat)
        
        # Convert to appropriate array format for model
        # Adjust shape based on your model's expected input
        X = np.array(X).reshape(-1, self.beat_length, 1)
        
        # Make predictions
        try:
            predictions = self.model.predict(X)
            
            # If model outputs probabilities for multiple classes
            if predictions.shape[1] > 1:
                # Get the class with highest probability
                predicted_classes = np.argmax(predictions, axis=1)
            else:
                # Binary classification
                predicted_classes = (predictions > 0.5).astype(int).flatten()
            
            # Map to class names
            self.beat_classifications = [
                self.class_mapping.get(pred_class, 'Unknown') 
                for pred_class in predicted_classes
            ]
            
            # Calculate statistics
            self.calculate_classification_stats()
            
            return True
            
        except Exception as e:
            print(f"Error making predictions: {e}")
            print("Falling back to simplified classification method.")
            return self.classify_beats_simple()
    
    def classify_beats_simple(self):
        """
        Simplified beat classification when no model is available
        This is a backup method
        """
        if self.heartbeats is None:
            print("No heartbeats segmented. Please run segment_heartbeats first.")
            return False
        
        # Calculate RR intervals
        rr_intervals = np.diff(self.r_peaks) / self.sampling_rate * 1000  # in ms
        
        # Simple classification
        classifications = []
        
        # First beat
        classifications.append('Normal')
        
        # For remaining beats
        for i in range(1, len(self.heartbeats)):
            beat = self.heartbeats[i]
            rr = rr_intervals[i-1] if i-1 < len(rr_intervals) else None
            
            # Simple rules based on RR interval and beat morphology
            if rr is not None and rr < 600:  # Short RR interval
                if np.max(beat) - np.min(beat) > 1.5 * np.mean([b.max() - b.min() for b in self.heartbeats]):
                    classifications.append('PVC')  # Premature Ventricular Contraction
                else:
                    classifications.append('PAC')  # Premature Atrial Contraction
            else:
                classifications.append('Normal')
        
        self.beat_classifications = classifications
        
        # Calculate statistics
        self.calculate_classification_stats()
        
        return True
    
    def calculate_classification_stats(self):
        """Calculate statistics for beat classifications"""
        if self.beat_classifications is None:
            return
        
        total_beats = len(self.beat_classifications)
        unique_classifications = set(self.beat_classifications)
        
        stats = {}
        for classification in unique_classifications:
            count = self.beat_classifications.count(classification)
            percentage = (count / total_beats) * 100
            stats[classification] = {
                'count': count,
                'percentage': percentage
            }
        
        self.metrics['beat_classifications'] = stats
        self.metrics['total_beats'] = total_beats
        
        # Calculate percentage of abnormal beats (non-Normal)
        normal_percentage = stats.get('Normal', {'percentage': 0})['percentage']
        self.metrics['abnormal_percentage'] = 100 - normal_percentage
    
    def calculate_hrv_metrics(self):
        """Calculate Heart Rate Variability metrics"""
        if self.r_peaks is None:
            print("No R-peaks detected. Please run detect_r_peaks first.")
            return False
        
        # Calculate RR intervals in seconds
        rr_intervals = np.diff(self.r_peaks) / self.sampling_rate
        
        # Time domain HRV metrics
        self.metrics['hrv'] = {}
        self.metrics['hrv']['mean_hr'] = 60 / np.mean(rr_intervals)
        self.metrics['hrv']['sdnn'] = np.std(rr_intervals) * 1000  # in ms
        self.metrics['hrv']['rmssd'] = np.sqrt(np.mean(np.square(np.diff(rr_intervals)))) * 1000  # in ms
        
        # Calculate pNN50 (percentage of successive RR intervals that differ by more than 50 ms)
        nn50 = sum(abs(np.diff(rr_intervals)) > 0.05)  # 0.05s = 50ms
        self.metrics['hrv']['pnn50'] = (nn50 / len(rr_intervals)) * 100 if len(rr_intervals) > 0 else 0
        
        return True
    
    def run_full_analysis(self, model_path=None, scaler_path=None):
        """
        Run the complete analysis pipeline
        
        Parameters:
        -----------
        model_path : str, optional
            Path to the pre-trained model
        scaler_path : str, optional
            Path to the feature scaler
        """
        if not self.load_fit_file():
            return False
        
        if not self.preprocess_ecg():
            return False
        
        if not self.detect_r_peaks():
            return False
        
        if not self.segment_heartbeats():
            return False
        
        # Try to load model if path provided
        if model_path:
            self.load_model(model_path, scaler_path)
        
        # Use ML classification if model loaded, otherwise use simple method
        if self.model:
            if not self.classify_beats_with_ml():
                return False
        else:
            if not self.classify_beats_simple():
                return False
        
        if not self.calculate_hrv_metrics():
            return False
        
        return True
    
    def generate_report(self):
        """Generate a report with the analysis results"""
        if not self.metrics:
            print("No analysis results available. Please run the analysis first.")
            return
        
        print("\n===== EKG Analysis Report =====")
        
        # Basic information
        print(f"\nFile: {os.path.basename(self.file_path)}")
        print(f"Total Duration: {len(self.ecg_signal)/self.sampling_rate:.2f} seconds")
        print(f"Total Beats Analyzed: {self.metrics['total_beats']}")
        
        # Beat classifications
        print("\nBeat Classifications:")
        for classification, data in self.metrics['beat_classifications'].items():
            print(f"  {classification}: {data['count']} beats ({data['percentage']:.2f}%)")
        
        print(f"\nAbnormal Beats: {self.metrics['abnormal_percentage']:.2f}%")
        
        # HRV metrics
        print("\nHeart Rate Variability Metrics:")
        print(f"  Mean Heart Rate: {self.metrics['hrv']['mean_hr']:.2f} bpm")
        print(f"  SDNN: {self.metrics['hrv']['sdnn']:.2f} ms")
        print(f"  RMSSD: {self.metrics['hrv']['rmssd']:.2f} ms")
        print(f"  pNN50: {self.metrics['hrv']['pnn50']:.2f}%")
    
    def plot_ecg_with_classifications(self, save_path=None):
        """Plot the ECG signal with beat classifications"""
        if self.ecg_signal is None or self.r_peaks is None or self.beat_classifications is None:
            print("Missing data for plotting. Please run the full analysis first.")
            return
        
        plt.figure(figsize=(15, 8))
        
        # Plot the ECG signal
        time = np.arange(len(self.ecg_signal)) / self.sampling_rate
        plt.plot(time, self.ecg_signal, 'b-', alpha=0.5, label='ECG Signal')
        
        # Plot R-peaks with classification colors
        colors = {
            'Normal': 'green',
            'PVC': 'red',
            'PAC': 'orange',
            'LBBB': 'purple',
            'RBBB': 'brown',
            'AF': 'magenta',
            'IAVB': 'cyan',
            'STD': 'yellow',
            'STE': 'pink',
            'Unknown': 'gray'
        }
        
        for i, r_peak in enumerate(self.r_peaks):
            if i < len(self.beat_classifications):
                beat_class = self.beat_classifications[i]
                color = colors.get(beat_class, 'blue')
                plt.plot(r_peak/self.sampling_rate, self.ecg_signal[r_peak], 'o', 
                         color=color, markersize=8)
        
        # Create legend
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                          markerfacecolor=color, markersize=8, label=classification)
                          for classification, color in colors.items() 
                          if classification in self.beat_classifications]
        
        plt.legend(handles=legend_elements, loc='upper right')
        plt.title('ECG Signal with Beat Classifications')
        plt.xlabel('Time (seconds)')
        plt.ylabel('Amplitude')
        plt.grid(True, alpha=0.3)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved to {save_path}")
        
        plt.show()



In [14]:
analyzer = FitEKGAnalyzerWithML('/Users/emccullough/Downloads/Activity_on_20240813_052850_by_Etienne_5010176_94deb87f66d7_FITFILE.fit')

In [15]:
# Function to download and prepare a PhysioNet model
def download_physionet_model(output_dir='ecg_model'):
    """
    Download and prepare a pre-trained PhysioNet ECG model
    
    This is a simplified example that shows how to access a model
    from PhysioNet's 2020 Challenge. In practice, you would need
    to adapt this to the specific model you want to use.
    
    Parameters:
    -----------
    output_dir : str
        Directory to save the model
    """
    import requests
    import zipfile
    import io
    import json
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    print("This function would normally download a model from PhysioNet.")
    print("For demonstration purposes, we'll create a placeholder model structure.")
    
    # In a real implementation, you would download the model files
    # Example:
    # url = "https://physionet.org/files/challenge-2020/1.0.1/models/team_name_model.zip"
    # response = requests.get(url)
    # with zipfile.ZipFile(io.BytesIO(response.content)) as z:
    #     z.extractall(output_dir)
    
    # Instead, let's create a simple Keras model as a placeholder
    # This is just for demonstration - use a real pre-trained model in practice
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(250, 1)),
        tf.keras.layers.Conv1D(filters=32, kernel_size=5, activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=64, kernel_size=5, activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(9, activation='softmax')  # 9 classes from PhysioNet 2020
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Save the model
    model.save(output_dir)
    
    # Save model metadata
    metadata = {
        'name': 'PhysioNet_2020_ECG_Classifier',
        'description': 'Placeholder model for ECG classification',
        'input_shape': [250, 1],
        'output_shape': 9,
        'classes': {
            '0': 'Normal',
            '1': 'AF',
            '2': 'IAVB',
            '3': 'LBBB',
            '4': 'RBBB',
            '5': 'PAC',
            '6': 'PVC',
            '7': 'STD',
            '8': 'STE'
        }
    }
    
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=4)
    
    print(f"Created placeholder model in {output_dir}")
    print("Note: This is not a trained model and won't make useful predictions.")
    print("In practice, you should download a real pre-trained model.")
    
    return os.path.abspath(output_dir)


def main():
    """
    Main function to demonstrate the EKG Analyzer with ML model
    """
    print("EKG Analysis with ML-based Classification")
    print("========================================")
    
    # Option to download a model
    download_model = input("Do you want to download/setup a pre-trained model? (y/n): ")
    model_path = None
    
    if download_model.lower() == 'y':
        model_path = download_physionet_model()
    else:
        model_path = input("Enter path to existing model (leave blank to use simple classification): ")
        if not model_path or not os.path.exists(model_path):
            model_path = None
    
    # Get file path from user
    file_path = input("Enter the path to your FIT file: ")
    
    # Create analyzer and run analysis
    analyzer = FitEKGAnalyzerWithML(file_path)
    if analyzer.run_full_analysis(model_path=model_path):
        analyzer.generate_report()
        
        # Ask if user wants to save a plot
        save_plot = input("Do you want to save a plot of the ECG with classifications? (y/n): ")
        if save_plot.lower() == 'y':
            plot_path = input("Enter the path to save the plot (or press Enter for default): ")
            if not plot_path:
                plot_path = os.path.splitext(file_path)[0] + "_ecg_analysis.png"
            analyzer.plot_ecg_with_classifications(save_path=plot_path)
        else:
            analyzer.plot_ecg_with_classifications()
    else:
        print("Analysis failed. Please check the file and try again.")


if __name__ == "__main__":
    main()

EKG Analysis with ML-based Classification


Do you want to download/setup a pre-trained model? (y/n):  y


This function would normally download a model from PhysioNet.
For demonstration purposes, we'll create a placeholder model structure.


NameError: name 'tf' is not defined