In [None]:
import sys
from pathlib import Path
import pandas as pd

project_root = str(Path().absolute().parent)
sys.path.append(project_root)

from src.data.unified import UnifiedDataset
from src.data.dataset import DatasetModality

data_root = Path(project_root) / "data"
arrhythmia_data = UnifiedDataset(
    data_root, modality=DatasetModality.ECG, dataset_key="arrhythmia"
)

records = arrhythmia_data.get_all_record_ids()
metadata_store = arrhythmia_data.metadata_store
df = pd.DataFrame(
    [{**metadata_store.get(record_id), "record_id": record_id} for record_id in records]
)
df["labels"] = [
    arrhythmia_data[record_id].preprocessed_record.target_labels
    for record_id in records
]

import ast

df["labels_metadata"] = df["labels_metadata"].apply(
    lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)
df["group"] = df["labels_metadata"].apply(
    lambda x: x[0].get("group") if isinstance(x, list) else None
)

del arrhythmia_data, metadata_store
df.head()

In [None]:
df.columns

In [23]:
import matplotlib.pyplot as plt
import seaborn as sns

# high retina
%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set_theme(style="whitegrid")

# set default size
plt.rcParams["figure.figsize"] = [12, 5]

In [None]:
df["sex_category"] = df["is_male"].map({1: "Male", 0: "Female"})
df.loc[df["is_male"].isna(), "sex_category"] = "Missing"

missing_age = df["age"].isna().sum()
missing_sex = df["is_male"].isna().sum()

print(f"Total missing age values: {missing_age}")
print(f"Total missing sex values: {missing_sex}")

sex_counts = df["sex_category"].value_counts()
total_records = len(df)
percentages = (sex_counts / total_records * 100).round(2)  # Round to 1 decimal place

# Set up the figure with two subplots
fig, axes = plt.subplots(1, 2)

# ---- Violin Plot: Age Distribution by Sex ----
sns.violinplot(
    x="sex_category", y="age", data=df, palette="muted", inner="quartile", ax=axes[0]
)

axes[0].set_xlabel("Sex Category")
axes[0].set_ylabel("Age")

# ---- Bar Plot: Distribution of Sex Categories ----
barplot = sns.countplot(x="sex_category", data=df, palette="muted", ax=axes[1])

# Add percentage labels on top of bars
for p in barplot.patches:
    height = p.get_height()
    barplot.annotate(
        f"{(height / total_records * 100):.2f}% (n={int(height)})",  # Convert count to percentage
        (p.get_x() + p.get_width() / 2, height),  # Position
        ha="center",
        va="bottom",
        fontsize=10,
        color="black",
    )

axes[1].set_xlabel("Sex Category")
axes[1].set_ylabel("Count")

plt.tight_layout()
plt.show()

In [None]:
# Assuming df is your DataFrame
# Explode the 'labels_metadata' column to have one row per label
df_exploded = df.explode("labels_metadata")

# Extract relevant fields
df_exploded["integration_name"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_name")
)
df_exploded["group"] = df_exploded["labels_metadata"].apply(lambda x: x.get("group"))

# Count occurrences of each integration name
label_counts = (
    df_exploded.groupby(["integration_name", "group"]).size().reset_index(name="count")
)

# Sort by count
label_counts = label_counts.sort_values(by="count", ascending=False)

# Calculate percentages
total_count = label_counts["count"].sum()
label_counts["percentage"] = (label_counts["count"] / total_count) * 100

# Plot
plt.figure(figsize=(12, 6))
ax = sns.barplot(
    data=label_counts, x="integration_name", y="count", hue="group", dodge=False
)

# annotate with percentages
for p in ax.patches:
    height = p.get_height()
    if height == 0:
        continue
    ax.text(
        p.get_x() + p.get_width() / 2.0,
        height + 3,
        f"{height / total_count * 100:.2f}%",
        ha="center",
        va="bottom",
        fontsize=10,
    )


plt.xticks(rotation=45, ha="right")
plt.xlabel("Label")
plt.ylabel("Count")
plt.legend(title="Group")
plt.tight_layout()
plt.show()

In [8]:
# results epoch 50

