In [None]:
import h5py
import numpy as np
import pandas as pd
import os
from datetime import datetime, timedelta
import plotly.graph_objects as go
import neurokit2 as nk
import matplotlib.pyplot as plt
from utils.ecg_features import calculate_tmv_and_qt, create_windowed_ecg_from_mat
from utils.ecg_plots import plot_standardized_qt_tmv_with_dual_prediction, compute_refs_and_zscores
from tqdm import tqdm

In [None]:
# ------------------------------------------------------------
#                   MAIN PIPELINE (FIXED)
# ------------------------------------------------------------

sampling_rate = 240

alarms_df = pd.read_csv("VTSampleData/alarms.csv")
exclude_ids = ['FID0001', 'FID0002', 'FID0006']

alarms_df = alarms_df[~alarms_df["Files"].isin(exclude_ids)]

all_results = []

for file in sorted(alarms_df["Files"].unique()):
    try:
        windowed_df = create_windowed_ecg_from_mat(
            alarms_df,
            file,
            sampling_rate=sampling_rate,
            waveform_dir="VTSampleData/waveform",
            window_duration=30,
            window_shift=5,
            pre_buffer_sec=3600 * 5,
            post_buffer_sec=1000,
        )

        if windowed_df is None or windowed_df.empty:
            print(f"[SKIP] No windows for {file}")
            continue

        for _, row in tqdm(
            windowed_df.iterrows(),
            total=len(windowed_df),
            desc=f"Processing {file}"
        ):
            try:
                (
                    qt, t_wave, raw_ecg, mean_hr, rmssd, sdnn,
                    t_wave_features, tmv_score, max_hr, min_hr,
                    qrs_dur, qrs_area, qrs_skew,
                    st_dev_mean, st_slope_mean,
                    qrs_wave,
                    ac_ecg_peak, ac_ecg_lag_sec, ac_ecg_mean,
                    ac_rr_peak, ac_rr_lag_beats, ac_rr_mean
                ) = calculate_tmv_and_qt(row["ECG"], sampling_rate)

                # Skip invalid windows
                if any(pd.isna(x) for x in [qt, mean_hr, rmssd, sdnn, tmv_score]):
                    continue

                all_results.append({
                    "Record": row["Record"],
                    "Start": row["Start"],
                    "End": row["End"],
                    "Label": row["Label"],

                    "QT_Interval": qt,
                    "T_Wave": t_wave,
                    "QRS_Wave": qrs_wave,
                    "ECG_Raw": raw_ecg,

                    "TMV_Score": tmv_score,
                    "Mean_HR": mean_hr,
                    "RMSSD": rmssd,
                    "SDNN": sdnn,
                    "Max_HR": max_hr,
                    "Min_HR": min_hr,

                    "T_Flatness": t_wave_features.get("flatness"),
                    "TWAmp_Std": t_wave_features.get("TWAmp_Std"),
                    "TWAmp_CV": t_wave_features.get("TWAmp_CV"),
                    "T_MCS": t_wave_features.get("mcs"),

                    "QRS_Duration": qrs_dur,
                    "QRS_Area": qrs_area,
                    "QRS_Skewness": qrs_skew,

                    "ST_Deviation_Mean": st_dev_mean,
                    "ST_Slope_Mean": st_slope_mean,

                    "AC_ECG_Peak": ac_ecg_peak,
                    "AC_ECG_Lag_Sec": ac_ecg_lag_sec,
                    "AC_ECG_MeanAroundPeak": ac_ecg_mean,
                    "AC_RR_Peak": ac_rr_peak,
                    "AC_RR_Lag_Beats": ac_rr_lag_beats,
                    "AC_RR_MeanAroundPeak": ac_rr_mean,
                })

            except Exception as e:
                print(f"[WARNING] Window skipped ({file} @ {row['Start']}): {e}")

    except Exception as e:
        print(f"[ERROR] while processing {file}: {e}")

