Loads oohca dataset info and processes it to find relevant cuts. Placed in `data` folder.

In [1]:
import os
from pathlib import Path

import polars as pl
import scipy

from cci.utils import project_dir
import numpy as np


DATASET_FOLDER = project_dir() / "data"
DATASET_FOLDER.mkdir(exist_ok=True)

# NOTE: Change this to directory containing the `.mat` files
OOCHA_DIR = Path(os.environ["OOCHA_DIR"])

In [2]:
# Load info
arecs = scipy.io.loadmat(OOCHA_DIR / "arecs.mat", simplify_cells=True)["arecs"]
oohca_info = scipy.io.loadmat(OOCHA_DIR / "oohrepr.mat", simplify_cells=True)["oohrepr"]


def replace_object(x):
    return [[y] if isinstance(y, str) else y.tolist() for y in x]


for k, v in oohca_info.items():
    oohca_info[k] = replace_object(v)
oohca_info.update({"file": arecs})

Store in dataframe, change annotations and save a copy of the full dataset.

In [3]:
# Collect values
df = (
    pl.LazyFrame(
        {key: oohca_info[key] for key in ["file", "EPI", "SMP"]},
    )
    .filter(
        pl.col("file").is_not_null(),  # Some of the entries are missing filename
        # Remove rows where these are different (rows with 1 rythm)
        pl.col("EPI").list.eval(pl.element().len()) == pl.col("SMP").list.eval(pl.element().len()),
    )
    .explode("EPI", "SMP")
    .with_columns(
        # Extract start/stop
        pl.col("SMP").list.to_struct(
            fields=["Start", "Stop"],
        ),
    )
    .unnest("SMP")
)


# Annotate
def annotate_hands_off(epi: str) -> str:
    """Hands off AS -> HAS"""
    mappings = {
        "AS": "HAS",
        "pr": "hpr",
        "VF": "HVF",
        "PR": "HPR",
        "as": "has",
        "vf": "hvf",
        "pe": "hpe",
        "VT": "HVT",
        "PE": "HPE",
        "vt": "hvt",
        "un": "hun",
    }
    return mappings.get(epi, epi)


def map_dfb(vals: pl.Struct) -> str:
    """VT -> dfb -> VT => VT -> DVT -> VT"""
    prev = vals["epi_-1"]
    current = vals["epi_0"]
    next = vals["epi_1"]
    if current == "dfb" and prev == next:
        return f"D{prev}"
    else:
        return current


# Annotate 'dfb' with corresponding rythm

df = (
    df.with_columns(
        [pl.col("EPI").shift(-i).alias(f"epi_{i}") for i in range(-1, 2)],
    )
    .with_columns(pl.struct(["epi_-1", "epi_0", "epi_1"]).map_elements(map_dfb).alias("EPI"))
    .with_columns(pl.col("EPI").map_elements(annotate_hands_off))
    # BUG
    # .drop(
    #     [f"epi_{i}" for i in range(-1, 2)],
    # )
)

# Collect and save
# BUG:?? Have to drop here or Start and Stop also gets dropped....
df = df.collect().drop(
    [f"epi_{i}" for i in range(-1, 2)],
)
df.write_csv(DATASET_FOLDER / "full.csv")

# DFB Dataset
Collect relevant DFB cuts and save as csv file

In [4]:
def dfb_df(full_df: pl.DataFrame) -> pl.DataFrame:
    # Filter so we have rhytm -> D(fb)rythm -> rythm -> transition
    df = (
        full_df.with_columns(
            [pl.col("EPI").shift(-i).alias(f"epi{i}") for i in range(1, 4)],
        )
        .with_columns(
            [pl.col("Start").shift(-i).alias(f"start{i}") for i in range(1, 4)],
        )
        .with_columns(
            [pl.col("Stop").shift(-i).alias(f"stop{i}") for i in range(1, 4)],
        )
        .filter(pl.col("epi1").str.starts_with("D"))
    )
    # Filter when occuring at end of file
    df = df.filter(pl.min_horizontal(pl.col("start1"), pl.col("start2"), pl.col("start3")) != 1)

    return df


