In [1]:
import tensorflow as tf
import numpy as np
import wfdb
import wfdb.processing
from scipy.signal import butter, filtfilt
from collections import Counter
import os

In [2]:
# --- 1. Define Helper Functions & Constants (Copied from your project) ---

# The "Forward Map" (Characters -> Integers)
# Not strictly needed for *this* script's pipeline, but good to have
AAMI_MAPPING = {
    # Class N (Normal Beats -> 0)
    'N': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0, '/': 0,
    
    # Class S (Supraventricular Ectopic Beats -> 1)
    'A': 1, 'a': 1, 'J': 1, 'S': 1,
    
    # Class V (Ventricular Ectopic Beats -> 2)
    'V': 2, 'E': 2,
    
    # Class F (Fusion Beats -> 3)
    'F': 3, 'f': 3,
    
    # Class Q (Unclassifiable/Other Beats -> 4)
    'Q': 4,
    
    # Non-Beat/Noise Annotations (To be filtered out -> -1)
    '!': -1, '"': -1, '+': -1, '[': -1, ']': -1,
    '|': -1, 'x': -1, '~': -1
}

# The "Reverse Map" (Integers -> Strings)
# This IS needed for the final report in Step 5
AAMI_CLASS_NAMES = ['N (Normal)', 'S (Supraventricular)', 'V (Ventricular)', 'F (Fusion)', 'Q (Unclassifiable)']

