# setup

In [None]:
import pandas as pd
import numpy as np
import ast
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os

In [None]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from hparams import MIMIC_ICD_ROOT
from dataset import DynamicsDataset

In [None]:
LABEL_CSV = os.path.join(MIMIC_ICD_ROOT, 'records_w_diag_icd10.csv')

In [None]:
def get_cardiac_codes(code_list_str):
    """Parses string list, filters Chapter IX (I), truncates to 3 chars."""
    if pd.isna(code_list_str) or code_list_str == '':
        return []
    try:
        raw_codes = ast.literal_eval(code_list_str)
    except:
        return []
        
    clean_codes = set()
    for code in raw_codes:
        if code.startswith('I'):
            clean_codes.add(code[:3]) # Truncate to 3 digits (e.g. I48)
    return list(clean_codes)

# read

In [None]:
df = pd.read_csv(LABEL_CSV)
df

In [None]:
print("Parsing labels (this might take 10-20 seconds)...")
df['cardiac_labels'] = df['all_diag_all'].apply(get_cardiac_codes)

all_flat = [c for sublist in df['cardiac_labels'] for c in sublist]
ALL_CLASSES = sorted(list(set(all_flat)))
print(f"Total Unique Cardiac Classes: {len(ALL_CLASSES)}")

# splits

In [None]:
train_mask = df['fold'].between(0, 17)
val_mask_full = df['fold'] == 18
val_mask_filtered = (df['fold'] == 18) & (df['ecg_no_within_stay'] == 0)

print(f"Train Set Size: {train_mask.sum():,}")
print(f"Val Set (Full): {val_mask_full.sum():,}")
print(f"Val Set (Baseline Filtered): {val_mask_filtered.sum():,}")
print(f" -> Dropped {val_mask_full.sum() - val_mask_filtered.sum()} later-stay ECGs in Validation.")

In [None]:
train_mask = df['fold'].between(0, 17)
tst_mask_full = df['fold'] == 19
tst_mask_filtered = (df['fold'] == 19) & (df['ecg_no_within_stay'] == 0)

print(f"Train Set Size: {train_mask.sum():,}")
print(f"Test Set (Full): {tst_mask_full.sum():,}")
print(f"Test Set (Baseline Filtered): {tst_mask_filtered.sum():,}")
print(f" -> Dropped {tst_mask_full.sum() - tst_mask_filtered.sum()} later-stay ECGs in Test.")

# classes

In [None]:
def count_classes(mask, description):
    subset = df[mask]
    flat_labels = [c for sublist in subset['cardiac_labels'] for c in sublist]
    counts = Counter(flat_labels)
    
    # Check for missing classes
    missing = []
    for cls in ALL_CLASSES:
        if counts[cls] == 0:
            missing.append(cls)
            
    print(f"\n--- {description} ---")
    print(f"Total Labels: {len(flat_labels):,}")
    print(f"Missing Classes: {len(missing)}")
    if len(missing) > 0:
        print(f"Classes with 0 samples: {missing}")
    return counts

In [None]:
train_counts = count_classes(train_mask, "TRAIN SET")
val_counts_full = count_classes(val_mask_full, "VAL SET (FULL)")
val_counts_filtered = count_classes(val_mask_filtered, "VAL SET (FILTERED)")

In [None]:
train_counts = count_classes(train_mask, "TRAIN SET")
tst_counts_full = count_classes(tst_mask_full, "tst SET (FULL)")
tst_counts_filtered = count_classes(tst_mask_filtered, "tst SET (FILTERED)")

In [None]:
full_set = set(tst_counts_full.keys())
filt_set = set(tst_counts_filtered.keys())

disappearing_classes = full_set - filt_set
print("\n" + "="*40)
print(f"CLASSES LOST DUE TO 'FIRST ECG' FILTER: {len(disappearing_classes)}")
print("="*40)
for cls in disappearing_classes:
    print(f"- {cls} (Count in Full Val: {val_counts_full[cls]})")

In [None]:
full_set = set(val_counts_full.keys())
filt_set = set(val_counts_filtered.keys())

disappearing_classes = full_set - filt_set
print("\n" + "="*40)
print(f"CLASSES LOST DUE TO 'FIRST ECG' FILTER: {len(disappearing_classes)}")
print("="*40)
for cls in disappearing_classes:
    print(f"- {cls} (Count in Full Val: {val_counts_full[cls]})")

# bot20

In [None]:
sorted_val = sorted(val_counts_filtered.items(), key=lambda x: x[1])
# sorted_val = sorted(tst_counts_filtered.items(), key=lambda x: x[1])
bottom_20 = sorted_val[:20]

keys = [x[0] for x in bottom_20]
vals = [x[1] for x in bottom_20]

In [None]:
keys, vals

In [None]:
plt.figure(figsize=(12, 6))
sns.barplot(x=keys, y=vals, palette="magma")
plt.title("Bottom 20 Rare Classes in Filtered Validation Set")
plt.ylabel("Count of Positives")
plt.xlabel("ICD Class (3-digit)")
plt.axhline(y=1, color='r', linestyle='--', label="Danger Zone (1 sample)")
plt.axhline(y=0, color='black', linestyle='-')
plt.legend()
plt.show()