In [1]:
import pandas as pd
import neurokit2 as nk
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import os
import wfdb
import pickle
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import h5py
from datetime import datetime, timedelta
from utils.ecg_windowing import make_baseline_windows, window_vtac_records
from utils.ecg_features import process_dataframe, create_windowed_ecg_from_mat , convert_and_relabel_windowed_df_full
from utils.ecg_plots import compute_refs_and_zscores, plot_subject_panels 

In [2]:
base_dir = "/Users/muratkucukosmanoglu/Desktop/cognitive_battery/cognitive-battery_UMBC"
file_name = "separated_data_rt_april11.pkl"

file_path = os.path.join(base_dir, file_name)

# Load the DataFrame
healthy_subject = pd.read_pickle(file_path)

print(healthy_subject.head())

  subject_id    game_type Difficulty              Start_time  \
0       BVCX     baseline        NaN                     NaT   
1       BVCX  task switch     Medium 2024-03-13 18:27:18.396   
2       BVCX  task switch       Hard 2024-03-13 18:22:38.650   
3       BVCX  task switch       Easy 2024-03-13 18:25:06.668   
4       CVBN     baseline        NaN                     NaT   

                       End_time  Accuracy  \
0                           NaT       NaN   
1 2024-03-13 18:29:07.553999872  0.888889   
2 2024-03-13 18:26:04.896000000  0.469136   
3 2024-03-13 18:28:13.735000064  0.802469   
4                           NaT       NaN   

                                            ecg_data  \