metrics = {
    "test_loss": 0.1731889545917511,
    "test/accuracy": 0.9585899114608765,
    "test/precision": 1.4069647749395141e29,
    "test/recall": 0.5189249515533447,
    "test/f1": 0.5026060938835144,
    "test/specificity": 0.970903217792511,
    "test/auroc/10370003": 0.9716390371322632,
    "test/auroc/164889003": 0.9922537803649902,
    "test/auroc/164909002": 0.9914255142211914,
    "test/auroc/164917005": 0.9407635927200317,
    "test/auroc/251146004": 0.8816485404968262,
    "test/auroc/270492004": 0.9613152742385864,
    "test/auroc/27885002": 0.970839262008667,
    "test/auroc/284470004": 0.7667648792266846,
    "test/auroc/39732003": 0.9763144254684448,
    "test/auroc/426177001": 0.9944480657577515,
    "test/auroc/426783006": 0.9580038785934448,
    "test/auroc/427084000": 0.9854577779769897,
    "test/auroc/427172004": 0.9058138132095337,
    "test/auroc/445118002": 0.9812264442443848,
    "test/auroc/47665007": 0.9816645383834839,
    "test/auroc/55827005": 0.8618090152740479,
    "test/auroc/55930002": 0.9100780487060547,
    "test/auroc/713427006": 0.9892495274543762,
    "test/auroc/74390002": 0.9732459187507629,
    "test/auroc/89792004": 0.9881136417388916,
    "test/average_precision/10370003": 0.7705467343330383,
    "test/average_precision/164889003": 0.9695818424224854,
    "test/average_precision/164909002": 0.5808265805244446,
    "test/average_precision/164917005": 0.43110114336013794,
    "test/average_precision/251146004": 0.1984182447195053,
    "test/average_precision/270492004": 0.644189178943634,
    "test/average_precision/27885002": 0.07638506591320038,
    "test/average_precision/284470004": 0.1017557680606842,
    "test/average_precision/39732003": 0.4860231876373291,
    "test/average_precision/426177001": 0.9889308214187622,
    "test/average_precision/426783006": 0.9226680994033813,
    "test/average_precision/427084000": 0.9516099095344543,
    "test/average_precision/427172004": 0.43971166014671326,
    "test/average_precision/445118002": 0.39616161584854126,
    "test/average_precision/47665007": 0.5112760066986084,
    "test/average_precision/55827005": 0.5109458565711975,
    "test/average_precision/55930002": 0.8367711305618286,
    "test/average_precision/713427006": 0.6500297784805298,
    "test/average_precision/74390002": 0.07447796314954758,
    "test/average_precision/89792004": 0.3549346923828125,
    "test/accuracy/10370003": 0.9876543283462524,
    "test/accuracy/164889003": 0.961624264717102,
    "test/accuracy/164909002": 0.9897367358207703,
    "test/accuracy/164917005": 0.9745649099349976,
    "test/accuracy/251146004": 0.9730774760246277,
    "test/accuracy/270492004": 0.981109619140625,
    "test/accuracy/27885002": 0.9979175925254822,
    "test/accuracy/284470004": 0.9702513813972473,
    "test/accuracy/39732003": 0.966830313205719,
    "test/accuracy/426177001": 0.9574594497680664,
    "test/accuracy/426783006": 0.9308344721794128,
    "test/accuracy/427084000": 0.9570132493972778,
    "test/accuracy/427172004": 0.9756061434745789,
    "test/accuracy/445118002": 0.9869105815887451,
    "test/accuracy/47665007": 0.9815558791160583,
    "test/accuracy/55827005": 0.7396995425224304,
    "test/accuracy/55930002": 0.8594377636909485,
    "test/accuracy/713427006": 0.9839357137680054,
    "test/accuracy/74390002": 0.9985125660896301,
    "test/accuracy/89792004": 0.9980663657188416,
    "test/precision/10370003": 0.8640000224113464,
    "test/precision/164889003": 0.8880982995033264,
    "test/precision/164909002": 0.43877550959587097,
    "test/precision/164917005": 0.5166666507720947,
    "test/precision/251146004": 0.3461538553237915,
    "test/precision/270492004": 0.6928104758262634,
    "test/precision/27885002": 0.0,
    "test/precision/284470004": 0.1428571492433548,
    "test/precision/39732003": 0.5098814368247986,
    "test/precision/426177001": 0.8998499512672424,
    "test/precision/426783006": 0.94259113073349,
    "test/precision/427084000": 0.9147679209709167,
    "test/precision/427172004": 0.7567567825317383,
    "test/precision/445118002": 0.31460675597190857,
    "test/precision/47665007": 0.5546875,
    "test/precision/55827005": 0.32277143001556396,
    "test/precision/55930002": 0.7691908478736877,
    "test/precision/713427006": 0.6094420552253723,
    "test/precision/74390002": 0.0,
    "test/precision/89792004": 1.0,
    "test/recall/10370003": 0.6206896305084229,
    "test/recall/164889003": 0.9417009353637695,
    "test/recall/164909002": 0.7543859481811523,
    "test/recall/164917005": 0.35428571701049805,
    "test/recall/251146004": 0.05202312022447586,
    "test/recall/270492004": 0.5698924660682678,
    "test/recall/27885002": 0.0,
    "test/recall/284470004": 0.0051282052882015705,
    "test/recall/39732003": 0.5657894611358643,
    "test/recall/426177001": 0.9921422600746155,
    "test/recall/426783006": 0.7565379738807678,
    "test/recall/427084000": 0.8522012829780579,
    "test/recall/427172004": 0.2772277295589447,
    "test/recall/445118002": 0.5090909004211426,
    "test/recall/47665007": 0.5144927501678467,
    "test/recall/55827005": 0.8386388421058655,
    "test/recall/55930002": 0.7478567957878113,
    "test/recall/713427006": 0.893081784248352,
    "test/recall/74390002": 0.0,
    "test/recall/89792004": 0.13333334028720856,
    "test/f1/10370003": 0.7224080562591553,
    "test/f1/164889003": 0.9141145348548889,
    "test/f1/164909002": 0.5548387169837952,
    "test/f1/164917005": 0.4203389883041382,
    "test/f1/251146004": 0.09045226126909256,
    "test/f1/270492004": 0.6253687143325806,
    "test/f1/27885002": 0.0,
    "test/f1/284470004": 0.009900989942252636,
    "test/f1/39732003": 0.5363825559616089,
    "test/f1/426177001": 0.9437450766563416,
    "test/f1/426783006": 0.8393782377243042,
    "test/f1/427084000": 0.8823769092559814,
    "test/f1/427172004": 0.4057970941066742,
    "test/f1/445118002": 0.3888888955116272,
    "test/f1/47665007": 0.5338345766067505,
    "test/f1/55827005": 0.4661378860473633,
    "test/f1/55930002": 0.7583737969398499,
    "test/f1/713427006": 0.7244898080825806,
    "test/f1/74390002": 0.0,
    "test/f1/89792004": 0.23529411852359772,
    "test/specificity/10370003": 0.997404158115387,
    "test/specificity/164889003": 0.9671415090560913,
    "test/specificity/164909002": 0.9917491674423218,
    "test/specificity/164917005": 0.9911423325538635,
    "test/specificity/251146004": 0.9974045753479004,
    "test/specificity/270492004": 0.9928101301193237,
    "test/specificity/27885002": 0.9995530247688293,
    "test/specificity/284470004": 0.9990808963775635,
    "test/specificity/39732003": 0.9809083938598633,
    "test/specificity/426177001": 0.9379791021347046,
    "test/specificity/426783006": 0.9855384230613708,
    "test/specificity/427084000": 0.9814713001251221,
    "test/specificity/427172004": 0.9972397089004517,
    "test/specificity/445118002": 0.9908518195152283,
    "test/specificity/47665007": 0.9913439750671387,
    "test/specificity/55827005": 0.7241913080215454,
    "test/specificity/55930002": 0.9061181545257568,
    "test/specificity/713427006": 0.9861364960670471,
    "test/specificity/74390002": 1.0,
    "test/specificity/89792004": 1.0,
}

