In [None]:
from brpylib import NsxFile
import os
import numpy as np
import matplotlib.pyplot as plt

import direct_neural_biasing as dnb

## read data

In [None]:
data_file = np.load('data/Patient2EEG.npy')
mrk_file = 'data/Patient02_OfflineMrk.mrk'
data = data_file[0]

In [None]:
np.shape(data)

In [None]:
def plot_data_stream_section(data_stream, n, m):
    """
    Plots an n-m section of a 1D integer data stream as a line graph.

    Args:
        data_stream (list or numpy.ndarray): The 1D integer data stream.
        n (int): The starting index (inclusive).
        m (int): The ending index (inclusive).
    """
    section = data_stream[n : m + 1]
    indices = range(n, m + 1)

    plt.figure(figsize=(14, 6))
    plt.plot(indices, section)
    plt.title(f'Data Section {n} to {m}')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.grid(True)
    plt.show()

start_t = 2515
end_t = 4515

plot_data_stream_section(data, start_t, end_t)

In [None]:
# --- Concise .mrk file parser ---
def parse_mrk_file_concise(filepath):
    """
    Parses a .mrk file into a dictionary (signal_type: [indices]).
    Assumes first line is header, subsequent lines are 'index index signal_type'.
    """
    mrk_data = {}
    with open(filepath, 'r') as f:
        next(f) # Skip header line
        for line in f:
            parts = line.strip().split()
            if len(parts) == 3:
                index = int(parts[0])
                signal_type = parts[2]
                mrk_data.setdefault(signal_type, []).append(index)
    return mrk_data

def plot_marker_with_context(data_stream, marker_index, signal_type, context_window=500, output_dir="marker_plots"):
    """
    Plots a single marker with surrounding data context.
    """
    marker_index = int((marker_index / 512) * 30000) # adjust for differe3nces in sample rate
    data_length = len(data_stream)
    plot_start = max(0, marker_index - context_window)
    plot_end = min(data_length - 1, marker_index + context_window)

    section_data = data_stream[plot_start : plot_end + 1]
    section_indices = range(plot_start, plot_end + 1)

    # CORRECTED LINE: Check if the section_data is empty using its length/size
    if len(section_data) == 0: # or if section_data.size == 0: if you're sure it's a numpy array
        print(f"Skipping plot for marker {marker_index} ({signal_type}) due to empty data section.")
        return

    plt.figure(figsize=(14, 6))
    plt.plot(section_indices, section_data, label='Continuous Data', color='blue', linewidth=1.5)

    # Highlight the marker point
    plt.axvline(x=marker_index, color='red', linestyle='--', label=f'Marker: {signal_type}')
    # Ensure the marker_index is within the bounds of data_stream before trying to access it
    if 0 <= marker_index < data_length:
        plt.plot(marker_index, data_stream[marker_index], 'ro', markersize=8, label='Marker Location')
    else:
        print(f"Warning: Marker index {marker_index} is out of bounds for data_stream. Cannot plot marker point.")


    plt.title(f'Signal Type: {signal_type} at Index: {marker_index} (Context: $\pm${context_window})')
    plt.xlabel('Data Index')
    plt.ylabel('Value')
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    
    plt.show()
    plt.close()

In [None]:
# --- Main Execution ---

# 1. Parse the .mrk file
parsed_markers = parse_mrk_file_concise(mrk_file)
print("\nParsed Markers:")
for signal_type, indices in parsed_markers.items():
    print(f"  {signal_type}: {indices[:5]}... ({len(indices)} total)")

# 2. Plot each marker with context (first 10, show only)
context_window_size = 100000 # Adjust this to change how much data is shown around each marker
max_plots_to_show = 5
plots_shown_count = 0

print(f"\nGenerating and displaying the first {max_plots_to_show} plots with context window of $\pm${context_window_size}...")

# Iterate through signal types and their indices
for signal_type, indices in parsed_markers.items():
    for marker_index in indices:
        if plots_shown_count < max_plots_to_show:
            plot_marker_with_context(data, marker_index, signal_type, context_window=context_window_size)
            plots_shown_count += 1
        else:
            # Once 10 plots are shown, break out of the inner loop
            break
    if plots_shown_count >= max_plots_to_show:
        # If 10 plots are shown, break out of the outer loop too
        break

