# ECG/Resp Inspection
Processes ECG data (from files or from ICE parameters in the TWIX file) and displays time-domain plots with optionally detected events (R-peaks). Also plots respiratory signals and detects peaks/troughs for respiration.

### Loading packages and data

In [None]:
import yaml
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
import utils.data_ingestion as di
import utils.ecg_resp as ecg_resp
import scipy.signal as signal
# %matplotlib widget

def load_config(config_file="config.yaml"):
    """
    Load configuration from a YAML file.

    Parameters
    ----------
    config_file : str
        Path to the YAML configuration file.

    Returns
    -------
    dict
        Parsed configuration.
    """
    with open(config_file, "r") as f:
        return yaml.safe_load(f)

# Read config
config = load_config()

# Paths and optional file references
twix_file = config["data"]["twix_file"]
dicom_folder = config["data"]["dicom_folder"]
ecg_files = config["data"].get("ecg_files", None)
event_file = config["data"].get("event_file", None)
resp_file = config["data"].get("resp_file", None)

# Read TWIX, extract raw k-space, and derive sampling frequency
scans = di.read_twix_file(twix_file, include_scans=[-1], parse_pmu=False)

### Extracting and analyzing ECG

In [None]:
# first_time = None
# last_time = None
# for (i,mdb) in enumerate(scans[-1]['mdb']):
#     if mdb.is_image_scan():
#         print(mdb.mdh.TimeStamp)
#         last_time = mdb.mdh.TimeStamp
#         if first_time is None:
#             first_time = mdb.mdh.TimeStamp

# print("First time: ", first_time)
# print("Last time: ", last_time)
# print("Time difference: ", (last_time - first_time) * 2.5e-3)

In [None]:
kspace = di.extract_image_data(scans)

framerate, frametime = di.get_dicom_framerate(dicom_folder)
n_phase_encodes_per_frame = kspace.shape[0] // config["data"]["n_frames"]
fs = framerate * n_phase_encodes_per_frame  # ECG / respiration sampling freq
# fs = 1/(scans[-1]['hdr']['Phoenix']['alTR'][0]/1000/48/1000)

# Load ECG data either from external files or from the ICE parameters
if ecg_files:
    # Concatenate multi-file ECG data channel-wise
    ecg_data = []
    for ecg_file in ecg_files:
        raw_ecg = np.loadtxt(ecg_file, skiprows=1, usecols=1)
        ecg_data.append(signal.resample(raw_ecg, kspace.shape[0]))
    ecg_data = np.vstack(ecg_data).T
else:
    ecg_columns = np.s_[18:21]
    ecg_data = di.extract_iceparam_data(scans, segment_index=0, columns=ecg_columns)
    # Force 2D shape
    if ecg_data.ndim == 1:
        ecg_data = ecg_data.reshape(-1, 1)

In [None]:
# Example (commented) R-peak detection
# r_peaks_list = ecg_resp.detect_r_peaks(ecg_data, fs)
# hr = ecg_resp.compute_average_heart_rate(r_peaks_list, fs)
# print(f"Average heart rate: {hr:.2f} BPM")

# If event_file is present, interpret it as R-peak triggers or some other event
if event_file:
    resampled_length = kspace.shape[0]
    raw_events = np.loadtxt(event_file, skiprows=1, usecols=1)
    # Normalize events so that non-zero remain as spikes
    raw_events = raw_events - np.min(raw_events)
    raw_length = len(raw_events)

    # Create empty, then place spikes based on global fraction
    resampled_events = np.zeros(resampled_length)
    raw_spike_indices = np.nonzero(raw_events)[0]
    resampled_spike_indices = np.round(
        raw_spike_indices / (raw_length - 1) * (resampled_length - 1)
    ).astype(int)

    for (raw_idx, resampled_idx) in zip(raw_spike_indices, resampled_spike_indices):
        resampled_events[resampled_idx] = raw_events[raw_idx]
else:
    resampled_events = None