In [None]:
# Define the per-class F1 scores (from provided table)
f1_scores = {
    k.split("/")[-1]: v for k, v in metrics.items() if k.startswith("test/f1/")
}

# multiply by 100 for percentage
f1_scores = {k: v * 100 for k, v in f1_scores.items()}

# Step 1: Extract occurrences from `labels_metadata`
df_exploded = df.explode("labels_metadata")

# Ensure 'labels_metadata' is in dictionary format
df_exploded["labels_metadata"] = df_exploded["labels_metadata"].apply(
    lambda x: eval(x) if isinstance(x, str) else x
)

# Extract relevant information
df_exploded["integration_code"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_code")
)
df_exploded["integration_name"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_name")
)

# Compute occurrence count of each integration name
label_counts = (
    df_exploded["integration_name"].value_counts(normalize=True) * 100
)  # Convert to percentage
label_counts = label_counts.reset_index()
label_counts.columns = ["Integration Name", "Distribution (%)"]

# Step 2: Map integration names to F1 scores
integration_map = (
    df_exploded[["integration_code", "integration_name"]]
    .drop_duplicates()
    .set_index("integration_code")["integration_name"]
    .to_dict()
)

# Create a DataFrame for F1 scores and match with integration names
df_f1 = pd.DataFrame(list(f1_scores.items()), columns=["Integration Code", "F1-score"])
df_f1["Integration Name"] = df_f1["Integration Code"].map(integration_map)

