In [None]:
%load_ext autoreload
%autoreload 2

import json
import os
from collections import defaultdict

from ethos.constants import PROJECT_ROOT
from ethos.datasets import ReadmissionDataset, TimelineDataset
from ethos.metrics.fidelity import (
    convert_numpy,
    fetch_all_codes,
    fidelity_evaluation,
    fidelity_from_halo_evaluation,
    transform_data_matrix,
)

data_dir = PROJECT_ROOT / "data/tokenized_datasets"

# this notebook requires tokenized datasets, that we cannot share due to the MIMIC sharing policy,
# however everything can be recreated using this repository and the result files on GDrive

In [None]:
def load_and_prepare_data(real_key, synth_key, data_dir, dataset_type="readmission"):
    """Loads and prepares data using the specified dataset type."""

    if dataset_type == "readmission":
        DatasetClass = ReadmissionDataset
    elif dataset_type == "timeline":
        DatasetClass = TimelineDataset
    else:
        raise ValueError(f"Unsupported dataset_type: {dataset_type}")

    real_data = DatasetClass(data_dir / f"mimic_synth/{real_key}")
    synth_data = DatasetClass(data_dir / f"mimic_synth/{synth_key}")

    print("number of patient_ids in synth_data:", len(set(synth_data.patient_ids.tolist())))
    print("number of static_data in synth_data:", len(set(synth_data.static_data.keys())))

    if set(synth_data.patient_ids.tolist()) == set(synth_data.static_data.keys()):
        print(f"[{real_key} vs {synth_key}] Patient IDs match")
        if real_data.vocab.stoi == synth_data.vocab.stoi:
            print(f"[{real_key} vs {synth_key}] Vocabularies are the same")
            code_vocab_size = len(real_data.vocab.stoi)
        else:
            print(f"[{real_key} vs {synth_key}] Vocabularies are different")
            code_vocab_size = len(real_data.vocab.stoi)

        real_data_all_codes = fetch_all_codes(real_data, dataset_type)
        synthetic_data_all_codes = fetch_all_codes(synth_data, dataset_type)

        def prepare_matrices(all_codes):
            return {
                mt: transform_data_matrix(all_codes, code_vocab_size, mt, dataset_type=dataset_type)
                for mt in ["binary", "count", "probability"]
            }

        return (
            real_data,
            synth_data,
            prepare_matrices(real_data_all_codes),
            prepare_matrices(synthetic_data_all_codes),
            real_data_all_codes,
            synthetic_data_all_codes,
        )

    else:
        raise ValueError(f"[{real_key} vs {synth_key}] Patient IDs do not match")


def evaluate_pair(real_key, synth_key, data_dir, save_dir, dataset_type="readmission"):
    """Evaluate a real vs synthetic dataset pair."""

    _, _, real_matrices, synth_matrices, real_codes, synthetic_codes = load_and_prepare_data(
        real_key, synth_key, data_dir, dataset_type=dataset_type
    )

    pair_results = defaultdict(lambda: defaultdict(dict))
    pair_results["fidelity_from_halo"] = fidelity_from_halo_evaluation(
        real_codes, synthetic_codes, save_dir, dataset_type=dataset_type
    )

    for mt in ["binary", "count", "probability"]:
        print(f"Evaluating {real_key} vs {synth_key} with {mt} matrix...")
        fidelity_res = fidelity_evaluation(real_matrices[mt], synth_matrices[mt], mt)
        pair_results["fidelity"].update(fidelity_res)

    return pair_results

In [None]:
result_dir = PROJECT_ROOT / "results/fidelity"
dataset_type_to_save_dir = {
    "readmission": result_dir / "readmission",
    "timeline": result_dir / "timeline",
}

real_types = ["big", "small", "little"]
synth_suffixes = ["_synth", "_synth_temp0.7", "_synth_temp0.9", "_synth_temp1.1"]

all_results = {}

for dataset_type, base_save_dir in dataset_type_to_save_dir.items():
    print(f"\n========== Starting evaluations for dataset_type: {dataset_type} ==========")
    for real_key in real_types:
        for suffix in synth_suffixes:
            synth_key = real_key + suffix
            print(f"\nStarting evaluation for {real_key} vs {synth_key}")
            result_key = f"{dataset_type}_{real_key}_vs_{synth_key}"

            current_save_dir = base_save_dir / synth_key
            current_save_dir.mkdir(parents=True, exist_ok=True)

            try:
                all_results[result_key] = evaluate_pair(
                    real_key, synth_key, data_dir, current_save_dir, dataset_type=dataset_type
                )

                with (current_save_dir / f"{result_key}_results.json").open("w") as f:
                    json.dump(all_results[result_key], f, indent=4, default=convert_numpy)

                print(f"Finished evaluation for {result_key}")
                print(f"Results for {result_key}: {all_results[result_key]}")

            except Exception as e:
                print(f"Error occurred during evaluation for {result_key}: {e}")