if plots_shown_count == 0:
    print("No plots were generated. Check your .mrk file or data stream.")
else:
    print(f"\nFinished displaying the first {plots_shown_count} marker plots.")

In [None]:
#!/usr/bin/env python3
"""
DirectNeuralBiasing Detection Comparison Script
Compares detection results against a ground truth marker file, providing a
detailed event-by-event summary and performance metrics with live updates.
"""

import numpy as np
import yaml
import os
import sys
import direct_neural_biasing as dnb
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import namedtuple

# --- Constants ---
DATA_FS = 30000.0
MRK_FS = 512.0
TOLERANCE_MS = 250
CONTEXT_WINDOW_MS = 500 # Context window for live printing (+/- 500ms)
PLOT_CONTEXT_MS = 1000  # Context window for plots (+/- 1 second)

DATA_FILE_PATH = 'data/Patient2EEG.npy'
MRK_FILE_PATH = 'data/Patient02_OfflineMrk.mrk'
CONFIG_PATH = "matlab_matched_config.yaml"

# Control flags
SHOW_LIVE_PLOTS = True  # Set to False to disable live plotting
MAX_LIVE_PLOTS = 10  # Maximum number of live plots to show (set to None for all)

# Define a structured tuple for detection results for clarity
Detection = namedtuple('Detection', [
    'index', 'wave_start', 'wave_end', 'peak_amplitude', 'is_match',
    'matched_gt_index', 'matched_gt_original_index', 'latency_ms'
])

# --- Configuration ---

def get_matlab_matched_config() -> dict:
    """Returns a dictionary with parameters matching the MATLAB script."""
    # Original MATLAB parameters:
    # sf = 512 Hz, fc_low = 0.25, fc_high = 4, refractory_s = 2.5
    return {
        'processor': {
            'verbose': False,
            'fs': DATA_FS,
            'channel': 1,
            'enable_debug_logging': False
        },
        'filters': {
            'bandpass_filters': [
                {'id': 'slow_wave_filter', 'f_low': 0.25, 'f_high': 4.0},
                {'id': 'ied_filter', 'f_low': 80.0, 'f_high': 120.0}
            ]
        },
        'detectors': {
            'wave_peak_detectors': [{
                'id': 'slow_wave_detector',
                'filter_id': 'slow_wave_filter',
                'z_score_threshold': 2.0,
                'sinusoidness_threshold': 0.7,
                'check_sinusoidness': True,  # MATLAB checks correlation
                'wave_polarity': 'downwave',  # MATLAB detects negative waves
                'min_wave_length_ms': 250.0,  # MATLAB min_ZeroCrossing_s * 1000
                'max_wave_length_ms': 1000.0  # MATLAB max_ZeroCrossing_s * 1000
            }, {
                'id': 'ied_detector',
                'filter_id': 'ied_filter',
                'z_score_threshold': 2.5,
                'sinusoidness_threshold': 0.0,
                'check_sinusoidness': False,
                'wave_polarity': 'upwave',
                'min_wave_length_ms': None,
                'max_wave_length_ms': None
            }]
        },
        'triggers': {
            'pulse_triggers': [{
                'id': 'pulse_trigger',
                'activation_detector_id': 'slow_wave_detector',
                'inhibition_detector_id': 'ied_detector',
                'inhibition_cooldown_ms': 2500.0, # refractory_s * 1000
                'pulse_cooldown_ms': 0
            }]
        }
    }

def create_config_file(config_path: str):
    """Creates a YAML configuration file from the MATLAB-matched parameters."""
    config = get_matlab_matched_config()
    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, indent=2)
    print(f"Created MATLAB-matched configuration file: {config_path}")

# --- Data Handling ---

def parse_mrk_file(filepath: str) -> np.ndarray:
    """Parses a .mrk file and returns an array of original marker indices."""
    markers = []
    with open(filepath, 'r') as f:
        next(f)  # Skip header
        for line in f:
            parts = line.strip().split()
            if len(parts) == 3:
                markers.append(int(parts[0]))
    return np.array(markers)

def get_ground_truth_map(original_indices: np.ndarray) -> dict:
    """Converts original MRK indices to the data sampling rate and returns a map."""
    converted_indices = (original_indices * DATA_FS / MRK_FS).astype(int)
    return dict(zip(converted_indices, original_indices))

