# ECG Signal Detection Inference & Classification
First Train the GAE using training notebook

In [None]:
# üì¶ Import Required Modules

# Import from src package
from src.models import EnhancedJointECGModel
from src.data_loader import get_dataloaders, ECGDataset
from src.graph_utils import build_knn_graph
from src.inference import predict, evaluate_model

# Standard libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, classification_report
from sklearn.preprocessing import label_binarize
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ All modules imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Configuration & Data Loading

In [None]:
# Configuration
INPUT_LENGTH = 187
LATENT_DIM = 128
NUM_CLASSES = 5
BATCH_SIZE = 512
K_NEIGHBORS = 5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# File paths
MODEL_PATH = 'models/best_model.pth'  # Trained model
TEST_DATA_PATH = 'examples/sample_ecgs.csv'  # Test data (change to your real data)

print(f"Using device: {DEVICE}")
print(f"Model path: {MODEL_PATH}")
print(f"Test data: {TEST_DATA_PATH}")

# Load test data using src.data_loader
try:
    # For demo, we'll use the same data for train/val/test
    # In practice, use separate test file
    _, _, test_loader = get_dataloaders(
        train_path=TEST_DATA_PATH,
        test_path=TEST_DATA_PATH,
        batch_size=BATCH_SIZE,
        val_split=0.0,  # No validation for inference
        test_split=1.0   # All data as test
    )
    
    print("‚úÖ Test data loaded successfully!")
    print(f"Test batches: {len(test_loader)}")
    
    # Get full dataset for detailed analysis
    test_dataset = ECGDataset(TEST_DATA_PATH, transform=None)
    X_test_full = test_dataset.data.values
    y_test_full = test_dataset.labels.values
    
    print(f"Full test set: {len(X_test_full)} samples")
    print(f"Classes: {sorted(np.unique(y_test_full))}")
    
except FileNotFoundError:
    print("‚ùå Test data file not found. Please update TEST_DATA_PATH")
    print("For demo purposes, we'll create synthetic data")
    
    # Create synthetic test data for demonstration
    np.random.seed(42)
    X_test_full = np.random.randn(1000, 187).astype(np.float32)
    y_test_full = np.random.randint(0, 5, 1000)
    
    print("‚úÖ Synthetic test data created for demonstration")

## Load Trained Model

In [None]:
# Initialize model using src.models
model = EnhancedJointECGModel(
    input_length=INPUT_LENGTH,
    latent_dim=LATENT_DIM,
    num_classes=NUM_CLASSES
).to(DEVICE)

# Load trained weights
try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    print("‚úÖ Model loaded successfully!")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
except FileNotFoundError:
    print(f"‚ùå Model file not found at {MODEL_PATH}")
    print("Please train a model first or update MODEL_PATH")
    print("For demo, we'll use randomly initialized model")

## Event Detection Functions