In [None]:
def compute_refs_and_zscores(results_df, sampling_rate=250, window_len_sec=30):
    """
    REAL-TIME / CAUSAL STANDARDIZATION (INDEX-BASED)

    - No baseline
    - VTAC labeling is non-causal (allowed)
    - Wait for 5 prior windows (points), NOT time
    - TMV_Global & QRS_Global use evolving references
    - Z-scores computed using past windows only
    """

    df = results_df.copy()
    df = df.sort_values(["Record", "Start"]).reset_index(drop=True)

    # -----------------------------
    # Fields
    # -----------------------------
    z_fields = [
        "TMV_Score",
        "QT_Interval",
        "Mean_HR",
        "Max_HR",
        "Min_HR",
        "RMSSD",
        "SDNN",
        "T_Flatness",
        "TWAmp_Std",
        "TWAmp_CV",
        "QRS_Duration",
        "QRS_Area",
        "QRS_Skewness",
        "ST_Deviation_Mean",
        "ST_Slope_Mean",
        "AC_ECG_Peak",
        "AC_ECG_Lag_Sec",
        "AC_ECG_MeanAroundPeak",
        "AC_RR_Peak",
        "AC_RR_Lag_Beats",
        "AC_RR_MeanAroundPeak",
    ]

    # -----------------------------
    # Init columns
    # -----------------------------
    df["TMV_Global"] = np.nan
    df["QRS_Global"] = np.nan
    df["VTAC_Label"] = 0

    for field in z_fields + ["TMV_Global", "QRS_Global"]:
        df[f"{field}_Z"] = np.nan

    # -----------------------------
    # Process per Record
    # -----------------------------
    for record_id, g in df.groupby("Record"):

        g = g.sort_values("Start").reset_index()
        idx = g["index"].values  # indices into df

        past_twaves = []
        past_qrs = []

        # ---------- VTAC labeling (DATETIME-SAFE) ----------
        vtac_times = g.loc[
            g["Label"].astype(str).str.upper() == "VTAC", "Start"
        ]

        if not vtac_times.empty:
            vtac_start = vtac_times.min()
            vtac_end = vtac_times.max() + pd.to_timedelta(
                window_len_sec, unit="s"
            )

            df.loc[
                (df["Record"] == record_id)
                & (df["Start"] >= vtac_start)
                & (df["Start"] <= vtac_end),
                "VTAC_Label",
            ] = 1

        # ---------- Iterate windows ----------
        for i, row in g.iterrows():

            # ----- TMV reference -----
            tw = row.get("T_Wave")
            if isinstance(tw, (list, np.ndarray)) and len(tw) == 100:
                tw = np.asarray(tw, dtype=float)

                if len(past_twaves) >= 60:
                    ref_twave = np.median(np.vstack(past_twaves), axis=0)
                    df.at[idx[i], "TMV_Global"] = float(
                        np.mean((tw - ref_twave) ** 2)
                    )

                past_twaves.append(tw)

            # ----- QRS reference -----
            qrs = row.get("QRS_Wave")
            if isinstance(qrs, (list, np.ndarray)) and len(qrs) == 100:
                qrs = np.asarray(qrs, dtype=float)

                if len(past_qrs) >= 60:
                    ref_qrs = np.median(np.vstack(past_qrs), axis=0)
                    df.at[idx[i], "QRS_Global"] = float(
                        np.mean((qrs - ref_qrs) ** 2)
                    )

                past_qrs.append(qrs)

            # ----- Z-scoring (causal, index-based) -----
            if i < 60:
                continue

            past_idx = idx[:i]

            for field in z_fields + ["TMV_Global", "QRS_Global"]:

                current_val = df.at[idx[i], field]
                if pd.isna(current_val):
                    continue

                past_values = df.loc[past_idx, field].dropna()
                if len(past_values) < 2:
                    continue

                median = float(np.nanmedian(past_values))
                q25, q75 = np.nanpercentile(past_values, [25, 75])
                iqr = q75 - q25

                # Guard against degenerate scale
                if iqr < 1e-6:
                    continue

                # Robust z-score (scaled to be std-comparable)
                z = (current_val - median) / (iqr / 1.349 + 1e-6)

                # Safety clip
                df.at[idx[i], f"{field}_Z"] = np.clip(z, -10, 10)

    return df

In [None]:
# ------------------------------------------------------------
#               FINAL REAL-TIME STANDARDIZATION
# ------------------------------------------------------------
results_df = pd.DataFrame(all_results)

results_df = compute_refs_and_zscores(
    results_df,
    sampling_rate=sampling_rate,
    window_len_sec=30,
)

print("DONE — results_df shape:", results_df.shape)

In [None]:
# Save as pickle
results_df.to_pickle("model/results_df_FID.pkl")
print("✅ results_df saved to results_df_FID.pkl")

In [None]:
# Load the DataFrame from the pickle file
results_df = pd.read_pickle("model/results_df_FID.pkl")

In [None]:
results_df.columns

In [None]:
features = [
    'QT_Interval_Z', 'TMV_Score_Z', 'Mean_HR_Z',"QRS_Global_Z",
    'T_Flatness_Z', 'TWAmp_Std_Z', 'TWAmp_CV_Z','TMV_Global_Z', 'QRS_Duration_Z', 'QRS_Area_Z',"QRS_Skewness_Z","ST_Deviation_Mean_Z",
    "ST_Slope_Mean_Z", "AC_ECG_Peak_Z","AC_ECG_Lag_Sec_Z","AC_ECG_MeanAroundPeak_Z","AC_RR_Peak_Z","AC_RR_Lag_Beats_Z","AC_RR_MeanAroundPeak_Z"
]


In [None]:
plot_standardized_qt_tmv_with_dual_prediction(results_df, alarms_df, features,
                                                  rf_model_path="model/random_forest_vtac_model_binary.joblib",
                                                  xgb_model_path="model/xgboost_vtac_model_binary.joblib",
                                                  rf_model_path_reg="model/random_forest_vtac_model_regression.joblib",
                                                  xgb_model_path_reg="model/xgboost_vtac_model_regression.joblib",
                                                  z_threshold=1.5,
                                                  output_dir="plots", smooth_kernel=5)