# --- Live Plotting Function ---

def plot_detection_context(data: np.ndarray, detection: Detection, gt_indices: np.ndarray, 
                          gt_map: dict, plot_number: int, plot_context_ms: float = PLOT_CONTEXT_MS):
    """Plots the signal context around a detection with markers."""
    context_samples = int((plot_context_ms / 1000) * DATA_FS)
    det_idx = detection.index  # Now this is the wave start
    
    # Get context range
    start_idx = max(0, det_idx - context_samples)
    end_idx = min(len(data), det_idx + context_samples)
    
    # Create time axis in milliseconds relative to detection (wave start)
    time_samples = np.arange(start_idx - det_idx, end_idx - det_idx)
    time_ms = time_samples * 1000 / DATA_FS
    
    # Create figure
    plt.figure(figsize=(12, 6))
    
    # Plot signal
    plt.plot(time_ms, data[start_idx:end_idx], 'b-', alpha=0.7, linewidth=1, label='Raw Signal')
    
    # Mark detection point (wave start - downward zero crossing)
    plt.axvline(x=0, color='red' if not detection.is_match else 'darkgreen', 
                linestyle='-', linewidth=2, alpha=0.8,
                label=f'Detection Start ({"TP" if detection.is_match else "FP"})')
    
    # Mark wave end (upward zero crossing)
    if detection.wave_end > 0:
        wave_end_ms = (detection.wave_end - det_idx) * 1000 / DATA_FS
        plt.axvline(x=wave_end_ms, color='orange', linestyle=':', linewidth=1.5, alpha=0.6,
                   label='Wave End')
        
        # Highlight the full wave span
        plt.axvspan(0, wave_end_ms, alpha=0.2, color='yellow', label='Detected Wave')
    
    # Plot nearby ground truth markers
    markers_in_range = gt_indices[(gt_indices >= start_idx) & (gt_indices < end_idx)]
    for i, gt_idx in enumerate(markers_in_range):
        marker_time_ms = (gt_idx - det_idx) * 1000 / DATA_FS
        original_idx = gt_map[gt_idx]
        
        # Highlight matched marker
        if detection.is_match and gt_idx == detection.matched_gt_index:
            plt.axvline(x=marker_time_ms, color='green', linestyle='--', linewidth=2, alpha=0.8,
                       label=f'Matched GT (Original: {original_idx})')
        else:
            plt.axvline(x=marker_time_ms, color='gray', linestyle='--', linewidth=1, alpha=0.5,
                       label='GT Marker' if i == 0 else '')
            # Add text label for original index
            plt.text(marker_time_ms, plt.ylim()[1] * 0.9, f'{original_idx}', 
                    rotation=90, fontsize=8, color='gray', alpha=0.7)
    
    # Add detection info to title
    title = f"Detection #{plot_number} at wave start index {det_idx}"
    if detection.is_match:
        title += f" | ✅ TRUE POSITIVE | Latency: {detection.latency_ms:.1f} ms"
    else:
        title += " | ❌ FALSE POSITIVE"
    
    if detection.peak_amplitude > 0:
        title += f" | Peak Z-score: {detection.peak_amplitude:.2f}"
    
    plt.title(title)
    plt.xlabel('Time relative to wave start (ms)')
    plt.ylabel('Amplitude')
    plt.grid(True, alpha=0.3)
    plt.legend(loc='best')
    plt.tight_layout()
    
    # Add zero line
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5, alpha=0.3)
    
    plt.show()

# --- Core Processing and Live Reporting ---

