In [1]:
import json

import numpy as np
from tabulate import tabulate

score_order = ["Wake", "N1", "N2", "N3", "REM"]

In [2]:
reprod_results_paths = [
    "../../logs/exp001/exp001a/sweep-2025-07-30_12-29-20_test_eval/0/usleep_test_results.json",
    "../../logs/exp001/exp001a/sweep-2025-07-30_12-29-20_test_eval/1/usleep_test_results.json",
    "../../logs/exp001/exp001a/sweep-2025-07-30_12-29-20_test_eval/2/usleep_test_results.json",
    "../../logs/exp003/exp003a/sweep-2025-08-04_10-57-43_test_eval/0/anysleep_test_results.json",
    "../../logs/exp003/exp003a/sweep-2025-08-04_10-57-43_test_eval/1/anysleep_test_results.json",
    "../../logs/exp003/exp003a/sweep-2025-08-04_10-57-43_test_eval/2/anysleep_test_results.json",
    "../../logs/exp002/exp002a/sweep-2025-07-29_14-59-03_test_eval/1/anysleep_test_results.json",
    "../../logs/exp002/exp002a/sweep-2025-07-29_14-59-03_test_eval/2/anysleep_test_results.json",
    "../../logs/exp002/exp002a/sweep-2025-07-29_14-59-03_test_eval/3/anysleep_test_results.json",
    "../../logs/exp004/exp004a/sweep-2025-08-04_10-58-11_test_eval/0/anysleep_test_results.json",
    "../../logs/exp004/exp004a/sweep-2025-08-04_10-58-11_test_eval/1/anysleep_test_results.json",
    "../../logs/exp004/exp004a/sweep-2025-08-04_10-58-11_test_eval/2/anysleep_test_results.json",
]

dataset_f1s = {}
for reprod_results_path in reprod_results_paths:
    with open(reprod_results_path, "r") as f:
        reprod_results = json.load(f)

    for dataset, results in reprod_results["datasets"].items():
        f1s = {}
        for channel, scores in results.items():
            f1s[channel] = [scores[score]["f1"][0] for score in score_order]
            f1s[channel].append(np.mean(f1s[channel]))

        if dataset not in dataset_f1s:
            dataset_f1s[dataset] = []
        dataset_f1s[dataset].append(f1s['majority'] if "majority" in f1s else f1s['concat'])

in_sample_datasets = ["abc", "ccshs", "cfs", "chat", "dcsm", "hpap", "mesa", "mros", "phys", "sedf-sc", "sedf-st",
                      "shhs", "sof"]
ood_datasets = ["dodh", "dodo", "isruc-sg1", "isruc-sg2", "isruc-sg3", "mass-c1", "mass-c3", "svuh"]
assert all([(ds in in_sample_datasets or ds in ood_datasets) for ds in
            dataset_f1s.keys()]), f"Dataset not found: {set(dataset_f1s.keys()).difference(in_sample_datasets + ood_datasets)}"

In [3]:
# present scores as table
# values from datasplit (see config/usleep_split.yaml)
n_recordings = [20, 78, 92, 128, 39, 36, 100, 134, 100, 23, 8, 140, 68, 25, 55, 100, 16, 10, 53, 62, 25]
table_data = np.empty((len(in_sample_datasets) + len(ood_datasets) + 2, 6), dtype=object)
model_means = np.zeros(4)
for ds_i, dataset in enumerate(in_sample_datasets):
    f1s = dataset_f1s[dataset]
    mean = np.mean([f1s[i][-1] for i in range(3)])
    std = np.std([f1s[i][-1] for i in range(3)])
    table_data[ds_i, 2] = f"{mean:.2f} ({std:.3f})"
    model_means[0] += mean * n_recordings[ds_i]

    mean = np.mean([f1s[3 + i][-1] for i in range(3)])
    std = np.std([f1s[3 + i][-1] for i in range(3)])
    table_data[ds_i, 3] = f"{mean:.2f} ({std:.3f})"
    model_means[1] += mean * n_recordings[ds_i]

    mean = np.mean([f1s[6 + i][-1] for i in range(3)])
    std = np.std([f1s[6 + i][-1] for i in range(3)])
    table_data[ds_i, 4] = f"{mean:.2f} ({std:.3f})"
    model_means[2] += mean * n_recordings[ds_i]

    mean = np.mean([f1s[9 + i][-1] for i in range(3)])
    std = np.std([f1s[9 + i][-1] for i in range(3)])
    table_data[ds_i, 5] = f"{mean:.2f} ({std:.3f})"
    model_means[3] += mean * n_recordings[ds_i]

# add mean over all datasets
model_means = model_means / np.sum(n_recordings[:len(in_sample_datasets)])
table_data[len(in_sample_datasets), 2:] = [f"{model_means[i]:.3f}" for i in range(len(model_means))]