These functions simulate event detection (since we don't have a trained U-Net).
In production, you'd replace these with your trained event detection model.

In [None]:
def simple_event_detection(signals, window_size=40, threshold_factor=0.5):
    """
    Simple peak-based event detection using amplitude thresholding.
    
    Args:
        signals: Array of ECG signals (n_samples, n_timesteps)
        window_size: Size of sliding window for local statistics
        threshold_factor: Multiplier for threshold (mean + factor*std)
    
    Returns:
        masks: Binary masks indicating detected events
    """
    masks = []
    
    for signal in signals:
        # Compute local statistics
        local_mean = pd.Series(signal).rolling(window=window_size, center=True).mean()
        local_std = pd.Series(signal).rolling(window=window_size, center=True).std()
        
        # Fill NaN values
        local_mean = local_mean.fillna(method='bfill').fillna(method='ffill')
        local_std = local_std.fillna(method='bfill').fillna(method='ffill')
        
        # Threshold: local_mean + threshold_factor * local_std
        threshold = local_mean + threshold_factor * local_std
        
        # Detect events
        mask = (signal > threshold).astype(int)
        masks.append(mask)
    
    return np.array(masks)

def extend_events_with_baseline(signals, masks, baseline_duration_ms=50, sampling_rate=360):
    """
    Extend detected events with baseline context.
    
    Args:
        signals: ECG signals
        masks: Event masks
        baseline_duration_ms: Baseline duration in milliseconds
        sampling_rate: ECG sampling rate
    
    Returns:
        segments: Extended segments
        segment_info: Metadata about each segment
    """
    baseline_samples = int(baseline_duration_ms * sampling_rate / 1000)
    segments = []
    segment_info = []
    
    for idx, (signal, mask) in enumerate(zip(signals, masks)):
        event_indices = np.where(mask == 1)[0]
        
        if len(event_indices) == 0:
            continue
            
        # Group consecutive events
        diff = np.diff(event_indices, prepend=0, append=len(signal))
        group_starts = np.where(diff > 1)[0]
        group_ends = np.where(diff > 1)[0][1:] - 1
        if len(group_ends) < len(group_starts):
            group_ends = np.append(group_ends, len(event_indices) - 1)
        
        for start_idx, end_idx in zip(group_starts, group_ends):
            event_start = event_indices[start_idx]
            event_end = event_indices[end_idx]
            
            # Extend with baseline
            seg_start = max(0, event_start - baseline_samples)
            seg_end = min(len(signal), event_end + baseline_samples)
            
            segment = signal[seg_start:seg_end]
            segments.append(segment)
            
            segment_info.append({
                'original_idx': idx,
                'event_start': event_start,
                'event_end': event_end,
                'segment_start': seg_start,
                'segment_end': seg_end,
                'baseline_before': event_start - seg_start,
                'baseline_after': seg_end - event_end
            })
    
    return segments, segment_info

def preprocess_segment(segment, target_length=187):
    """
    Resample segment to target length using interpolation.
    
    Args:
        segment: ECG segment
        target_length: Desired length
    
    Returns:
        Processed segment
    """
    if len(segment) == target_length:
        return segment
    
    # Simple interpolation
    x_old = np.linspace(0, 1, len(segment))
    x_new = np.linspace(0, 1, target_length)
    
    return np.interp(x_new, x_old, segment)

print("‚úÖ Event detection functions defined!")

## Model Inference Pipeline

Process detected events through the trained GAE model.

In [None]:
# Event Detection on Test Data
print("üîç Detecting events in test data...")
event_masks = simple_event_detection(X_test_full, threshold_factor=0.5)
print(f"Events detected in {np.sum(event_masks.sum(axis=1) > 0)}/{len(X_test_full)} signals")

# Extract segments with baseline
print("üìè Extracting segments with baseline context...")
segments, segment_info = extend_events_with_baseline(
    X_test_full, event_masks, 
    baseline_duration_ms=50
)
print(f"Extracted {len(segments)} segments")

# Preprocess segments
print("üîß Preprocessing segments...")
segments_processed = [preprocess_segment(seg, INPUT_LENGTH) for seg in segments]
segments_tensor = torch.tensor(segments_processed, dtype=torch.float32).unsqueeze(1).to(DEVICE)

print(f"‚úÖ Preprocessing complete: {segments_tensor.shape[0]} segments of length {INPUT_LENGTH}")

# Model Inference using src.inference
print("üöÄ Running model inference...")

# Extract CNN features
with torch.no_grad():
    cnn_features = model.ecg_encoder(segments_tensor)
    print(f"CNN features shape: {cnn_features.shape}")

# Build k-NN graph using src.graph_utils
edge_index = build_knn_graph(cnn_features, k=K_NEIGHBORS)
edge_index = edge_index.to(DEVICE)
print(f"Graph edges: {edge_index.shape[1]}")

# Run inference using src.inference.predict
predictions, latent_representations, reconstructions = predict(
    model=model,
    data=segments_tensor,
    edge_index=edge_index,
    batch_size=BATCH_SIZE,
    device=DEVICE
)

print("‚úÖ Inference complete!")
print(f"Predictions shape: {predictions.shape}")
print(f"Latent shape: {latent_representations.shape}")
print(f"Reconstructions shape: {reconstructions.shape}")

## Full Test Set Evaluation

Evaluate the model on the entire test dataset using `src.inference.evaluate_model`.

In [None]:
# Full test set evaluation using src.inference
print("üìä Evaluating model on full test set...")

# Use the evaluate_model function from src
metrics, visualizations = evaluate_model(
    model=model,
    test_loader=test_loader,
    device=DEVICE,
    k_neighbors=K_NEIGHBORS,
    save_plots=True,
    plot_prefix='detection_with_src'
)

print("‚úÖ Evaluation complete!")
print("\n" + "="*50)
print("PERFORMANCE METRICS")
print("="*50)
print(f"Overall Accuracy: {metrics['accuracy']:.4f}")
print(f"Macro F1-Score: {metrics['macro_f1']:.4f}")
print(f"Weighted F1-Score: {metrics['weighted_f1']:.4f}")

print("\nPer-Class Performance:")
for i, (prec, rec, f1) in enumerate(zip(metrics['precision'], metrics['recall'], metrics['f1'])):
    print(f"Class {i}: Precision={prec:.3f}, Recall={rec:.3f}, F1={f1:.3f}")

# Display confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues',
            xticklabels=[f'Class {i}' for i in range(NUM_CLASSES)],
            yticklabels=[f'Class {i}' for i in range(NUM_CLASSES)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Full Test Set)')
plt.tight_layout()
plt.savefig('results/confusion_matrix_detection_src.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Confusion matrix saved!")

## Advanced Visualizations

In [None]:
# t-SNE Visualization of Latent Space
print("üé® Computing t-SNE visualization...")

# Get true labels for segments (map back from segment_info)
segment_labels = []
for info in segment_info:
    original_idx = info['original_idx']
    segment_labels.append(y_test_full[original_idx])

segment_labels = np.array(segment_labels)

# t-SNE on latent representations
tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
latent_2d = tsne.fit_transform(latent_representations)

# Plot
plt.figure(figsize=(12, 8))

# Color by true class
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                     c=segment_labels, cmap='viridis', alpha=0.6, s=30)

# Color by predicted class
plt.figure(figsize=(12, 8))
scatter_pred = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                          c=predictions, cmap='plasma', alpha=0.6, s=30)