def process_and_report_events_live(processor: dnb.PySignalProcessor, data: np.ndarray, gt_map: dict, chunk_size: int = 4096) -> tuple:
    """
    Processes data, prints events live as they are detected, and returns final
    lists of true positives, false positives, and false negatives.
    """
    processor.reset_index()
    
    # --- Setup for live processing ---
    true_positives, false_positives = [], []
    gt_indices = np.array(list(gt_map.keys()))
    matched_gt_mask = np.zeros(len(gt_indices), dtype=bool)
    
    # Convert tolerances to sample counts
    tolerance_samples = int((TOLERANCE_MS / 1000) * DATA_FS)
    context_samples = int((CONTEXT_WINDOW_MS / 1000) * DATA_FS)

    num_chunks = (len(data) + chunk_size - 1) // chunk_size
    pbar = tqdm(range(0, len(data), chunk_size), total=num_chunks, desc="🧠 Processing data", unit="chunk")
    
    plot_count = 0
    
    # --- Processing Loop ---
    for i in pbar:
        chunk = data[i:i + chunk_size].tolist()
        chunk_output, _ = processor.run_chunk(chunk)

        for sample_result in chunk_output:
            if sample_result.get("detectors:slow_wave_detector:detected", 0.0) == 1.0:
                # Use wave_start_index to match MATLAB behavior (downward zero-crossing)
                det_idx = int(sample_result.get("detectors:slow_wave_detector:wave_start_index", -1))
                
                # --- Live Context and Matching ---
                context_start, context_end = det_idx - context_samples, det_idx + context_samples
                markers_in_window = [
                    (gt_idx, gt_map[gt_idx]) for gt_idx in gt_indices 
                    if context_start <= gt_idx <= context_end
                ]

                # Find closest ground truth marker for potential match
                distances = np.abs(gt_indices - det_idx)
                min_dist_idx = np.argmin(distances)
                
                # Check if it's a valid match (within tolerance and not already matched)
                if distances[min_dist_idx] <= tolerance_samples and not matched_gt_mask[min_dist_idx]:
                    matched_gt_mask[min_dist_idx] = True
                    gt_index = gt_indices[min_dist_idx]
                    latency = (det_idx - gt_index) * 1000 / DATA_FS
                    
                    tp = Detection(
                        index=det_idx,  # This is now the wave start (downward crossing)
                        wave_start=det_idx,
                        wave_end=int(sample_result.get("detectors:slow_wave_detector:wave_end_index", -1)),
                        peak_amplitude=sample_result.get("detectors:slow_wave_detector:peak_z_score_amplitude", 0),
                        is_match=True, matched_gt_index=gt_index,
                        matched_gt_original_index=gt_map[gt_index], latency_ms=latency
                    )
                    true_positives.append(tp)
                    
                    # --- Live Print for True Positive ---
                    tqdm.write(f"  [✅ TRUE POSITIVE] Detection at index {tp.index:<9} | Matched MRK: {tp.matched_gt_index} (Original: {tp.matched_gt_original_index})")

                else:
                    fp = Detection(
                        index=det_idx,  # This is now the wave start (downward crossing)
                        wave_start=det_idx,
                        wave_end=int(sample_result.get("detectors:slow_wave_detector:wave_end_index", -1)),
                        peak_amplitude=sample_result.get("detectors:slow_wave_detector:peak_z_score_amplitude", 0),
                        is_match=False, 
                        matched_gt_index=None, 
                        matched_gt_original_index=None, 
                        latency_ms=None
                    )
                    false_positives.append(fp)

                    # --- Live Print for False Positive ---
                    marker_info = f"Nearby MRKs: {[orig for _, orig in markers_in_window]}" if markers_in_window else "No nearby MRKs"
                    tqdm.write(f"  [❌ FALSE POSITIVE] Detection at index {fp.index:<9} | {marker_info}")
                
                # --- Live Plot ---
                if SHOW_LIVE_PLOTS and (MAX_LIVE_PLOTS is None or plot_count < MAX_LIVE_PLOTS):
                    plot_count += 1
                    # Get the most recent detection
                    current_detection = true_positives[-1] if true_positives and true_positives[-1].index == det_idx else false_positives[-1]
                    
                    # Temporarily clear the progress bar for clean plotting
                    pbar.clear()
                    plot_detection_context(data, current_detection, gt_indices, gt_map, plot_count)
                    pbar.refresh()
                
                # Update progress bar postfix
                pbar.set_postfix(TPs=len(true_positives), FPs=len(false_positives), refresh=True)

    # --- Finalize ---
    unmatched_gt_indices = gt_indices[~matched_gt_mask]
    false_negatives = sorted([(int(gt_idx), gt_map[gt_idx]) for gt_idx in unmatched_gt_indices])

    return sorted(true_positives), sorted(false_positives), false_negatives


# --- Reporting & Visualization ---