dfb_full = dfb_df(df)
dfb_full.write_csv(DATASET_FOLDER / "dfb_full.csv")

dfb = dfb_full.clone()
# Filter when last is D(fb) or C(ompression) (no transition) TODO: Or should these be bad transitions?
dfb = dfb.filter(~pl.col("epi3").str.starts_with("D"))
dfb = dfb.filter(~pl.col("epi3").str.starts_with("C"))

# Filter out unknowns
dfb = dfb.filter(~pl.col("EPI").str.contains("un"))

# Keep only if < 3 seconds after dfb there is a transition
dfb = dfb.filter((pl.col("start3") - pl.col("start2")) < 1500).with_row_index()

Using verify_dataset to visually verify correct transitions
Shift transitions that are incorrect

In [5]:
# Indexes of signals overlapping with dfb

shift = [
    9,
    12,
    14,
    16,
    22,
    28,
    29,
    34,
    37,
    41,
    46,
    47,
    52,
    55,
    59,
    62,
    70,
    71,
    72,
    99,
    101,
    103,
    110,
    112,
    116,
    117,
    125,
    126,
    133,
    142,
    143,
    147,
    148,
    164,
    168,
    169,
    170,
    171,
    172,
    173,
    176,
    177,
    178,
    180,
    181,
    182,
    183,
    193,
    200,
    201,
    202,
    203,
    204,
    210,
    212,
    214,
    216,
    224,
    233,
    237,
    240,
    241,
    244,
    247,
    252,
    255,
    265,
    278,
    281,
    291,
    294,
    296,
    303,
    308,
    313,
    315,
    322,
    323,
    324,
    329,
    332,
    334,
    335,
    338,
    341,
    345,
    349,
    351,
    352,
    354,
    366,
    369,
    383,
    386,
    387,
    390,
    392,
    393,
    398,
    405,
    412,
    414,
    416,
    421,
    427,
    428,
    430,
    431,
    435,
    436,
    437,
    449,
    451,
    453,
    454,
    456,
    459,
    460,
    462,
    464,
    465,
    467,
    475,
    487,
    491,
    497,
    499,
    500,
]

# --- Shift late transitions before dfb
dfb_overrides_shift = dfb.filter(pl.col("index").is_in(shift))
# Shift transition by 0.01 seconds
shift_samples = int(500 * 0.01)
dfb_overrides_shift = dfb_overrides_shift.with_columns(
    pl.col("Stop") - shift_samples,
    pl.col("start1") - shift_samples,
    pl.lit("shift_transition").alias("override_function"),
    pl.lit("overlap with DFB").alias("override_reason"),
    pl.lit(True).alias("valid"),
)

# --- Fix other errors
errors = [
    (105, "start", 10, "contains unannotated shock"),
    (194, "start", 3, "starts with unannotated shock"),
    (283, "remove", None, "hvf, only sample annotated as noisy"),
    (419, "remove", None, "HAS, should be CAS?"),
]

dfb_overrides_errors = dfb.filter(pl.col("index").is_in([x[0] for x in errors]))


def adjust_start(row):
    if row["index"] == 105:
        return row["Start"] + 10 * 500
    elif row["index"] == 194:
        return row["Start"] + 3 * 500
    return row["Start"]


dfb_overrides_errors = dfb_overrides_errors.with_columns(
    pl.struct(["index", "Start"]).map_elements(adjust_start).alias("Start"),
    override_function=np.array(["shift_start", "shift_start", "remove", "remove"]),
    override_reason=np.array(["contains unannotated shock", "contains unannotated shock", "noisy", "noisy"]),
    valid=np.array([True, True, False, False]),
)

dfb_overrides = dfb_overrides_shift.vstack(dfb_overrides_errors).sort("index")

dfb_overrides.write_csv(DATASET_FOLDER / "dfb_overrides.csv")

Create the revised dfb set to be used for training.

In [6]:
# Replace with overrides
dfb.filter(~pl.col("index").is_in(dfb_overrides.select("index"))).vstack(
    dfb_overrides.drop(["override_function", "override_reason", "valid"])
).filter(  # Remove unvalid indexes
    ~pl.col("index").is_in(dfb_overrides.filter(~pl.col("valid")).select("index")),
).sort("index").write_csv(DATASET_FOLDER / "dfb.csv")

