# ECG/Resp Inspection
Processes ECG data (from files or from ICE parameters in the TWIX file) and displays time-domain plots with optionally detected events (R-peaks). Also plots respiratory signals and detects peaks/troughs for respiration.

### Loading packages and data

In [None]:
import yaml
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
import utils.data_ingestion as di
import utils.ecg_resp as ecg_resp
import scipy.signal as signal
# %matplotlib widget

def load_config(config_file="config.yaml"):
    """
    Load configuration from a YAML file.

    Parameters
    ----------
    config_file : str
        Path to the YAML configuration file.

    Returns
    -------
    dict
        Parsed configuration.
    """
    with open(config_file, "r") as f:
        return yaml.safe_load(f)

# Read config
config = load_config()

# Paths and optional file references
twix_file = config["data"]["twix_file"]
dicom_folder = config["data"]["dicom_folder"]
ecg_files = config["data"].get("ecg_files", None)
event_file = config["data"].get("event_file", None)
resp_file = config["data"].get("resp_file", None)

# Read TWIX, extract raw k-space, and derive sampling frequency
scans = di.read_twix_file(twix_file, include_scans=[-1], parse_pmu=False)

### Extracting and analyzing ECG

In [None]:
# first_time = None
# last_time = None
# for (i,mdb) in enumerate(scans[-1]['mdb']):
#     if mdb.is_image_scan():
#         print(mdb.mdh.TimeStamp)
#         last_time = mdb.mdh.TimeStamp
#         if first_time is None:
#             first_time = mdb.mdh.TimeStamp

# print("First time: ", first_time)
# print("Last time: ", last_time)
# print("Time difference: ", (last_time - first_time) * 2.5e-3)

In [None]:
kspace = di.extract_image_data(scans)

framerate, frametime = di.get_dicom_framerate(dicom_folder)
n_phase_encodes_per_frame = kspace.shape[0] // config["data"]["n_frames"]
fs = framerate * n_phase_encodes_per_frame  # ECG / respiration sampling freq
# fs = 1/(scans[-1]['hdr']['Phoenix']['alTR'][0]/1000/48/1000)

# Load ECG data either from external files or from the ICE parameters
if ecg_files:
    # Concatenate multi-file ECG data channel-wise
    ecg_data = []
    for ecg_file in ecg_files:
        raw_ecg = np.loadtxt(ecg_file, skiprows=1, usecols=1)
        ecg_data.append(signal.resample(raw_ecg, kspace.shape[0]))
    ecg_data = np.vstack(ecg_data).T
else:
    ecg_columns = np.s_[18:21]
    ecg_data = di.extract_iceparam_data(scans, segment_index=0, columns=ecg_columns)
    # Force 2D shape
    if ecg_data.ndim == 1:
        ecg_data = ecg_data.reshape(-1, 1)

In [None]:
# Example (commented) R-peak detection
# r_peaks_list = ecg_resp.detect_r_peaks(ecg_data, fs)
# hr = ecg_resp.compute_average_heart_rate(r_peaks_list, fs)
# print(f"Average heart rate: {hr:.2f} BPM")

# If event_file is present, interpret it as R-peak triggers or some other event
if event_file:
    resampled_length = kspace.shape[0]
    raw_events = np.loadtxt(event_file, skiprows=1, usecols=1)
    # Normalize events so that non-zero remain as spikes
    raw_events = raw_events - np.min(raw_events)
    raw_length = len(raw_events)

    # Create empty, then place spikes based on global fraction
    resampled_events = np.zeros(resampled_length)
    raw_spike_indices = np.nonzero(raw_events)[0]
    resampled_spike_indices = np.round(
        raw_spike_indices / (raw_length - 1) * (resampled_length - 1)
    ).astype(int)

    for (raw_idx, resampled_idx) in zip(raw_spike_indices, resampled_spike_indices):
        resampled_events[resampled_idx] = raw_events[raw_idx]
else:
    resampled_events = None