def bandpass_filter(segment, fs=360, lowcut=0.5, highcut=45, order=4):
    """Applies the bandpass filter."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, segment, axis=0)

def global_zscore_normalize(segment, mean, std):
    """Normalizes a segment using the pre-calculated global mean and std."""
    return (segment - mean) / (std + 1e-8) # Add 1e-8 to prevent division by zero

In [3]:
# --- 2. CRITICAL: Set Your Preprocessing Stats ---

GLOBAL_MEAN = np.array([ 0.02254995, -0.00905314]) 
GLOBAL_STD = np.array([ 0.41518577,  0.33936556])  


# Path to your database
DB_PATH = r"C:\Me\College\4th\Ai_in_healthcare\Ai_in_healthcare_project\database"

RECORD_TO_ANALYZE = "234" 
RECORD_FILE = os.path.join(DB_PATH, RECORD_TO_ANALYZE)

# Model configuration
WINDOW_SIZE = 180



In [4]:
def analyze_full_record(record_path):

    print(f"--- Starting 5-Step Pipeline for Record: {RECORD_TO_ANALYZE} ---")

    # --- Step 1: Load the Raw Signal ---
    print(f"\nStep 1: Loading raw signal from {record_path}...")
    try:
        record = wfdb.rdrecord(record_path)
        signal = record.p_signal
        fs = record.fs
        print(f"Signal loaded successfully. Shape: {signal.shape}, Fs: {fs} Hz")
    except Exception as e:
        print(f"Error loading record: {e}")
        return None, None # <-- MODIFIED: Return None on failure

    # --- Step 2: Find All Heartbeats (QRS Detection) ---
    print(f"\nStep 2: Finding all heartbeats (QRS detection)...")
    # We use the 'gqrs_detect' algorithm on the first channel (MLII)
    qrs_indices = wfdb.processing.gqrs_detect(sig=signal[:, 0], fs=fs)
    print(f"Found {len(qrs_indices)} heartbeats (QRS peaks).")

    # --- Step 3: Create a "Batch" of All Segments ---
    print("\nStep 3: Extracting, filtering, and normalizing all segments...")
    patient_segments = []
    
    for sample_idx in qrs_indices:
        # Check if the beat is too close to the edge
        if sample_idx - WINDOW_SIZE//2 < 0 or sample_idx + WINDOW_SIZE//2 >= signal.shape[0]:
            continue # Skip this beat

        # Extract: (Window is 180 samples)
        seg = signal[sample_idx - WINDOW_SIZE//2 : sample_idx + WINDOW_SIZE//2]
        
        # Filter:
        filtered_seg = bandpass_filter(seg, fs=fs)
        
        # Normalize: (Using the GLOBAL stats)
        normalized_seg = global_zscore_normalize(filtered_seg, GLOBAL_MEAN, GLOBAL_STD)
        
        # Append to our batch
        patient_segments.append(normalized_seg)

    # Stack all segments into a single NumPy array
    X_new_patient = np.array(patient_segments)
    
    # --- FIX: Handle case where no valid segments were found ---
    if X_new_patient.shape[0] == 0:
        print("No valid segments could be extracted from the record.")
        return None, None # <-- MODIFIED: Return None on failure
        
    print(f"Created batch of {X_new_patient.shape[0]} valid segments.")
# ... existing code ... # e.g., (2135, 180, 2)

    # --- Step 4: Run Batch Prediction ---
    print("\nStep 4: Loading model and running batch prediction...")
    try:
        model = tf.keras.models.load_model('optimal_ecg_model.keras')
        print("Model 'optimal_ecg_model.keras' loaded.")
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None # <-- MODIFIED: Return None on failure

    # Predict on the entire batch at once (very fast)
    all_predictions_proba = model.predict(X_new_patient)
    
    # Get Final Labels (e.g., [0, 0, 1, 0, 2, ...])
    all_predicted_labels = np.argmax(all_predictions_proba, axis=1)
    print(f"Prediction complete. Got {len(all_predicted_labels)} labels.")

    # --- Step 5: Compute the Final Result (Aggregation) ---
    print("\nStep 5: Generating final arrhythmia report...")
    
    label_counts = Counter(all_predicted_labels)
    total_beats = len(all_predicted_labels)

    # --- Build one single string for the report.
    report_string = ""
    report_string += "\n" + "="*40 + "\n"
    report_string += "     MODEL PREDICTION REPORT\n"
    report_string += f"     Record: {RECORD_TO_ANALYZE}\n"
    report_string += "="*40 + "\n"
    report_string += f"Total Beats Detected (by gqrs_detect): {total_beats}\n\n"

    if total_beats > 0:
        # Class 0: N
        count_n = label_counts.get(0, 0)
        percent_n = (count_n / total_beats) * 100
        report_string += f"  {AAMI_CLASS_NAMES[0]:<25}: {count_n:<6} beats ({percent_n:.2f}%)\n"
        
        # Class 1: S
        count_s = label_counts.get(1, 0)
        percent_s = (count_s / total_beats) * 100
        report_string += f"  {AAMI_CLASS_NAMES[1]:<25}: {count_s:<6} beats ({percent_s:.2f}%)\n"
        
        # Class 2: V
        count_v = label_counts.get(2, 0)
        percent_v = (count_v / total_beats) * 100
        report_string += f"  {AAMI_CLASS_NAMES[2]:<25}: {count_v:<6} beats ({percent_v:.2f}%)\n"
        
        # Class 3: F
        count_f = label_counts.get(3, 0)
        percent_f = (count_f / total_beats) * 100
        report_string += f"  {AAMI_CLASS_NAMES[3]:<25}: {count_f:<6} beats ({percent_f:.2f}%)\n"
        
        # Class 4: Q
        count_q = label_counts.get(4, 0)
        percent_q = (count_q / total_beats) * 100
        report_string += f"  {AAMI_CLASS_NAMES[4]:<25}: {count_q:<6} beats ({percent_q:.2f}%)"
    else:
        report_string = "  No beats were analyzed."

    report_string += "\n" + "="*40
    
    # Print the entire report string at once to the console
    print(report_string, flush=True)

    # --- Save the report to a text file ---
    report_filename = f"final_report_{RECORD_TO_ANALYZE}.txt"
    try:
        with open(report_filename, "w") as f:
            f.write(report_string)
        print(f"\nSUCCESS: Full report also saved to {report_filename}")
        print("Please open this file to see the untruncated output.")
    except Exception as e:
        print(f"\nWarning: Could not save report file. {e}")
        
    # --- MODIFIED: Return the full list of labels for Step 6 ---
    return all_predicted_labels, total_beats

In [5]:
# --- 4. NEW FUNCTION: Check Ground Truth ---
def check_ground_truth(record_path):
    """
    Loads the .atr file for a record, maps the symbols to AAMI classes,
    and prints the "ground truth" counts for comparison.
    """
    
    print("\n" + "="*40)
    print("     GENERATING GROUND TRUTH REPORT")
    print(f"     Record: {RECORD_TO_ANALYZE}")
    print("="*40)

    try:
        # 1. Load the ground truth annotations
        annotation = wfdb.rdann(record_path, 'atr')
        raw_symbols = annotation.symbol
        
        # 2. Map symbols to AAMI integers (0-4)
        # We must filter out the non-beat (-1) symbols
        true_labels_int = []
        for s in raw_symbols:
            mapped_class = AAMI_MAPPING.get(s, -1) # Get class, default to -1
            if mapped_class != -1: # Only include valid beats (0-4)
                true_labels_int.append(mapped_class)

        # 3. Count the mapped labels
        label_counts = Counter(true_labels_int)
        
        # Get total counts
        total_beats = len(true_labels_int) # This is the count *after* filtering non-beats
        total_raw_annotations = len(raw_symbols) # This is the *total* count
        
        # --- 4. Build the report string (THIS IS THE FIX) ---
        report_string = ""
        report_string += "\n" + "="*40 + "\n"
        report_string += "     GROUND TRUTH REPORT\n"
        report_string += f"     Record: {RECORD_TO_ANALYZE}\n"
        report_string += "="*40 + "\n"
        report_string += f"Total Valid Beats (from .atr file): {total_beats}\n"
        report_string += f"(Total Raw Annotations in file: {total_raw_annotations})\n\n"

        if total_beats > 0:
            count_n = label_counts.get(0, 0)
            percent_n = (count_n / total_beats) * 100
            report_string += f"  {AAMI_CLASS_NAMES[0]:<25}: {count_n:<6} beats ({percent_n:.2f}%)\n"
            
            count_s = label_counts.get(1, 0)
            percent_s = (count_s / total_beats) * 100
            report_string += f"  {AAMI_CLASS_NAMES[1]:<25}: {count_s:<6} beats ({percent_s:.2f}%)\n"
            
            count_v = label_counts.get(2, 0)
            percent_v = (count_v / total_beats) * 100
            report_string += f"  {AAMI_CLASS_NAMES[2]:<25}: {count_v:<6} beats ({percent_v:.2f}%)\n"
            
            count_f = label_counts.get(3, 0)
            percent_f = (count_f / total_beats) * 100
            report_string += f"  {AAMI_CLASS_NAMES[3]:<25}: {count_f:<6} beats ({percent_f:.2f}%)\n"
            
            count_q = label_counts.get(4, 0)
            percent_q = (count_q / total_beats) * 100
            report_string += f"  {AAMI_CLASS_NAMES[4]:<25}: {count_q:<6} beats ({percent_q:.2f}%)"
        else:
            report_string = "  No valid beats found in .atr file."

        report_string += "\n" + "="*40
        
        # --- MODIFICATION: Removed the console print() statement ---
        
        # --- Save the report to a text file ---
        report_filename = f"ground_truth_report_{RECORD_TO_ANALYZE}.txt"
        try:
            with open(report_filename, "w") as f:
                f.write(report_string)
            print(f"\nSUCCESS: Ground truth report saved to {report_filename}")
        except Exception as e:
            print(f"\nWarning: Could not save report file. {e}")

    except Exception as e:
        print(f"Error loading ground truth annotations: {e}")

In [6]:
# --- 5. "SMARTER" DIAGNOSTIC FUNCTION (THE FIX) ---
def generate_diagnostic_report(all_predicted_labels, total_beats):
    """
    Takes the *full list* of predicted labels (not just the counts)
    and generates a smarter "Yes/No" diagnosis by looking for patterns
    and clinical frequency thresholds.
    """
    
    report_string = ""
    report_string += "\n" + "="*50 + "\n"
    report_string += "     FINAL DIAGNOSTIC REPORT\n"
    report_string += f"     Record: {RECORD_TO_ANALYZE}\n"
    report_string += "="*50 + "\n"
    
    # --- THIS IS THE FIX ---
    # We will build the report string FIRST, and only save/print at the end.
    
    if total_beats is None or total_beats == 0:
        report_string += "Diagnosis: Inconclusive (No beats were detected or processed).\n"
        report_string += "This may be due to an error loading the model or the record."
    
    else:
        # --- 1. Analyze the sequence for patterns ---
        label_counts = Counter(all_predicted_labels)
        v_tach_found = False
        v_couplets = 0
        s_tach_found = False
        s_couplets = 0
        consecutive_v_count = 0
        consecutive_s_count = 0
        
        # We iterate through the *list* of predictions to find patterns
        for label in all_predicted_labels:
            # Check for Ventricular (V) patterns
            if label == 2: # Class 'V'
                consecutive_v_count += 1
            else:
                if consecutive_v_count == 2:
                    v_couplets += 1
                elif consecutive_v_count >= 3:
                    v_tach_found = True
                consecutive_v_count = 0 # Reset V counter

            # Check for Supraventricular (S) patterns
            if label == 1: # Class 'S'
                consecutive_s_count += 1
            else:
                if consecutive_s_count == 2:
                    s_couplets += 1
                elif consecutive_s_count >= 3:
                    s_tach_found = True
                consecutive_s_count = 0 # Reset S counter
        
        # Check tail case (if the recording ends on a pattern)
        if consecutive_v_count == 2:
            v_couplets += 1
        elif consecutive_v_count >= 3:
            v_tach_found = True
            
        if consecutive_s_count == 2:
            s_couplets += 1
        elif consecutive_s_count >= 3:
            s_tach_found = True

        # --- 2. Get overall counts ---
        count_n = label_counts.get(0, 0)
        count_s = label_counts.get(1, 0) # Supraventricular
        count_v = label_counts.get(2, 0) # Ventricular
        count_f = label_counts.get(3, 0) # Fusion
        total_arrhythmia_beats = count_s + count_v + count_f
        percent_arrhythmia = (total_arrhythmia_beats / total_beats) * 100
        percent_v = (count_v / total_beats) * 100
        percent_s = (count_s / total_beats) * 100 # Added for S-burden

        # --- 3. The "Yes/No" Detection (Smarter Rules) ---
        # These thresholds are examples. A real clinical product
        # would have these fine-tuned by a medical team.
        
        # We consider a "significant burden" to be > 5% of total beats
        # for either S or V classes.
        V_BURDEN_THRESHOLD = 5.0 
        S_BURDEN_THRESHOLD = 5.0
        
        is_significant = False
        if v_tach_found:
            is_significant = True
            report_string += "Heart Disease Detection: YES (CRITICAL: Ventricular Tachycardia detected)\n"
        elif s_tach_found:
            is_significant = True
            report_string += "Heart Disease Detection: YES (SIGNIFICANT: Supraventricular Tachycardia detected)\n"
        elif percent_v > V_BURDEN_THRESHOLD:
             is_significant = True
             report_string += f"Heart Disease Detection: YES (Significant Ventricular Burden: {percent_v:.2f}%)\n"
        elif percent_s > S_BURDEN_THRESHOLD:
             is_significant = True
             report_string += f"Heart Disease Detection: YES (Significant Supraventricular Burden: {percent_s:.2f}%)\n"
        elif total_arrhythmia_beats > 0:
            report_string += "Heart Disease Detection: NO (Benign/Occasional Arrhythmia Detected)\n"
        else:
            report_string += "Heart Disease Detection: NO (Normal Sinus Rhythm Dominant)\n"
            
        # --- 4. The Diagnostic Summary (Smarter Report) ---
        report_string += "\nDiagnostic Summary:\n"
        
        if total_arrhythmia_beats == 0:
            report_string += f"  - All {total_beats} detected beats were classified as Normal.\n"
            report_string += "  - No arrhythmias were detected in this recording.\n"
        else:
            # Report on Ventricular findings (most serious)
            if v_tach_found:
                report_string += "  - !!! CRITICAL FINDING: Detected Ventricular Tachycardia (3+ 'V' beats in a row).\n"
            elif v_couplets > 0:
                report_string += f"  - Detected {v_couplets} Ventricular Couplet(s) (2 'V' beats in a row).\n"
            elif count_v > 0:
                 report_string += f"  - Detected {count_v} isolated Ventricular beat(s) (PVCs).\n"

            # Report on Supraventricular findings
            if s_tach_found:
                report_string += f"  - SIGNIFICANT FINDING: Detected Supraventricular Tachycardia (3+ 'S' beats in a row).\n"
            elif s_couplets > 0:
                report_string += f"  - Detected {s_couplets} Supraventricular Couplet(s) (2 'S' beats in a row).\n"
            elif count_s > 0:
                report_string += f"  - Detected {count_s} isolated Supraventricular beat(s) (APBs).\n"
            
            # Report on Fusion beats
            if count_f > 0:
                report_string += f"  - Detected {count_f} Fusion beat(s).\n"
            
            # Add clinical context
            if not is_significant and total_arrhythmia_beats > 0:
                report_string += f"\n  - Clinical Note: Occasional arrhythmias were detected ({percent_arrhythmia:.2f}%), \n"
                report_string += "    but no dangerous patterns or significant burden (>5%) was found.\n"

    report_string += "\n" + "="*50
    
    # Print the full diagnostic report
    print(report_string, flush=True)

    # Save the diagnostic report to its own file
    report_filename = f"diagnostic_report_{RECORD_TO_ANALYZE}.txt"
    try:
        with open(report_filename, "w") as f:
            f.write(report_string)
        print(f"\nSUCCESS: Diagnostic report also saved to {report_filename}")
    except Exception as e:
        print(f"\nWarning: Could not save diagnostic report file: {e}")
    
    # Now we can safely return
    return

In [7]:
if __name__ == "__main__":
    if np.array_equal(GLOBAL_MEAN, np.array([ 0.02254995, -0.00905314])) == False:
         print("="*50)
         print("ERROR: Please update GLOBAL_MEAN and GLOBAL_STD in this script!")
         print("Copy the values from Cell 5 of your 'dataset_access_FIXED.ipynb' notebook.")
         print("="*50)
    else:
        all_pred_labels, total_beats = analyze_full_record(RECORD_FILE)
        check_ground_truth(RECORD_FILE)
        generate_diagnostic_report(all_pred_labels, total_beats)

--- Starting 5-Step Pipeline for Record: 234 ---

Step 1: Loading raw signal from C:\Me\College\4th\Ai_in_healthcare\Ai_in_healthcare_project\database\234...
Signal loaded successfully. Shape: (650000, 2), Fs: 360 Hz

Step 2: Finding all heartbeats (QRS detection)...
Found 2753 heartbeats (QRS peaks).

Step 3: Extracting, filtering, and normalizing all segments...
Created batch of 2753 valid segments.

Step 4: Loading model and running batch prediction...
Model 'optimal_ecg_model.keras' loaded.
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 23ms/step
Prediction complete. Got 2753 labels.

Step 5: Generating final arrhythmia report...

     MODEL PREDICTION REPORT
     Record: 234
Total Beats Detected (by gqrs_detect): 2753

  N (Normal)               : 2698   beats (98.00%)
  S (Supraventricular)     : 52     beats (1.89%)
  V (Ventricular)          : 3      beats (0.11%)
  F (Fusion)               : 0      beats (0.00%)
  Q (Unclassifiable)       : 0      beats (0.00%