plt.colorbar(scatter_pred, label='Predicted Class')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title('Latent Space Visualization (Colored by Predictions)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('results/tsne_predictions_src.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ t-SNE visualizations saved!")

In [None]:
# Sample Predictions Visualization
print("üìã Analyzing sample predictions...")

# Get some examples from each class
unique_classes = np.unique(segment_labels)
fig, axes = plt.subplots(len(unique_classes), 4, figsize=(20, 5*len(unique_classes)))

for i, class_idx in enumerate(unique_classes):
    # Find samples of this class
    class_mask = segment_labels == class_idx
    class_indices = np.where(class_mask)[0]
    
    if len(class_indices) == 0:
        continue
    
    # Get correct and incorrect predictions
    class_preds = predictions[class_indices]
    correct_mask = class_preds == class_idx
    incorrect_mask = class_preds != class_idx
    
    # Plot examples
    examples_to_show = []
    titles = []
    
    # Correct prediction
    if np.any(correct_mask):
        idx = class_indices[np.where(correct_mask)[0][0]]
        examples_to_show.append(idx)
        titles.append(f'Class {class_idx} (Correct)')
    
    # Incorrect prediction
    if np.any(incorrect_mask):
        idx = class_indices[np.where(incorrect_mask)[0][0]]
        examples_to_show.append(idx)
        titles.append(f'Class {class_idx} (Wrong‚Üí{int(predictions[idx])})')
    
    # Show up to 4 examples
    for j, (idx, title) in enumerate(zip(examples_to_show, titles)):
        if j >= 4:
            break
            
        ax = axes[i, j] if len(unique_classes) > 1 else axes[j]
        
        # Original segment
        ax.plot(segments_processed[idx], 'b-', alpha=0.7, label='Original')
        
        # Reconstruction
        ax.plot(reconstructions[idx], 'r--', alpha=0.7, label='Reconstructed')
        
        ax.set_title(f'{title}\nMSE: {np.mean((segments_processed[idx] - reconstructions[idx])**2):.4f}')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('results/sample_predictions_src.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Sample predictions visualization saved!")

In [None]:
# ROC Curves for Multi-Class
print("üìà Computing ROC curves...")

# Binarize labels for multi-class ROC
y_true_bin = label_binarize(segment_labels, classes=range(NUM_CLASSES))
y_pred_prob = torch.softmax(torch.tensor(logits_all), dim=1).numpy()

# Compute ROC curve for each class
plt.figure(figsize=(12, 8))

colors = ['blue', 'red', 'green', 'orange', 'purple']
for i in range(NUM_CLASSES):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_pred_prob[:, i])
    auc_score = roc_auc_score(y_true_bin[:, i], y_pred_prob[:, i])
    
    plt.plot(fpr, tpr, color=colors[i], 
             label=f'Class {i} (AUC = {auc_score:.3f})')

plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves (Multi-Class, One-vs-Rest)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('results/roc_curves_src.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ ROC curves saved!")

## Interactive Prediction Viewer

Browse individual predictions with an interactive interface.

In [None]:
class InteractiveECGViewer:
    def __init__(self, segments, reconstructions, predictions, true_labels, segment_info):
        self.segments = segments
        self.reconstructions = reconstructions
        self.predictions = predictions
        self.true_labels = true_labels
        self.segment_info = segment_info
        
        self.current_idx = 0
        
        # Create widgets
        self.idx_slider = widgets.IntSlider(
            value=0, min=0, max=len(segments)-1, step=1,
            description='Sample:', continuous_update=False
        )
        
        self.prev_button = widgets.Button(description='Previous')
        self.next_button = widgets.Button(description='Next')
        self.random_button = widgets.Button(description='Random')
        
        self.true_class_filter = widgets.Dropdown(
            options=[('All', -1)] + [(f'Class {i}', i) for i in range(NUM_CLASSES)],
            value=-1, description='True Class:'
        )
        
        self.pred_class_filter = widgets.Dropdown(
            options=[('All', -1)] + [(f'Class {i}', i) for i in range(NUM_CLASSES)],
            value=-1, description='Pred Class:'
        )
        
        self.status_filter = widgets.Dropdown(
            options=[('All', -1), ('Correct', 1), ('Incorrect', 0)],
            value=-1, description='Status:'
        )
        
        # Connect callbacks
        self.idx_slider.observe(self.update_display, names='value')
        self.prev_button.on_click(self.prev_sample)
        self.next_button.on_click(self.next_sample)
        self.random_button.on_click(self.random_sample)
        self.true_class_filter.observe(self.filter_samples, names='value')
        self.pred_class_filter.observe(self.filter_samples, names='value')
        self.status_filter.observe(self.filter_samples, names='value')
        
        # Filtered indices
        self.filtered_indices = list(range(len(segments)))
        
        # Output area
        self.output = widgets.Output()
    
    def filter_samples(self, change):
        """Filter samples based on criteria"""
        true_filter = self.true_class_filter.value
        pred_filter = self.pred_class_filter.value
        status_filter = self.status_filter.value
        
        indices = []
        for i in range(len(self.segments)):
            true_class = self.true_labels[i]
            pred_class = self.predictions[i]
            is_correct = (true_class == pred_class)
            
            if true_filter != -1 and true_class != true_filter:
                continue
            if pred_filter != -1 and pred_class != pred_filter:
                continue
            if status_filter != -1 and is_correct != bool(status_filter):
                continue
                
            indices.append(i)
        
        self.filtered_indices = indices
        if self.current_idx >= len(self.filtered_indices):
            self.current_idx = 0
        
        self.idx_slider.max = len(self.filtered_indices) - 1
        self.idx_slider.value = min(self.current_idx, len(self.filtered_indices) - 1)
        
        self.update_display(None)
    
    def prev_sample(self, b):
        if len(self.filtered_indices) > 0:
            self.current_idx = (self.current_idx - 1) % len(self.filtered_indices)
            self.idx_slider.value = self.current_idx
    
    def next_sample(self, b):
        if len(self.filtered_indices) > 0:
            self.current_idx = (self.current_idx + 1) % len(self.filtered_indices)
            self.idx_slider.value = self.current_idx
    
    def random_sample(self, b):
        if len(self.filtered_indices) > 0:
            self.current_idx = np.random.randint(0, len(self.filtered_indices))
            self.idx_slider.value = self.current_idx
    
    def update_display(self, change):
        with self.output:
            clear_output(wait=True)
            
            if len(self.filtered_indices) == 0:
                print("No samples match the current filters.")
                return
            
            idx = self.filtered_indices[self.current_idx]
            
            # Get data
            original = self.segments[idx]
            reconstructed = self.reconstructions[idx]
            true_class = self.true_labels[idx]
            pred_class = self.predictions[idx]
            info = self.segment_info[idx]
            
            # Create plot
            fig = make_subplots(
                rows=2, cols=1,
                subplot_titles=('ECG Signal', 'Reconstruction Error'),
                shared_xaxes=True
            )
            
            # Original vs Reconstructed
            fig.add_trace(
                go.Scatter(y=original, mode='lines', name='Original', 
                          line=dict(color='blue')),
                row=1, col=1
            )
            
            fig.add_trace(
                go.Scatter(y=reconstructed, mode='lines', name='Reconstructed',
                          line=dict(color='red', dash='dash')),
                row=1, col=1
            )
            
            # Reconstruction error
            error = original - reconstructed
            fig.add_trace(
                go.Scatter(y=error, mode='lines', name='Error',
                          line=dict(color='green')),
                row=2, col=1
            )
            
            fig.update_layout(height=600, title_text=f'Sample {idx} (Filtered Index {self.current_idx})')
            
            # Print metadata
            print(f"Sample {idx} (showing {self.current_idx+1}/{len(self.filtered_indices)} filtered)")
            print(f"True Class: {true_class}, Predicted: {pred_class}")
            print(f"Correct: {'‚úÖ' if true_class == pred_class else '‚ùå'}")
            print(f"MSE: {np.mean(error**2):.6f}")
            print(f"Correlation: {np.corrcoef(original, reconstructed)[0,1]:.4f}")
            print(f"Event Info: Start={info['event_start']}, End={info['event_end']}")
            print(f"Baseline: {info['baseline_before']}ms before, {info['baseline_after']}ms after")
            
            fig.show()
    
    def display(self):
        """Display the interactive interface"""
        controls = widgets.HBox([
            self.idx_slider,
            widgets.VBox([self.prev_button, self.next_button, self.random_button])
        ])
        
        filters = widgets.HBox([
            self.true_class_filter,
            self.pred_class_filter,
            self.status_filter
        ])
        
        display(widgets.VBox([controls, filters, self.output]))
        
        # Initial display
        self.update_display(None)

# Create and display viewer
print("üéÆ Creating interactive ECG viewer...")
viewer = InteractiveECGViewer(
    segments_processed,
    reconstructions,
    predictions,
    segment_labels,
    segment_info
)

viewer.display()

print("‚úÖ Interactive viewer ready! Use the controls to explore predictions.")

### Saved Files

- `results/confusion_matrix_detection_src.png`
- `results/tsne_predictions_src.png`
- `results/sample_predictions_src.png`
- `results/roc_curves_src.png`