# Merge with label distribution data
df_final = pd.merge(label_counts, df_f1, on="Integration Name", how="left").dropna()

# Step 3: Sort by occurrence percentage
df_final = df_final.sort_values(by="Distribution (%)", ascending=False)

# Normalize F1-score for color mapping
norm = plt.Normalize(df_final["F1-score"].min(), df_final["F1-score"].max())
cmap = plt.cm.RdYlGn  # Colormap from red to green

# Create figure and axes
fig, ax = plt.subplots(figsize=(12, 6))

# Create horizontal barplot using distribution as x-axis
bars = ax.barh(
    df_final["Integration Name"],
    df_final["Distribution (%)"],
    color=cmap(norm(df_final["F1-score"])),
)

# Annotate each bar with the F1-score
for bar, score in zip(bars, df_final["F1-score"]):
    width = bar.get_width()
    ax.text(
        width + 0.3,
        bar.get_y() + bar.get_height() / 2,
        f"{round(score)}",
        ha="left",
        va="center",
        fontsize=10,
        color="black",
    )

# Add colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # Fixes colorbar issue
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("F1-score (%)")

# Labels and title
ax.set_xlabel("Distribution (%)")
ax.set_ylabel("Label")

# Adjust layout
plt.xlim(0, df_final["Distribution (%)"].max() + 5)  # Slight padding on x-axis
plt.gca().invert_yaxis()  # Highest occurrence at top
plt.tight_layout()
plt.show()

In [None]:
from scipy import stats

# Calculate statistics
x = df_final["Distribution (%)"]
y = df_final["F1-score"]
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
r_squared = r_value**2
equation = f"y = {slope:.2f}x + {intercept:.2f}"

plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams["figure.figsize"] = (16, 8)
plt.rcParams["font.size"] = 12
sns.set_palette("colorblind")

# Create figure
fig, ax = plt.subplots()

# Scatter plot with distribution density
sns.regplot(
    x=x,
    y=y,
    scatter_kws={"s": 80, "alpha": 0.8, "edgecolor": "w", "linewidths": 0.5},
    line_kws={"color": "#d62728", "linestyle": "--", "linewidth": 1.5},
    ci=95,  # 95% confidence interval
)

# Add statistical annotations
stats_text = (
    f"Pearson $r$ = {r_value:.2f}\n" f"$p$ = {p_value:.4f}\n" f"$n$ = {len(x)} classes"
)

ax.text(
    0.95,
    0.15,
    stats_text,
    transform=ax.transAxes,
    fontfamily="monospace",
    fontsize=10,
    verticalalignment="top",
    horizontalalignment="right",
    bbox=dict(facecolor="white", alpha=0.9, edgecolor="0.8"),
)

# Axis labels with units
ax.set_xlabel("Label Prevalence (% of total samples)", fontsize=12, labelpad=10)
ax.set_ylabel("F1-Score (%)", fontsize=12, labelpad=10)

# Set axis limits with buffer
ax.set_xlim(left=-1, right=x.max() * 1.1)
ax.set_ylim(bottom=y.min() - 2, top=100 + 2)

# Custom grid
ax.grid(True, linestyle="--", alpha=0.6, which="both")


# print values
print(
    "slope, intercept, r_value, p_value, std_err",
    slope,
    intercept,
    r_value,
    p_value,
    std_err,
)

# Use tight layout
plt.tight_layout()
plt.show()

In [None]:
recall_scores = {
    k.split("/")[-1]: v for k, v in metrics.items() if k.startswith("test/recall/")
}

# multiply by 100 for percentage
recall_scores = {k: v * 100 for k, v in recall_scores.items()}

# Step 1: Extract occurrences from `labels_metadata`
df_exploded = df.explode("labels_metadata")

# Ensure 'labels_metadata' is in dictionary format
df_exploded["labels_metadata"] = df_exploded["labels_metadata"].apply(
    lambda x: eval(x) if isinstance(x, str) else x
)

# Extract relevant information
df_exploded["integration_code"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_code")
)
df_exploded["integration_name"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_name")
)

# Compute occurrence count of each integration name
label_counts = (
    df_exploded["integration_name"].value_counts(normalize=True) * 100
)  # Convert to percentage
label_counts = label_counts.reset_index()
label_counts.columns = ["Integration Name", "Distribution (%)"]

# Step 2: Map integration names to F1 scores
integration_map = (
    df_exploded[["integration_code", "integration_name"]]
    .drop_duplicates()
    .set_index("integration_code")["integration_name"]
    .to_dict()
)