In [None]:
# Normalize ECG
ecg_data = (ecg_data - np.min(ecg_data, axis=0)) / (
    np.ptp(ecg_data, axis=0) + 1e-9
)

# Plot ECG with optional event spikes as vertical lines
ecg_resp.plot_ecg_signals(ecg_data, fs, spike_indices=np.nonzero(resampled_events)[0] if resampled_events is not None else None, mode="separate")

### Extracting and analyzing resp

In [None]:
# If a respiratory file is present, load and detect peaks/troughs
if resp_file:
    resp_data = np.loadtxt(resp_file, skiprows=1, usecols=1)
    # Resample to match the total number of k-space time points
    resp_data = signal.resample(resp_data, kspace.shape[0])[:, np.newaxis]

    resp_peaks = ecg_resp.detect_resp_peaks(resp_data, fs, method='scipy', height=0.6, prominence=0.2)
    resp_troughs = ecg_resp.detect_resp_peaks(-resp_data, fs, method='scipy', height=0.6, prominence=0.2)

    # Plot
    ecg_resp.plot_resp_signal(resp_data, fs, resp_peaks=resp_peaks, resp_troughs=resp_troughs)

In [None]:
# TODO: Make peak/trough detection faster in real-time (i.e., it shouold be more reactive since it seems to pick up on regime changes quite late)
# there is relatively little noise at the peaks/troughs, so some basic smoothing + peak detection / derivative-based peak detection should be sufficient

In [None]:
# t = np.arange(N)/fs
# plt.figure(figsize=(10,5))

# # Plot the offline reference fraction (unchanged)
# plt.plot(t, actual_fraction, label='Actual Fraction (Offline)', linewidth=2, color='black')

# # Instead of a single line for predicted_fraction, we segment it by predicted_phase.
# # We'll use green for inhalation (True), blue for exhalation (False), and gray if unknown.
# start_idx = 0
# for i in range(1, N):
#     # If the phase indicator changes, plot the segment from start_idx to i-1.
#     if predicted_phase[i] != predicted_phase[i-1]:
#         seg_t = t[start_idx:i]
#         seg_y = predicted_fraction[start_idx:i]
#         if predicted_phase[i-1] is True:
#             color = 'green'      # inhalation
#         elif predicted_phase[i-1] is False:
#             color = 'blue'       # exhalation
#         else:
#             color = 'gray'
#         plt.plot(seg_t, seg_y, '-', color=color, linewidth=2)
#         start_idx = i
# # Plot the last segment
# if start_idx < N:
#     seg_t = t[start_idx:]
#     seg_y = predicted_fraction[start_idx:]
#     if predicted_phase[-1] is True:
#         color = 'green'
#     elif predicted_phase[-1] is False:
#         color = 'blue'
#     else:
#         color = 'gray'
#     plt.plot(seg_t, seg_y, '-', color=color, linewidth=2)

# # Optionally, add a vertical line at calibration end
# if calibration_end_idx is not None:
#     plt.axvline(calibration_end_idx/fs, color='r', linestyle=':', linewidth=2, label='Calibration End')

# # Normalize raw resp_signal to range 0-100
# raw_norm = (resp_signal - np.min(resp_signal)) / (np.ptp(resp_signal)) * 100.0
# plt.plot(t, raw_norm, label='Raw Signal (Normalized)', color='magenta', alpha=0.5)

# plt.title("Online Peak/Trough Detection, Fraction & Phase Indicator")
# plt.xlabel("Time (s)")
# plt.ylabel("Breath Cycle Fraction (%)")
# plt.grid(True)
# plt.legend(loc='upper left')
# plt.show()