# Clean transition dataset

In [7]:
df = pl.read_csv(DATASET_FOLDER / "full.csv")


def dfb_clean(full_df: pl.DataFrame) -> pl.DataFrame:
    df = (
        # Keep unnoisy hands off recordings
        full_df.with_columns(
            [pl.col("EPI").shift(-i).alias(f"epi{i}") for i in range(1, 2)],
        )
        .with_columns(
            [pl.col("Start").shift(-i).alias(f"start{i}") for i in range(1, 2)],
        )
        .with_columns(
            [pl.col("Stop").shift(-i).alias(f"stop{i}") for i in range(1, 2)],
        )
        .filter(
            # Keep only Clean Signal -> Clean Signal (hands off, not noisy)
            pl.col("EPI").str.starts_with("H"),
            pl.col("epi1").str.starts_with("H"),
            # Remove if last
            pl.col("start1") != 1,
        )
    )
    return df


clean_df_full = dfb_clean(df).filter((pl.col("Stop") - pl.col("Start")) > 1500)
clean_df_full.write_csv(DATASET_FOLDER / "clean_full.csv")

In [8]:
# Override clean df (only checked when length > 3 seconds)
override_df = clean_df_full.with_columns(
    pl.lit(True).alias("Valid"),
    pl.lit(None).alias("New Start"),
    pl.lit(None).alias("New Stop"),
    pl.lit(None).alias("comment"),
).with_row_index()
override_df.write_csv(DATASET_FOLDER / "override_df_template.csv")

# Fix override clean after check

In [9]:
clean_df = pl.read_csv(DATASET_FOLDER / "override_df.csv")
clean_df = clean_df.filter(pl.col("Valid"))


clean_df.with_columns(
    (
        pl.when(pl.col("New Stop").is_not_null())
        .then(
            pl.col("New Stop"),
        )
        .otherwise(
            pl.col("Stop"),
        )
    )
    .cast(int)
    .alias("Stop"),
    (
        pl.when(pl.col("New Start").is_not_null())
        .then(
            pl.col("New Start"),
        )
        .otherwise(
            pl.col("Start"),
        )
    )
    .cast(int)
    .alias("Start"),
)

index,file,EPI,Start,Stop,epi1,start1,stop1,Valid,New Start,New Stop,comment
i64,str,str,i64,i64,str,i64,i64,bool,str,i64,str
1,"""a_2""","""HPE""",457313,466699,"""HVT""",466700,470691,true,,,
2,"""a_2""","""HPE""",826803,829500,"""HPR""",829501,936774,true,,,
3,"""a_2""","""HPR""",1114462,1121524,"""HPE""",1121525,1141336,true,,,
4,"""a_2""","""HPE""",1121525,1141336,"""HVT""",1141337,1148379,true,,,
5,"""a_2""","""HVT""",1141337,1148379,"""HVF""",1148380,1190708,true,,,
…,…,…,…,…,…,…,…,…,…,…,…
655,"""s_383""","""HVF""",162427,171755,"""HAS""",171756,176902,true,,,
656,"""s_383""","""HPE""",200564,205199,"""HAS""",205200,218862,true,,,
657,"""s_383""","""HPE""",262358,265438,"""HAS""",265439,267223,true,,,
658,"""s_386""","""HPE""",418190,419747,"""HAS""",419748,428167,true,,,


# Create training/test/validation sets

In [10]:
# Classify clean set