In [None]:
# Normalize ECG
ecg_data = (ecg_data - np.min(ecg_data, axis=0)) / (
    np.ptp(ecg_data, axis=0) + 1e-9
)

# Plot ECG with optional event spikes as vertical lines
ecg_resp.plot_ecg_signals(ecg_data, fs, spike_indices=np.nonzero(resampled_events)[0] if resampled_events is not None else None, mode="separate")

### Extracting and analyzing resp

In [None]:
# If a respiratory file is present, load and detect peaks/troughs
if resp_file:
    resp_data = np.loadtxt(resp_file, skiprows=1, usecols=1)
    # Resample to match the total number of k-space time points
    resp_data = signal.resample(resp_data, kspace.shape[0])[:, np.newaxis]

    resp_peaks = ecg_resp.detect_resp_peaks(resp_data, fs, method='scipy', height=0.6, prominence=0.15)
    resp_troughs = ecg_resp.detect_resp_peaks(-resp_data, fs, method='scipy', height=0.6, prominence=0.15)

    # Plot
    ecg_resp.plot_resp_signal(resp_data, fs, resp_peaks=resp_peaks, resp_troughs=resp_troughs)

In [None]:
# TODO: Make peak/trough detection faster in real-time (i.e., it shouold be more reactive since it seems to pick up on regime changes quite late)
# there is relatively little noise at the peaks/troughs, so some basic smoothing + peak detection / derivative-based peak detection should be sufficient

In [None]:
# t = np.arange(N)/fs
# plt.figure(figsize=(10,5))

# # Plot the offline reference fraction (unchanged)
# plt.plot(t, actual_fraction, label='Actual Fraction (Offline)', linewidth=2, color='black')

# # Instead of a single line for predicted_fraction, we segment it by predicted_phase.
# # We'll use green for inhalation (True), blue for exhalation (False), and gray if unknown.
# start_idx = 0
# for i in range(1, N):
#     # If the phase indicator changes, plot the segment from start_idx to i-1.
#     if predicted_phase[i] != predicted_phase[i-1]:
#         seg_t = t[start_idx:i]
#         seg_y = predicted_fraction[start_idx:i]
#         if predicted_phase[i-1] is True:
#             color = 'green'      # inhalation
#         elif predicted_phase[i-1] is False:
#             color = 'blue'       # exhalation
#         else:
#             color = 'gray'
#         plt.plot(seg_t, seg_y, '-', color=color, linewidth=2)
#         start_idx = i
# # Plot the last segment
# if start_idx < N:
#     seg_t = t[start_idx:]
#     seg_y = predicted_fraction[start_idx:]
#     if predicted_phase[-1] is True:
#         color = 'green'
#     elif predicted_phase[-1] is False:
#         color = 'blue'
#     else:
#         color = 'gray'
#     plt.plot(seg_t, seg_y, '-', color=color, linewidth=2)

# # Optionally, add a vertical line at calibration end
# if calibration_end_idx is not None:
#     plt.axvline(calibration_end_idx/fs, color='r', linestyle=':', linewidth=2, label='Calibration End')

# # Normalize raw resp_signal to range 0-100
# raw_norm = (resp_signal - np.min(resp_signal)) / (np.ptp(resp_signal)) * 100.0
# plt.plot(t, raw_norm, label='Raw Signal (Normalized)', color='magenta', alpha=0.5)

# plt.title("Online Peak/Trough Detection, Fraction & Phase Indicator")
# plt.xlabel("Time (s)")
# plt.ylabel("Breath Cycle Fraction (%)")
# plt.grid(True)
# plt.legend(loc='upper left')
# plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
from scipy import signal
import utils.ecg_resp as ecg_resp  # Existing utility module