def print_summary_metrics(tp_count, fp_count, fn_count, total_gt):
    """Calculates and prints summary performance metrics."""
    sensitivity = tp_count / (tp_count + fn_count) if (tp_count + fn_count) > 0 else 0
    precision = tp_count / (tp_count + fp_count) if (tp_count + fp_count) > 0 else 0
    
    print("\n" + "="*50)
    print("📊 FINAL SUMMARY METRICS")
    print("="*50)
    print(f"Total Ground Truth Events: {total_gt}")
    print(f"Total Detected Events:   {tp_count + fp_count}")
    print("-" * 28)
    print(f"True Positives:          {tp_count}")
    print(f"False Positives:         {fp_count}")
    print(f"False Negatives:         {fn_count} (Missed Ground Truth Events)")
    print("-" * 28)
    print(f"Sensitivity (Recall):    {sensitivity:.3f}")
    print(f"Precision:               {precision:.3f}")
    print("="*50)

def plot_detection_comparison(data, true_positives, false_positives, false_negatives, sample_range=None):
    """Plots a section of data with detections and ground truth markers."""
    if sample_range is None:
        sample_range = (0, min(10 * int(DATA_FS), len(data)))
    
    start_idx, end_idx = sample_range
    time_axis = np.arange(start_idx, end_idx) / DATA_FS
    
    plt.figure(figsize=(18, 6))
    ax = plt.gca()
    plt.plot(time_axis, data[start_idx:end_idx], 'b-', alpha=0.6, label='Raw Data')
    
    # Plot detections and ground truth markers
    all_gt_indices = [tp.matched_gt_index for tp in true_positives] + [fn[0] for fn in false_negatives]
    gt_in_range = [idx for idx in all_gt_indices if start_idx <= idx < end_idx]
    
    # Use single representative lines for the legend
    ax.axvline(x=-1, color='green', linestyle='--', label=f'Ground Truth ({len(all_gt_indices)})')
    ax.axvline(x=-1, color='darkgreen', linestyle='-', label=f'True Positive ({len(true_positives)})')
    ax.axvline(x=-1, color='red', linestyle='-', label=f'False Positive ({len(false_positives)})')

    for gt_idx in gt_in_range:
        ax.axvline(x=gt_idx / DATA_FS, color='green', linestyle='--', alpha=0.7)
        
    for det in true_positives + false_positives:
        if start_idx <= det.index < end_idx:
            color = 'darkgreen' if det.is_match else 'red'
            ax.axvline(x=det.index / DATA_FS, color=color, linestyle='-', alpha=0.8)

    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.title('Detection Comparison: Raw Data, Ground Truth, and Detections')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(time_axis[0], time_axis[-1])
    plt.tight_layout()
    plt.show()

# --- Main Execution ---

def main():
    """Main execution function."""
    print("🧠 DirectNeuralBiasing Detection Comparison")
    print("=" * 50)
    
    create_config_file(CONFIG_PATH)

    try:
        # --- Setup ---
        print("\n📊 Loading and preparing data...")
        data = np.load(DATA_FILE_PATH)[0]
        original_mrk_indices = parse_mrk_file(MRK_FILE_PATH)
        ground_truth_map = get_ground_truth_map(original_mrk_indices)
        print(f"Loaded {len(data)} data samples and {len(ground_truth_map)} ground truth markers.")
        
        # --- Processing ---
        print(f"\n🔧 Initializing signal processor and starting live detection...")
        if SHOW_LIVE_PLOTS:
            print(f"📈 Live plotting is ENABLED (showing up to {MAX_LIVE_PLOTS if MAX_LIVE_PLOTS else 'all'} plots)")
        else:
            print("📈 Live plotting is DISABLED")
        
        processor = dnb.PySignalProcessor.from_config_file(CONFIG_PATH)
        
        true_positives, false_positives, false_negatives = process_and_report_events_live(
            processor, data, ground_truth_map
        )

        # --- Final Reporting ---
        print("\n✅ Live processing complete.")
        print_summary_metrics(
            len(true_positives),
            len(false_positives),
            len(false_negatives),
            len(ground_truth_map)
        )

        # --- Visualization ---
        print(f"\n🎨 Plotting final detection comparison...")
        plot_detection_comparison(data, true_positives, false_positives, false_negatives)
    
    finally:
        # --- Cleanup ---
        if os.path.exists(CONFIG_PATH):
            os.remove(CONFIG_PATH)
            print(f"\n🧹 Cleaned up: Deleted {CONFIG_PATH}")


if __name__ == "__main__":
    main()