# Create a DataFrame for F1 scores and match with integration names
df_f1 = pd.DataFrame(
    list(recall_scores.items()), columns=["Integration Code", "Recall"]
)
df_f1["Integration Name"] = df_f1["Integration Code"].map(integration_map)

# Merge with label distribution data
df_final = pd.merge(label_counts, df_f1, on="Integration Name", how="left").dropna()

# Step 3: Sort by occurrence percentage
df_final = df_final.sort_values(by="Distribution (%)", ascending=False)

# Normalize F1-score for color mapping
norm = plt.Normalize(df_final["Recall"].min(), df_final["Recall"].max())
cmap = plt.cm.RdYlGn  # Colormap from red to green

# Create figure and axes
fig, ax = plt.subplots(figsize=(12, 6))

# Create horizontal barplot using distribution as x-axis
bars = ax.barh(
    df_final["Integration Name"],
    df_final["Distribution (%)"],
    color=cmap(norm(df_final["Recall"])),
)

# Annotate each bar with the F1-score
for bar, score in zip(bars, df_final["Recall"]):
    width = bar.get_width()
    ax.text(
        width + 0.3,
        bar.get_y() + bar.get_height() / 2,
        f"{round(score)}",
        ha="left",
        va="center",
        fontsize=10,
        color="black",
    )

# Add colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # Fixes colorbar issue
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("Recall (%)")

# Labels and title
ax.set_xlabel("Distribution (%)")
ax.set_ylabel("Label")

# Adjust layout
plt.xlim(0, df_final["Distribution (%)"].max() + 5)  # Slight padding on x-axis
plt.gca().invert_yaxis()  # Highest occurrence at top
plt.tight_layout()
plt.show()

In [None]:
precision_scores = {
    k.split("/")[-1]: v for k, v in metrics.items() if k.startswith("test/precision/")
}

# multiply by 100 for percentage
precision_scores = {k: v * 100 for k, v in precision_scores.items()}

# Step 1: Extract occurrences from `labels_metadata`
df_exploded = df.explode("labels_metadata")

# Ensure 'labels_metadata' is in dictionary format
df_exploded["labels_metadata"] = df_exploded["labels_metadata"].apply(
    lambda x: eval(x) if isinstance(x, str) else x
)

# Extract relevant information
df_exploded["integration_code"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_code")
)
df_exploded["integration_name"] = df_exploded["labels_metadata"].apply(
    lambda x: x.get("integration_name")
)

# Compute occurrence count of each integration name
label_counts = (
    df_exploded["integration_name"].value_counts(normalize=True) * 100
)  # Convert to percentage
label_counts = label_counts.reset_index()
label_counts.columns = ["Integration Name", "Distribution (%)"]

# Step 2: Map integration names to F1 scores
integration_map = (
    df_exploded[["integration_code", "integration_name"]]
    .drop_duplicates()
    .set_index("integration_code")["integration_name"]
    .to_dict()
)

# Create a DataFrame for F1 scores and match with integration names
df_f1 = pd.DataFrame(
    list(precision_scores.items()), columns=["Integration Code", "Precision"]
)
df_f1["Integration Name"] = df_f1["Integration Code"].map(integration_map)

# Merge with label distribution data
df_final = pd.merge(label_counts, df_f1, on="Integration Name", how="left").dropna()

# Step 3: Sort by occurrence percentage
df_final = df_final.sort_values(by="Distribution (%)", ascending=False)

# Normalize F1-score for color mapping
norm = plt.Normalize(df_final["Precision"].min(), df_final["Precision"].max())
cmap = plt.cm.RdYlGn  # Colormap from red to green

# Create figure and axes
fig, ax = plt.subplots(figsize=(12, 6))

# Create horizontal barplot using distribution as x-axis
bars = ax.barh(
    df_final["Integration Name"],
    df_final["Distribution (%)"],
    color=cmap(norm(df_final["Precision"])),
)

# Annotate each bar with the F1-score
for bar, score in zip(bars, df_final["Precision"]):
    width = bar.get_width()
    ax.text(
        width + 0.3,
        bar.get_y() + bar.get_height() / 2,
        f"{round(score)}",
        ha="left",
        va="center",
        fontsize=10,
        color="black",
    )

# Add colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # Fixes colorbar issue
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("Precision (%)")

# Labels and title
ax.set_xlabel("Distribution (%)")
ax.set_ylabel("Label")

# Adjust layout
plt.xlim(0, df_final["Distribution (%)"].max() + 5)  # Slight padding on x-axis
plt.gca().invert_yaxis()  # Highest occurrence at top
plt.tight_layout()
plt.show()