def detect_resp_peaks_realtime(resp_data, fs, candidate_prominence=0.03, confirm_prominence=0.15, 
                               height=None, time_constant=0.2, 
                               candidate_left_threshold=0.15, 
                               confirmed_left_threshold=0.4, confirmed_right_threshold=0.15):
    """
    Detect respiratory peaks in a 1D signal in real time using a two-stage process with custom
    left/right prominence checks.
    
    For peaks:
      - Candidate peaks are first identified on an EMA-filtered version of the signal using a
        candidate prominence of 0.03. These candidates are then filtered by verifying that their
        left-side prominence (computed on the EMA-filtered signal) is at least candidate_left_threshold.
      - Confirmed (actual) peaks are then detected on the original normalized signal using a prominence 
        of 0.15. These peaks are verified by requiring that the left prominence is at least 
        confirmed_left_threshold and the right prominence is at least confirmed_right_threshold.
    
    For troughs:
      - Call this function on the inverted signal and set candidate_left_threshold and 
        confirmed_left_threshold as needed (e.g., 0.4 for candidate troughs and 0.4 for confirmed troughs,
        with confirmed_right_threshold remaining 0.15).
    
    Parameters
    ----------
    resp_data : np.ndarray
        Respiratory signal (shape: [n_samples] or [n_samples, 1]).
    fs : float
        Sampling frequency in Hz.
    candidate_prominence : float, optional
        Prominence threshold for candidate detection (default is 0.03).
    confirm_prominence : float, optional
        Prominence threshold for confirmed detection (default is 0.15).
    height : float, optional
        Minimum peak height. Only peaks above this value are considered.
    time_constant : float, optional
        Time constant (in seconds) for EMA smoothing.
    candidate_left_threshold : float, optional
        Minimum required left prominence for candidate peaks (only left side is checked).
    confirmed_left_threshold : float, optional
        Minimum required left prominence for confirmed peaks.
    confirmed_right_threshold : float, optional
        Minimum required right prominence for confirmed peaks.
    
    Returns
    -------
    tuple of np.ndarray
        candidate_peaks: Indices of candidate peaks (after left prominence check and adjustment).
        confirmed_peaks: Indices of confirmed peaks (after left/right prominence verification).
    """
    # Ensure the input is a 1D array.
    if resp_data.ndim > 1:
        resp_data = resp_data.flatten()
    
    # Normalize the original signal.
    norm_resp = (resp_data - np.min(resp_data)) / (np.ptp(resp_data) + 1e-9)
    
    # ---------------- EMA Filtering ----------------
    # Low-pass filter the signal using an Exponential Moving Average (EMA) to reduce noise.
    alpha = (1 / fs) / (time_constant + 1 / fs)  # Compute smoothing factor
    smoothed_resp = np.zeros_like(norm_resp)
    smoothed_resp[0] = norm_resp[0]
    for i in range(1, len(norm_resp)):
        smoothed_resp[i] = alpha * norm_resp[i] + (1 - alpha) * smoothed_resp[i - 1]
    
    # ---------------- Candidate Peak Detection ----------------
    # Identify candidate peaks on the EMA-smoothed signal using the candidate prominence.
    candidate_peaks, _ = signal.find_peaks(smoothed_resp, height=height, prominence=candidate_prominence)
    
    # Filter candidate peaks: only check left prominence on the smoothed signal.
    if candidate_peaks.size > 0:
        cand_proms, cand_left_bases, _ = signal.peak_prominences(smoothed_resp, candidate_peaks)
        candidate_left_proms = smoothed_resp[candidate_peaks] - smoothed_resp[cand_left_bases]
        valid_cand_mask = candidate_left_proms >= candidate_left_threshold
        candidate_peaks = candidate_peaks[valid_cand_mask]
    
    # "Fix" candidate peak positions: adjust to the highest point in the original normalized signal
    # within a window of time_constant seconds.
    window_size = int(time_constant * fs)
    candidate_peaks = np.array([
        np.argmax(norm_resp[max(0, peak - window_size):min(len(norm_resp), peak + window_size)]) 
        + max(0, peak - window_size)
        for peak in candidate_peaks
    ])
    
    # ---------------- Confirmed (Actual) Peak Detection ----------------
    # Detect confirmed peaks on the original normalized signal using the higher prominence threshold.
    confirmed_peaks, _ = signal.find_peaks(norm_resp, height=height, prominence=confirm_prominence)
    
    # Verify confirmed peaks by checking both left and right prominences.
    if confirmed_peaks.size > 0:
        conf_proms, conf_left_bases, conf_right_bases = signal.peak_prominences(norm_resp, confirmed_peaks)
        confirmed_left_proms = norm_resp[confirmed_peaks] - norm_resp[conf_left_bases]
        confirmed_right_proms = norm_resp[confirmed_peaks] - norm_resp[conf_right_bases]
        valid_conf_mask = (confirmed_left_proms >= confirmed_left_threshold) & \
                          (confirmed_right_proms >= confirmed_right_threshold)
        confirmed_peaks = confirmed_peaks[valid_conf_mask]
    
    return candidate_peaks, confirmed_peaks