model_means = np.zeros(4)
for ds_i, dataset in enumerate(ood_datasets, len(in_sample_datasets) + 1):
    f1s = dataset_f1s[dataset]
    mean = np.mean([f1s[i][-1] for i in range(3)])
    std = np.std([f1s[i][-1] for i in range(3)])
    table_data[ds_i, 2] = f"{mean:.2f} ({std:.3f})"
    model_means[0] += mean * n_recordings[ds_i - 1]

    mean = np.mean([f1s[3 + i][-1] for i in range(3)])
    std = np.std([f1s[3 + i][-1] for i in range(3)])
    table_data[ds_i, 3] = f"{mean:.2f} ({std:.3f})"
    model_means[1] += mean * n_recordings[ds_i - 1]

    mean = np.mean([f1s[6 + i][-1] for i in range(3)])
    std = np.std([f1s[6 + i][-1] for i in range(3)])
    table_data[ds_i, 4] = f"{mean:.2f} ({std:.3f})"
    model_means[2] += mean * n_recordings[ds_i - 1]

    mean = np.mean([f1s[9 + i][-1] for i in range(3)])
    std = np.std([f1s[9 + i][-1] for i in range(3)])
    table_data[ds_i, 5] = f"{mean:.2f} ({std:.3f})"
    model_means[3] += mean * n_recordings[ds_i - 1]

# add mean over all datasets
model_means = model_means / np.sum(n_recordings[len(in_sample_datasets):])
table_data[len(in_sample_datasets) + len(ood_datasets) + 1, 2:] = [f"{model_means[i]:.3f}" for i in
                                                                   range(len(model_means))]

table_data[:, 0] = in_sample_datasets + ["Mean"] + ood_datasets + ["Mean"]
table_data[:, 1] = n_recordings[:len(in_sample_datasets)] + [""] + n_recordings[len(in_sample_datasets):] + [""]
table_data[:, 1] = [str(t) for t in table_data[:, 1]]

# make maximum values bold
for i in range(table_data.shape[0]):
    if "(" not in " ".join(table_data[i, 2:]):
        continue
    max_val = np.max([float(v[:v.index("(")]) for v in table_data[i, 2:]])
    table_data[i, 2:] = [
        "\\textbf{" + v[:v.index(" ")] + "} " + v[v.index("("):] if float(v[:v.index("(")]) == max_val else v
        for v in table_data[i, 2:]]

In [4]:
table_data_html = [[vv.replace("\\textbf{", "<b>").replace("}", "</b>") for vv in v] for v in table_data]

# add a midrule between in-distribution and out-of-distribution
table_data_html.insert(len(in_sample_datasets) + 1, [""] * len(table_data_html[0]))

tabulate(table_data_html,
         headers=["", "", "U-Sleep", "AnySleep early fusion", "AnySleep mid fusion", "AnySleep late fusion"],
         tablefmt="unsafehtml")

Unnamed: 0,Unnamed: 1,U-Sleep,AnySleep early fusion,AnySleep mid fusion,AnySleep late fusion
abc,20.0,0.76 (0.009),0.77 (0.002),0.80 (0.006),0.78 (0.006)
ccshs,78.0,0.86 (0.003),0.85 (0.003),0.87 (0.001),0.86 (0.002)
cfs,92.0,0.82 (0.004),0.81 (0.002),0.83 (0.001),0.82 (0.002)
chat,128.0,0.84 (0.007),0.82 (0.007),0.86 (0.001),0.85 (0.002)
dcsm,39.0,0.81 (0.005),0.80 (0.002),0.81 (0.005),0.80 (0.011)
hpap,36.0,0.77 (0.002),0.74 (0.005),0.79 (0.006),0.77 (0.001)
mesa,100.0,0.78 (0.006),0.76 (0.007),0.80 (0.001),0.79 (0.002)
mros,134.0,0.76 (0.007),0.75 (0.003),0.78 (0.001),0.77 (0.002)
phys,100.0,0.79 (0.005),0.76 (0.004),0.79 (0.002),0.78 (0.005)
sedf-sc,23.0,0.80 (0.003),0.80 (0.007),0.81 (0.004),0.81 (0.002)


In [5]:
print(tabulate(table_data,
               headers=["", "U-Sleep", "AnySleep early fusion", "AnySleep mid fusion", "AnySleep late fusion"],
               tablefmt="latex_raw"))

\begin{tabular}{llllll}
\hline
           &     & U-Sleep               & AnySleep early fusion   & AnySleep mid fusion   & AnySleep late fusion   \\
\hline
 abc       & 20  & 0.76 (0.009)          & 0.77 (0.002)            & \textbf{0.80} (0.006) & 0.78 (0.006)           \\
 ccshs     & 78  & 0.86 (0.003)          & 0.85 (0.003)            & \textbf{0.87} (0.001) & 0.86 (0.002)           \\
 cfs       & 92  & 0.82 (0.004)          & 0.81 (0.002)            & \textbf{0.83} (0.001) & 0.82 (0.002)           \\
 chat      & 128 & 0.84 (0.007)          & 0.82 (0.007)            & \textbf{0.86} (0.001) & 0.85 (0.002)           \\
 dcsm      & 39  & \textbf{0.81} (0.005) & 0.80 (0.002)            & \textbf{0.81} (0.005) & 0.80 (0.011)           \\
 hpap      & 36  & 0.77 (0.002)          & 0.74 (0.005)            & \textbf{0.79} (0.006) & 0.77 (0.001)           \\
 mesa      & 100 & 0.78 (0.006)          & 0.76 (0.007)            & \textbf{0.80} (0.001) & 0.79 (0.002)           \\
 mros     