In [1]:
%load_ext autoreload
%autoreload 2

Create a dataframe with all cut information.

In [2]:
import os
from pathlib import Path

from cci.utils import project_dir
import polars as pl
import scipy

DATA_DIR = project_dir() / "data"
DATA_DIR.mkdir(exist_ok=True)

oocha_dir = Path(os.environ["OOCHA_DIR"])
arecs = scipy.io.loadmat(oocha_dir / "arecs.mat", simplify_cells=True)["arecs"]
oohrepr = scipy.io.loadmat(
    oocha_dir / "oohrepr.mat",
    simplify_cells=True,
)["oohrepr"]

#
files = []
epi = []
smp_start = []
smp_stop = []
for f, x, y in zip(arecs, oohrepr["EPI"], oohrepr["SMP"]):
    if isinstance(x, str):
        continue
    for j, k in zip(x, y):
        files.append(f)
        epi.append(j)  # .upper()  # The original file uses both upper and lower case.
        # Lower for noisy signal i think
        smp_start.append(k[0])
        smp_stop.append(k[1])

original_df = (
    pl.LazyFrame(
        {
            "files": files,
            "EPI": epi,
            "SMP_start": smp_start,
            "SMP_stop": smp_stop,
        }
    )
    .filter(pl.col("files").is_not_null())
    .with_columns(
        pl.col("EPI").shift(-1).alias("EPI_1"),
        pl.col("SMP_start").shift(-1).alias("SMP_start_1"),
        pl.col("SMP_stop").shift(-1).alias("SMP_stop_1"),
        pl.col("EPI").shift(-2).alias("EPI_2"),
        pl.col("SMP_start").shift(-2).alias("SMP_start_2"),
        pl.col("SMP_stop").shift(-2).alias("SMP_stop_2"),
    )
    .collect()
)

original_df.write_csv(DATA_DIR / "original.csv")
original_df.head()

files,EPI,SMP_start,SMP_stop,EPI_1,SMP_start_1,SMP_stop_1,EPI_2,SMP_start_2,SMP_stop_2
str,str,i32,i32,str,i32,i32,str,i32,i32
"""S_1""","""un""",1,11188,"""VF""",11189,19352,"""dfb""",19353,20192
"""S_1""","""VF""",11189,19352,"""dfb""",19353,20192,"""VF""",20193,21272
"""S_1""","""dfb""",19353,20192,"""VF""",20193,21272,"""AS""",21273,22846
"""S_1""","""VF""",20193,21272,"""AS""",21273,22846,"""CAS""",22847,38680
"""S_1""","""AS""",21273,22846,"""CAS""",22847,38680,"""AS""",38681,40186


# Min sample length

In [3]:
df = original_df.filter((pl.col("SMP_stop") - pl.col("SMP_start")) > 1500)

print(len(original_df))
print(len(df))

23182
19532


In [4]:
def classify_class_label(df, next_epi: int):
    """Class 0 good, class 1 bad
    next_epi: 1 for EPI_1, 2 for EPI_2
    Desired:
    VF/VT -> PR
    AS -> PR / VF / VT
    PE -> PR
    PR -> sROSC TODO:
    """
    class_label = []
    for epi, next_epi in df.select(["EPI", f"EPI_{next_epi}"]).rows():
        match epi:
            case "AS":
                if next_epi in ["VF", "VT", "PR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "VF":
                if next_epi in ["PR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "VT":
                if next_epi in ["PR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "PE":
                if next_epi in ["PR"]:
                    class_label.append(0)
                else:
                    class_label.append(1)
            case "PR":
                class_label.append(1)
    return class_label

# Clean DF

In [5]:
clean_labels = ["AS", "VF", "VT", "PE", "PR"]
clean_df = df.filter(
    pl.col("EPI").is_in(clean_labels),
    pl.col("SMP_start_1") != 1,
    pl.col("EPI_1").is_in(clean_labels),
)
class_label = classify_class_label(clean_df, 1)

clean_df = clean_df.hstack([pl.Series("Class Label", class_label)])
clean_df.head(20)
clean_df.write_csv(DATA_DIR / "clean_df.csv")

# Clean DFB DF

In [6]:
clean_labels = ["AS", "VF", "VT", "PE", "PR"]
clean_dfb_df = df.filter(
    pl.col("EPI").is_in(clean_labels),
    pl.col("SMP_start_1") != 1,
    pl.col("SMP_start_2") != 1,
    pl.col("EPI_1") == "dfb",
    pl.col("EPI_2").is_in(clean_labels),
)
class_label = classify_class_label(clean_dfb_df, 2)
clean_dfb_df = clean_dfb_df.hstack([pl.Series("Class Label", class_label)])
clean_dfb_df.head()
clean_dfb_df.write_csv(DATA_DIR / "clean_df_dfb.csv")

In [7]:
clean_dfb_df.select("Class Label").to_series().value_counts()

Class Label,count
i64,u32
1,941


In [8]:
original_df.filter(pl.col("EPI_1") == "dfb").write_csv(DATA_DIR / "original_dfb.csv")