# ---------------- Real-Time Animation Setup ----------------

# Assume resp_data and fs are already defined and loaded in a previous cell.
resp_signal = resp_data.flatten()
N = len(resp_signal)
time_array = np.arange(N) / fs  # Create time vector based on sampling frequency

# Parameters for candidate and confirmed detection.
peak_height = 0.6             # Minimum peak height for detection.
candidate_prominence = 0.03   # Candidate prominence threshold.
confirm_prominence = 0.15     # Confirmed (actual) prominence threshold.
min_time = 10                 # Minimum time (in seconds) before starting detection.
time_constant = 0.1           # Time constant for EMA smoothing (in seconds).

# Create the figure and initialize plot elements.
fig, ax = plt.subplots(figsize=(16, 8), dpi=40)

# Plot for the respiratory signal.
line, = ax.plot([], [], lw=2, color='gray', label='Resp Signal')

# Scatter plots for candidate and confirmed peaks.
candidate_peaks_scatter = ax.scatter([], [], color='orange', s=50, marker='x', label='Candidate Peaks')
confirmed_peaks_scatter = ax.scatter([], [], color='red', s=50, marker='o', label='Confirmed Peaks')

# Scatter plots for candidate and confirmed troughs.
candidate_troughs_scatter = ax.scatter([], [], color='cyan', s=50, marker='x', label='Candidate Troughs')
confirmed_troughs_scatter = ax.scatter([], [], color='blue', s=50, marker='o', label='Confirmed Troughs')

ax.set_xlabel("Time (s)")
ax.set_ylabel("Signal Amplitude")
ax.set_title("Real-Time Respiratory Signal with Two-Stage Detected Peaks/Troughs")
ax.legend()

# Set initial x and y limits.
ax.set_xlim(0, time_array[-1])
ymin, ymax = 0, 1
ax.set_ylim(ymin - 0.1 * np.abs(ymin), ymax + 0.1 * np.abs(ymax))

def init():
    """
    Initialize the plot elements for the animation.
    
    Returns
    -------
    tuple
        Initialized plot artists.
    """
    line.set_data([], [])
    candidate_peaks_scatter.set_offsets(np.empty((0, 2)))
    confirmed_peaks_scatter.set_offsets(np.empty((0, 2)))
    candidate_troughs_scatter.set_offsets(np.empty((0, 2)))
    confirmed_troughs_scatter.set_offsets(np.empty((0, 2)))
    return (line, candidate_peaks_scatter, confirmed_peaks_scatter,
            candidate_troughs_scatter, confirmed_troughs_scatter)