In [None]:
"""
Real-Time Respiratory Peak/Trough Detection Animation

This cell simulates real-time detection of respiratory peaks and troughs. It progressively
displays an increasing portion of the respiratory signal and, at each frame, re-detects the
peaks and troughs on the current segment. This allows false positives to be corrected as more
data becomes available.

Assumptions:
    - The full respiratory signal is stored in the variable `resp_data` (shape: [N, 1] or [N])
    - The sampling frequency `fs` is already defined.
    - The detection function `detect_resp_peaks` (and its counterpart for troughs by negating the signal)
      from utils/ecg_resp.py is available.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import utils.ecg_resp as ecg_resp

# Ensure resp_data is a 1D array
resp_signal = resp_data.flatten()  # assuming resp_data is already loaded from prior cells
N = len(resp_signal)
time_array = np.arange(N) / fs  # time vector based on sampling frequency

# Parameters for peak/trough detection
peak_height = 0.6
peak_prominence = 0.2

# Create a figure and initial plot elements
fig, ax = plt.subplots(figsize=(16, 8), dpi=40)
line, = ax.plot([], [], lw=2, color='gray', label='Resp Signal')
peaks_scatter = ax.scatter([], [], color='red', s=50, label='Peaks')
troughs_scatter = ax.scatter([], [], color='blue', s=50, label='Troughs')
ax.set_xlabel("Time (s)")
ax.set_ylabel("Signal Amplitude")
ax.set_title("Real-Time Respiratory Signal with Detected Peaks/Troughs")
ax.legend()

# Set initial x and y limits
ax.set_xlim(0, time_array[-1])
ymin, ymax = np.min(resp_signal), np.max(resp_signal)
ax.set_ylim(ymin - 0.1 * np.abs(ymin), ymax + 0.1 * np.abs(ymax))

def init():
    """
    Initialize the plot elements for the animation.
    """
    line.set_data([], [])
    # Initialize with empty 2D arrays to avoid index errors
    peaks_scatter.set_offsets(np.empty((0, 2)))
    troughs_scatter.set_offsets(np.empty((0, 2)))
    return line, peaks_scatter, troughs_scatter

def update(frame_fraction):
    """
    Update function for each frame of the animation.
    
    Parameters
    ----------
    frame_fraction : float
        Fraction (between 0 and 1) of the total signal length to display.
    
    Returns
    -------
    tuple
        Updated artists.
    """
    # Determine the current index based on the frame fraction
    current_idx = int(frame_fraction * N)
    # Ensure at least 2 samples for detection to work
    if current_idx < 2:
        current_idx = 2

    # Get current segment of the signal and corresponding time vector
    current_signal = resp_signal[:current_idx]
    current_time = time_array[:current_idx]
    
    # Detect peaks and troughs on the current segment using the existing function
    detected_peaks = ecg_resp.detect_resp_peaks(current_signal, fs, method='scipy',
                                                height=peak_height, prominence=peak_prominence)
    detected_troughs = ecg_resp.detect_resp_peaks(-current_signal, fs, method='scipy',
                                                  height=peak_height, prominence=peak_prominence)
    
    # Prepare coordinates for the detected peaks and troughs
    if len(detected_peaks) > 0:
        peaks_times = current_time[detected_peaks]
        peaks_values = current_signal[detected_peaks]
        peaks_coords = np.column_stack((peaks_times, peaks_values))
    else:
        peaks_coords = np.empty((0, 2))
        
    if len(detected_troughs) > 0:
        troughs_times = current_time[detected_troughs]
        troughs_values = current_signal[detected_troughs]
        troughs_coords = np.column_stack((troughs_times, troughs_values))
    else:
        troughs_coords = np.empty((0, 2))
    
    # Update the line plot and scatter markers
    line.set_data(current_time, current_signal)
    peaks_scatter.set_offsets(peaks_coords)
    troughs_scatter.set_offsets(troughs_coords)
    
    # Update the x-axis limit to reflect the current time (simulate real-time progress)
    ax.set_xlim(0, current_time[-1])
    
    return line, peaks_scatter, troughs_scatter

# Create the animation with 100 frames spanning the entire signal
anim = FuncAnimation(fig, update, frames=np.linspace(0, 1, 500), init_func=init,
                     blit=True, interval=10)

# Display the animation in the notebook as jshtml
HTML(anim.to_jshtml())