0  [210.0, 215.0, 210.0, 203.0, 166.0, 132.0, 133...   
1  [224.0, 227.0, 220.0, 227.0, 218.0, 221.0, 226...   
2  [165.0, 168.0, 163.0, 169.0, 167.0, 168.0, 165...   
3  [235.0, 227.0, 240.0, 229.0, 234.0, 225.0, 239...   
4  [207.0, 203.0, 200.0, 196.0, 188.0, 193.0, 186...   

       

In [3]:
healthy_subject = healthy_subject[healthy_subject['game_type'] == 'baseline']
healthy_subject.head()

Unnamed: 0,subject_id,game_type,Difficulty,Start_time,End_time,Accuracy,ecg_data,eda_data,Dates,averaged_rt,time_in_seconds
0,BVCX,baseline,,NaT,NaT,,"[210.0, 215.0, 210.0, 203.0, 166.0, 132.0, 133...","[209.0, 211.0, 208.0, 201.0, 166.0, 133.0, 131...","[2024-03-13 17:57:50.229330176, 2024-03-13 17:...",,"[0.0, 0.003999, 0.008, 0.011999, 0.016, 0.0199..."
4,CVBN,baseline,,NaT,NaT,,"[207.0, 203.0, 200.0, 196.0, 188.0, 193.0, 186...","[203.0, 202.0, 197.0, 193.0, 187.0, 190.0, 184...","[2024-03-14 19:34:04.914569984, 2024-03-14 19:...",,"[0.0, 0.003999, 0.008, 0.012, 0.016, 0.019999,..."
8,DKLQ,baseline,,NaT,NaT,,"[240.0, 141.0, 151.0, 129.0, 109.0, 136.0, 150...","[687.0, 693.0, 687.0, 693.0, 687.0, 685.0, 686...","[2024-02-13 18:54:17.359229184, 2024-02-13 18:...",,"[0.0, 0.003999, 0.008, 0.012, 0.016, 0.02, 0.0..."
12,EDFR,baseline,,NaT,NaT,,"[221.0, 222.0, 205.0, 215.0, 194.0, 211.0, 208...","[221.0, 218.0, 205.0, 210.0, 193.0, 206.0, 206...","[2024-03-12 17:41:25.153460480, 2024-03-12 17:...",,"[0.0, 0.004, 0.008, 0.012, 0.015999, 0.02, 0.0..."
16,FDSA,baseline,,NaT,NaT,,"[188.0, 177.0, 159.0, 150.0, 154.0, 157.0, 159...","[184.0, 176.0, 158.0, 148.0, 152.0, 154.0, 157...","[2024-04-03 17:46:15.475637504, 2024-04-03 17:...",,"[0.0, 0.004, 0.008, 0.012, 0.016, 0.02, 0.024,..."


In [4]:
baseline_df = make_baseline_windows(healthy_subject, sample_rate=250, win_sec=30, shift_sec=5, max_subjects=50)
baseline_df.head()

Unnamed: 0,Record,Start,End,Label,ECG
0,BVCX,0,7500,baseline,"[210.0, 215.0, 210.0, 203.0, 166.0, 132.0, 133..."
1,BVCX,1250,8750,baseline,"[212.0, 211.0, 212.0, 218.0, 210.0, 213.0, 211..."
2,BVCX,2500,10000,baseline,"[222.0, 226.0, 224.0, 232.0, 223.0, 233.0, 220..."
3,BVCX,3750,11250,baseline,"[233.0, 238.0, 236.0, 240.0, 235.0, 237.0, 236..."
4,BVCX,5000,12500,baseline,"[211.0, 213.0, 218.0, 214.0, 220.0, 212.0, 220..."


In [5]:
# --- Define constants ---
SAMPLE_RATE = 250  # Samples per second

# --- Directory containing the records ---
record_dir = 'cu-ventricular-tachyarrhythmia-database-1.0.0'

# --- List all records in the directory ---
record_files = [f for f in os.listdir(record_dir) if f.endswith('.dat')]
record_names = [os.path.splitext(f)[0] for f in record_files]

# Inspect first record's header to see lead info
lead_info = {}
for rec in record_names:
    try:
        record = wfdb.rdrecord(os.path.join(record_dir, rec))
        lead_info[rec] = record.sig_name
    except Exception as e:
        lead_info[rec] = f"Error reading: {e}"

lead_info


{'cu23': ['ECG'],
 'cu22': ['ECG'],
 'cu08': ['ECG'],
 'cu34': ['ECG'],
 'cu20': ['ECG'],
 'cu21': ['ECG'],
 'cu35': ['ECG'],
 'cu09': ['ECG'],
 'cu31': ['ECG'],
 'cu25': ['ECG'],
 'cu19': ['ECG'],
 'cu18': ['ECG'],
 'cu24': ['ECG'],
 'cu30': ['ECG'],
 'cu26': ['ECG'],
 'cu32': ['ECG'],
 'cu33': ['ECG'],
 'cu27': ['ECG'],
 'cu02': ['ECG'],
 'cu16': ['ECG'],
 'cu17': ['ECG'],
 'cu03': ['ECG'],
 'cu29': ['ECG'],
 'cu15': ['ECG'],
 'cu01': ['ECG'],
 'cu14': ['ECG'],
 'cu28': ['ECG'],
 'cu10': ['ECG'],
 'cu04': ['ECG'],
 'cu05': ['ECG'],
 'cu11': ['ECG'],
 'cu07': ['ECG'],
 'cu13': ['ECG'],
 'cu12': ['ECG'],
 'cu06': ['ECG']}

In [6]:
# Override (e.g., 200 Hz, 20s window, 2s shift)
windowed_df = window_vtac_records(
    "cu-ventricular-tachyarrhythmia-database-1.0.0",
    sample_rate=250,
    win_sec=30,
    shift_sec=5,
)

In [7]:
windowed_df 

Unnamed: 0,Record,Start,End,Label,ECG
0,cu23,0,7500,Pre-VTAC,"[-0.215, -0.2, -0.18, -0.18, -0.195, -0.195, -..."
1,cu23,1250,8750,Pre-VTAC,"[-0.18, -0.205, -0.2, -0.19, -0.195, -0.21, -0..."
2,cu23,2500,10000,Pre-VTAC,"[-0.235, -0.245, -0.235, -0.245, -0.24, -0.25,..."
3,cu23,3750,11250,Pre-VTAC,"[0.715, 0.205, -0.1, -0.21, -0.205, -0.22, -0...."
4,cu23,5000,12500,Pre-VTAC,"[0.065, -0.04, -0.08, -0.07, 0.06, 0.27, 0.58,..."
...,...,...,...,...,...
3355,cu06,113750,121250,Other,"[0.455, 0.4625, 0.4725, 0.46, 0.4425, 0.4175, ..."
3356,cu06,115000,122500,Other,"[-0.43, -0.44, -0.4425, -0.455, -0.4475, -0.42..."
3357,cu06,116250,123750,Other,"[-0.2875, -0.2975, -0.3, -0.305, -0.2975, -0.2..."
3358,cu06,117500,125000,Other,"[-0.175, -0.175, -0.17, -0.1575, -0.15, -0.147..."


In [8]:
windowed_df ["Label"].unique()

array(['Pre-VTAC', 'VTAC', 'Other'], dtype=object)

In [9]:
windowed_df["Label"].unique()

array(['Pre-VTAC', 'VTAC', 'Other'], dtype=object)

|------Other------|-------Pre-VTAC-------|===VTAC===|---------Other---------|
                                               ↑
                                        VTAC Event Starts


In [10]:
# Combine both DataFrames
windowed_df = pd.concat([windowed_df, baseline_df], ignore_index=True)

# Optional: check shape and label distribution
print("Combined shape:", windowed_df.shape)
print("Label distribution:\n", windowed_df["Label"].value_counts())

Combined shape: (12633, 5)
Label distribution:
 Label
baseline    9273
Pre-VTAC    2071
VTAC         868
Other        421
Name: count, dtype: int64


In [11]:
sampling_rate = 240

alarms_df = pd.read_csv("VTSampleData/alarms.csv")
all_windowed = []
# Only use these two files
target_files = [
 'FID0000', 'FID0001', 'FID0002', 'FID0003', 'FID0004', 'FID0005',
 'FID0006', 'FID0007', 'FID0008', 'FID0009', 'FID0010', 'FID0011',
 'FID0012', 'FID0013', 'FID0014', 'FID0015', 'FID0016', 'FID0017',
 'FID0018', 'FID0019', 'FID0020', 'FID0021'
]

# --- Generate windows for selected files only ---
for file in target_files:
    try:
        df = create_windowed_ecg_from_mat(
            alarms_df,
            file,
            sampling_rate=sampling_rate,
            waveform_dir="VTSampleData/waveform",
            window_duration=30,
            window_shift=5,
            pre_buffer_sec=3600,
            post_buffer_sec=500,
        )
        if df is not None and not df.empty:
            all_windowed.append(df)

    except Exception as e:
        print(f"[ERROR] Could not window {file}: {e}")
        continue

# --- CONCAT ALL WINDOWS INTO ONE MASTER DF ---
windowed_df_full = pd.concat(all_windowed, ignore_index=True)
print("Total windows created:", len(windowed_df_full))

Total windows created: 6702


In [12]:
# 1. Collect VTAC intervals
vtac_intervals = []
for _, row in alarms_df.iterrows():
    vs = pd.to_datetime(row["StartTime"], format="%m/%d/%y %H:%M")
    ve = vs + pd.Timedelta(seconds=row["Duration"])
    vtac_intervals.append((vs, ve))

# 2. Convert and relabel
windowed_df_final = convert_and_relabel_windowed_df_full(
    windowed_df_full,
    vtac_intervals,
    sampling_rate=240,
    win_sec=30,
    shift_sec=5
)

# 3. Check output
print(windowed_df_final.head())
print(windowed_df_final["Label"].value_counts())


    Record  Start    End     Label  \
0  FID0003      0   7200  Pre-VTAC   
1  FID0003   1200   8400  Pre-VTAC   
2  FID0003   2400   9600  Pre-VTAC   
3  FID0003   3600  10800  Pre-VTAC   
4  FID0003   4800  12000  Pre-VTAC   

                                                 ECG  
0  [-44, -38, -33, -29, -26, -23, -18, -17, -18, ...  
1  [-49, -24, -10, -2, 4, 9, 11, 10, 9, 9, 8, 8, ...  
2  [-8, 22, 68, 115, 132, 106, 36, -41, -78, -61,...  
3  [-82, -78, -78, -79, -74, -63, -51, -31, 1, 45...  
4  [-29, -35, -44, -51, -62, -74, -88, -100, -107...  
Label
Pre-VTAC    6600
VTAC         102
Name: count, dtype: int64


In [13]:
windowed_df_final["Label"].unique()

array(['Pre-VTAC', 'VTAC'], dtype=object)

In [14]:
windowed_df_final["ECG"].apply(len).mean()
mean_len = windowed_df_final["ECG"].apply(len).mean()
print("Mean ECG segment length:", mean_len)

Mean ECG segment length: 7200.0


In [15]:
# Run with sampling_rate=240
results_240 = process_dataframe(windowed_df_final, sampling_rate=240, extension_sec=30)


=== VTAC Ratio Change Per Subject ===
         VTAC_Ratio_Before(%)  VTAC_Ratio_After(%)  Delta(After-Before)
Record                                                                 
FID0003              0.731707             2.317073             1.585366
FID0004              6.935123             9.619687             2.684564
FID0005              0.864553             2.737752             1.873199
FID0008              0.425532             2.695035             2.269504
FID0009              1.846966             3.430079             1.583113
FID0011              0.975610             2.439024             1.463415
FID0012              0.732601             2.319902             1.587302
FID0013              0.615764             2.339901             1.724138
FID0015              2.781137             4.232164             1.451028


  warn(
  warn(
  warn(
  warn(
 38%|███▊      | 2532/6702 [06:56<01:09, 60.04it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
 38%|███▊      | 2556/6702 [06:57<01:15, 55.14it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
  warn(
 39%|███▊      | 2588/6702 [06:59<02:37, 26.16it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
 39%|███▉      | 2613/6702 [06:59<01:44, 38.95it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


 40%|███▉      | 2678/6702 [07:05<06:33, 10.23it/s]

[ERROR] '[7202] not in index'


100%|██████████| 6702/6702 [15:53<00:00,  7.03it/s]  


In [16]:
results_240["Label"].unique()

array(['Pre-VTAC', 'VTAC'], dtype=object)

In [17]:
windowed_df["Label"].unique()

array(['Pre-VTAC', 'VTAC', 'Other', 'baseline'], dtype=object)

In [18]:
# Run with sampling_rate=250
results_250 = process_dataframe(windowed_df, sampling_rate=250, extension_sec=30)


=== VTAC Ratio Change Per Subject ===
        VTAC_Ratio_Before(%)  VTAC_Ratio_After(%)  Delta(After-Before)
Record                                                                
BHGY                0.000000             0.000000             0.000000
BVCX                0.000000             0.000000             0.000000
CVBN                0.000000             0.000000             0.000000
DKLQ                0.000000             0.000000             0.000000
EDFR                0.000000             0.000000             0.000000
FDSA                0.000000             0.000000             0.000000
FGHZ                0.000000             0.000000             0.000000
HGFD                0.000000             0.000000             0.000000
HJKL                0.000000             0.000000             0.000000
JFDE                0.000000             0.000000             0.000000
JHGF                0.000000             0.000000             0.000000
JKLM                0.000000          

  warn(
  warn(
  warn(
  warn(
  warn(
  4%|▎         | 458/12633 [00:48<03:27, 58.75it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  5%|▍         | 623/12633 [01:00<17:14, 11.61it/s]

[ERROR] '[7505] not in index'


  warn(
 10%|▉         | 1242/12633 [01:57<04:45, 39.89it/s]

[ERROR] cannot convert float NaN to integer


  warn(
 10%|█         | 1284/12633 [01:59<04:03, 46.66it/s]

[ERROR] cannot convert float NaN to integer


  warn(
 10%|█         | 1310/12633 [01:59<03:27, 54.45it/s]

[ERROR] cannot convert float NaN to integer


 13%|█▎        | 1625/12633 [02:26<19:10,  9.57it/s]

[ERROR] '[7505] not in index'


  warn(
  warn(
  warn(
 18%|█▊        | 2292/12633 [03:35<07:57, 21.66it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
 18%|█▊        | 2306/12633 [03:35<04:50, 35.50it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(
  warn(
  warn(
 19%|█▉        | 2386/12633 [03:41<08:09, 20.95it/s]

[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer
[ERROR] cannot convert float NaN to integer


  warn(


[ERROR] cannot convert float NaN to integer


  warn(
 21%|██        | 2665/12633 [04:09<09:32, 17.42it/s]

[ERROR] cannot convert float NaN to integer


 21%|██▏       | 2690/12633 [04:11<08:29, 19.50it/s]

[ERROR] '[7500] not in index'


 23%|██▎       | 2860/12633 [04:27<11:58, 13.60it/s]

[ERROR] '[7500] not in index'


  warn(
 23%|██▎       | 2867/12633 [04:27<10:12, 15.93it/s]

[ERROR] cannot convert float NaN to integer


100%|██████████| 12633/12633 [26:09<00:00,  8.05it/s] 


# Visualizations 

In [19]:
# 1) Compute refs, global metrics, baseline z-scores, VTAC_Label
zscores_df_250 = compute_refs_and_zscores(results_250, sampling_rate=250, window_len_sec=30)

[SKIP] Record BHGY: No VTAC detected.
[SKIP] Record BVCX: No VTAC detected.
[SKIP] Record CVBN: No VTAC detected.
[SKIP] Record DKLQ: No VTAC detected.
[SKIP] Record EDFR: No VTAC detected.
[SKIP] Record FDSA: No VTAC detected.
[SKIP] Record FGHZ: No VTAC detected.
[SKIP] Record HGFD: No VTAC detected.
[SKIP] Record HJKL: No VTAC detected.
[SKIP] Record JFDE: No VTAC detected.
[SKIP] Record JHGF: No VTAC detected.
[SKIP] Record JKLM: No VTAC detected.
[SKIP] Record JKLZ: No VTAC detected.
[SKIP] Record KJHG: No VTAC detected.
[SKIP] Record LKJH: No VTAC detected.
[SKIP] Record LPDW: No VTAC detected.
[SKIP] Record MJKL: No VTAC detected.
[SKIP] Record MJUY: No VTAC detected.
[SKIP] Record MYTZ: No VTAC detected.
[SKIP] Record NBVC: No VTAC detected.
[SKIP] Record NMBV: No VTAC detected.
[SKIP] Record NMKL: No VTAC detected.
[SKIP] Record NVBX: No VTAC detected.
[SKIP] Record PKJH: No VTAC detected.
[SKIP] Record PLMO: No VTAC detected.
[SKIP] Record QAXY: No VTAC detected.
[SKIP] Recor

In [20]:
# 2) Compute refs, global metrics, baseline z-scores, VTAC_Label
zscores_df_240 = compute_refs_and_zscores(results_240, sampling_rate=240, window_len_sec=30)

In [21]:
# Concatenate both dataframes
zscores_df = pd.concat(
    [zscores_df_250, zscores_df_240],
    ignore_index=True
)

print("Final combined shape:", zscores_df.shape)

Final combined shape: (8803, 57)


In [None]:
rng = np.random.default_rng(42)
unique_records = zscores_df['Record'].dropna().astype(str).unique()

k = min(2, len(unique_records))

# Ensure at least one starts with 'cu'
cu_candidates = [r for r in unique_records if r.startswith('cu')]
picked = []

if cu_candidates:
    picked.append(rng.choice(cu_candidates, size=1, replace=False)[0])


picked = np.concatenate([picked, np.array(["FID0003", 'FID0004', "FID0005","FID0011","FID0012","FID0015"])])
# Fill the rest from remaining pool (excluding what we already picked)
remaining_pool = [r for r in unique_records if r not in picked]
if remaining_pool and k - len(picked) > 0:
    picked.extend(rng.choice(remaining_pool, size=min(k - len(picked), len(remaining_pool)), replace=False))

picked = np.array(picked)
print("Plotting subjects:", picked)

subset = zscores_df[zscores_df['Record'].isin(picked)].copy()
plot_subject_panels(subset, sampling_rate=250, window_len_sec=30)

# Saving the features

In [23]:
os.makedirs("data/processed", exist_ok=True)
zscores_df.to_pickle("data/processed/zscores_df.pkl")