def update(frame_fraction):
    """
    Update function for each frame of the animation using the two-stage detection with custom
    left/right prominence thresholds.
    
    Parameters
    ----------
    frame_fraction : float
        Fraction (between 0 and 1) of the total signal length to display.
    
    Returns
    -------
    tuple
        Updated plot artists.
    """
    # Determine the current index based on frame_fraction.
    current_idx = int(frame_fraction * N)
    if current_idx < 2:
        current_idx = 2

    # Extract and normalize the current segment of the signal.
    current_signal = resp_signal[:current_idx]
    current_signal = (current_signal - np.min(current_signal)) / (np.ptp(current_signal) + 1e-9)
    current_time = time_array[:current_idx]
    
    # Start detection only if sufficient time has passed.
    if current_time[-1] > min_time:
        # For peaks: candidate detection with candidate_left_threshold=0.15; 
        # confirmed detection requires left>=0.4 and right>=0.15.
        candidate_peaks, confirmed_peaks = detect_resp_peaks_realtime(
            current_signal, fs, candidate_prominence=candidate_prominence,
            confirm_prominence=confirm_prominence, height=peak_height, time_constant=time_constant,
            candidate_left_threshold=0.15, confirmed_left_threshold=0.15, confirmed_right_threshold=0.4
        )
        # For troughs: process the inverted signal.
        # Candidate troughs: candidate_left_threshold=0.4 (only left check);
        # Confirmed troughs: require left>=0.4 and right>=0.15.
        candidate_troughs, confirmed_troughs = detect_resp_peaks_realtime(
            -current_signal, fs, candidate_prominence=candidate_prominence,
            confirm_prominence=confirm_prominence, height=peak_height, time_constant=time_constant,
            candidate_left_threshold=0.4, confirmed_left_threshold=0.4, confirmed_right_threshold=0.15
        )
    else:
        candidate_peaks = np.array([])
        confirmed_peaks = np.array([])
        candidate_troughs = np.array([])
        confirmed_troughs = np.array([])
    
    # Prepare coordinates for candidate peaks.
    if candidate_peaks.size > 0:
        candidate_peaks_times = current_time[candidate_peaks]
        candidate_peaks_values = current_signal[candidate_peaks]
        candidate_peaks_coords = np.column_stack((candidate_peaks_times, candidate_peaks_values))
    else:
        candidate_peaks_coords = np.empty((0, 2))
        
    # Prepare coordinates for confirmed peaks.
    if confirmed_peaks.size > 0:
        confirmed_peaks_times = current_time[confirmed_peaks]
        confirmed_peaks_values = current_signal[confirmed_peaks]
        confirmed_peaks_coords = np.column_stack((confirmed_peaks_times, confirmed_peaks_values))
    else:
        confirmed_peaks_coords = np.empty((0, 2))
    
    # Prepare coordinates for candidate troughs.
    if candidate_troughs.size > 0:
        candidate_troughs_times = current_time[candidate_troughs]
        candidate_troughs_values = current_signal[candidate_troughs]
        candidate_troughs_coords = np.column_stack((candidate_troughs_times, candidate_troughs_values))
    else:
        candidate_troughs_coords = np.empty((0, 2))
        
    # Prepare coordinates for confirmed troughs.
    if confirmed_troughs.size > 0:
        confirmed_troughs_times = current_time[confirmed_troughs]
        confirmed_troughs_values = current_signal[confirmed_troughs]
        confirmed_troughs_coords = np.column_stack((confirmed_troughs_times, confirmed_troughs_values))
    else:
        confirmed_troughs_coords = np.empty((0, 2))
    
    # Update the line plot and scatter markers.
    line.set_data(current_time, current_signal)
    candidate_peaks_scatter.set_offsets(candidate_peaks_coords)
    confirmed_peaks_scatter.set_offsets(confirmed_peaks_coords)
    candidate_troughs_scatter.set_offsets(candidate_troughs_coords)
    confirmed_troughs_scatter.set_offsets(confirmed_troughs_coords)
    
    # Update the x-axis limit to simulate real-time progress.
    ax.set_xlim(0, current_time[-1])
    
    return (line, candidate_peaks_scatter, confirmed_peaks_scatter,
            candidate_troughs_scatter, confirmed_troughs_scatter)

# Create the animation with 500 frames spanning the entire signal.
anim = FuncAnimation(fig, update, frames=np.linspace(0, 1, 500), init_func=init,
                     blit=True, interval=10)

# Display the animation in the notebook as jshtml.
display(HTML(anim.to_jshtml()))

# Save the animation as a GIF using the Pillow writer.
# anim.save('resp_animation.gif', writer='pillow', fps=30)