def classify_class_label(df):
    """Class 0 good, class 1 bad
    next_epi: 1 for EPI_1, 2 for EPI_2
    Desired:
    VF/VT -> PR
    AS -> PR / VF / VT
    PE -> PR
    PR -> sROSC TODO:
    """
    class_label = []
    for epi, next_epi in df.select(["EPI", "EPI_NEXT"]).rows():
        match epi:
            case "HAS":
                if next_epi in ["HVF", "HVT", "HPR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "HVF":
                if next_epi in ["HPR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "HVT":
                if next_epi in ["HPR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "HPE":
                if next_epi in ["HPR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "HPR":
                class_label.append(1)
    return class_label


clean_df = (
    clean_df.filter((pl.col("Stop") - pl.col("Start")) > 1500)
    .select(["file", "Start", "Stop", "EPI", "epi1"])
    .rename({"epi1": "EPI_NEXT"})
)
dfb_df = (
    pl.read_csv(DATASET_FOLDER / "dfb.csv")
    .filter((pl.col("Stop") - pl.col("Start")) > 1500)
    .select(["file", "Start", "Stop", "EPI", "epi3"])
    .rename({"epi3": "EPI_NEXT"})
)
# Minimum 3 seconds

df = clean_df.vstack(dfb_df)
class_label = classify_class_label(df)
df = df.hstack([pl.Series("Class", class_label)]).sort("file", "Start")
df.write_csv(DATASET_FOLDER / "test_train_val.csv")

In [11]:
from sklearn.model_selection import StratifiedGroupKFold, train_test_split

RANDOM_STATE = 0
df = pl.read_csv(DATASET_FOLDER / "test_train_val.csv").with_row_index()
# Split train/test
labels = df.select("Class").to_series().to_numpy()
train_val_idx, test_idx = train_test_split(
    range(len(df)),
    stratify=labels,
    test_size=0.1,
    random_state=RANDOM_STATE,
)

train_val_df = df.filter(pl.col("index").is_in(train_val_idx))
test_df = df.filter(pl.col("index").is_in(test_idx))

# Reindex
train_val_df = train_val_df.drop("index").with_row_index()
test_df = test_df.drop("index").with_row_index()
test_df.write_csv(DATASET_FOLDER / "full_test.csv")

# Slit train/validation
labels = train_val_df.select("Class").to_series().to_numpy()
groups = train_val_df.select("file").to_series().to_numpy()
fold = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
for i, (train_idx, val_idx) in enumerate(fold.split(np.zeros(len(train_val_df)), labels, groups)):
    train_val_df.filter(pl.col("index").is_in(train_idx)).write_csv(DATASET_FOLDER / f"full_train_{i}.csv")
    train_val_df.filter(pl.col("index").is_in(val_idx)).write_csv(DATASET_FOLDER / f"full_val_{i}.csv")

# Create training sets for each rhythm

In [12]:
# Combine HVF/HVT
vft_df = df.filter(pl.col("EPI").is_in(["HVF", "HVT"]))
as_df = df.filter(pl.col("EPI") == "HAS")
pe_df = df.filter(pl.col("EPI") == "HPE")
pr_df = df.filter(pl.col("EPI") == "HPR")

In [13]:
labels = vft_df.select("Class").to_series().to_numpy()
train_idx, test_idx = train_test_split(
    range(len(vft_df)),
    stratify=labels,
    test_size=0.1,
    random_state=RANDOM_STATE,
)
vft_df = vft_df.drop("index").with_row_index()
train = vft_df.filter(pl.col("index").is_in(train_idx))
test = vft_df.filter(pl.col("index").is_in(test_idx))
train.write_csv(DATASET_FOLDER / "vft_train.csv")
test.write_csv(DATASET_FOLDER / "vft_test.csv")

In [14]:
for rhytm_name, rhytm_df in [
    ("vft", vft_df),
    ("as", as_df),
    ("pe", pe_df),
    ("pr", pr_df),
]:
    labels = rhytm_df.select("Class").to_series().to_numpy()
    train_idx, test_idx = train_test_split(
        range(len(rhytm_df)),
        stratify=labels,
        test_size=0.1,
        random_state=RANDOM_STATE,
    )
    rhytm_df = rhytm_df.drop("index").with_row_index()
    train = rhytm_df.filter(pl.col("index").is_in(train_idx))
    test = rhytm_df.filter(pl.col("index").is_in(test_idx))
    train.write_csv(DATASET_FOLDER / f"{rhytm_name}_train_0.csv")
    test.write_csv(DATASET_FOLDER / f"{rhytm_name}_test.csv")