This notebook contains code to create statistics on the processed cohorts in the MEDS format

In [None]:
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(os.path.join(parent_dir, "src"))

from datasets import load_from_disk, Dataset
from femr.hf_utils import aggregate_over_dataset
from ehr_stats.codes import create_code_occurence_plot, create_code_stats_table, plot_codes_per_patient, extract_code_stats, combine_code_stats, get_summary_statistics
import numpy as np
from texttable import Texttable
import latextable
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


In [None]:
RAW_MIMIC_MEDS_DATA_DIR = "/home/niclas/Dokumente/thesis-daten/mimic_meds_2.2/"
RAW_MIMIC_OMOP_MEDS_DATA_DIR = "/home/niclas/Dokumente/thesis-daten/mimic-omop-meds"
DATA_RESULTS_DIR = "/home/niclas/Dokumente/thesis-daten/demo_results/"
OUTPUT_DIR = os.path.join(DATA_RESULTS_DIR, "calculated_output")
CLEAN_DATA_PATH = os.path.join(DATA_RESULTS_DIR, "clean")

In [None]:
def calculate_code_stats(dataset):
    code_stats = aggregate_over_dataset(dataset, extract_code_stats, combine_code_stats, 50, 2)
    return code_stats

clean_dataset = load_from_disk(CLEAN_DATA_PATH)
clean_code_stats = calculate_code_stats(clean_dataset)

create_code_stats_table(clean_code_stats, OUTPUT_DIR)

In [None]:
import numpy as np
from texttable import Texttable
import os
import latextable
from datasets import Dataset, DatasetDict

# Function to calculate min, mean, max, and IQR (displayed as Q1-Q3)
def calculate_statistics(data):
    """Calculates minimum, mean, maximum, and interquartile range (IQR) of the given data."""
    q1 = np.percentile(data, 25)
    q3 = np.percentile(data, 75)
    mean = np.mean(data)
    return {
        "min": np.min(data),
        "mean": mean,
        "max": np.max(data),
        "median": np.median(data),
        "q1": q1,
        "q3": q3,
        "total": np.sum(data)
    }

# Function to format timedelta in "xyz days (x.y years)"
def format_timedelta(timedelta):
    """Formats a timedelta object to a string of days and years."""
    days = timedelta.days
    hours = timedelta.seconds / 3600
    days += hours / 24
    years = days / 365
    return f"{days:,.0f} days ({years:,.1f} years)"


def create_dataset_stats(dataset):
    summary_ds = dataset.map(get_summary_statistics,
    batched=True,
    batch_size=25,
    remove_columns=dataset.column_names,
    num_proc=8
)
    
    # Initialize statistics dictionary and combined data storage
    stats = {}
    all_data = {
        "events_per_patient": [],
        "visits_per_patient": [],
        "timeline_length_per_patient": []
    }

    # Check if the input is a DatasetDict or Dataset and compute statistics accordingly
    if isinstance(summary_ds, DatasetDict):
        # Process each split in the DatasetDict
        for split in summary_ds:
            stats[split] = {}
            for attr in summary_ds[split].column_names:
                data = summary_ds[split][attr]
                stats[split][attr] = calculate_statistics(data)
                all_data[attr].extend(data)

        # Calculate statistics for all combined splits
        stats["all_splits"] = {}
        for attr in all_data:
            stats["all_splits"][attr] = calculate_statistics(all_data[attr])

        # Define splits for table generation
        splits = ['train', 'val', 'test', 'all_splits']

    else:
        # Process a single Dataset
        for attr in summary_ds.column_names:
            data = summary_ds[attr]
            stats[attr] = calculate_statistics(data)
        
        # Define splits for table generation when single Dataset is used
        splits = ['all']

    # Generate the summary table with statistics
    rows = [["Attribute"] + [split.capitalize() for split in splits]]

    # Define the list of attributes to calculate statistics for
    attributes = ["events_per_patient", "visits_per_patient", "timeline_length_per_patient"]

    for attr in attributes:
        # Add row for the attribute name
        rows.append([f"Number of {attr.split('_')[0].capitalize()}"] + [""] * len(splits))
        
        # Add rows for each metric (min, mean with IQR, max)
        for metric in ["min", "mean", "median", "max", "total"]:
            if attr == "timeline_length_per_patient":
                row = [metric.capitalize()]
                for split in splits:
                    if metric == "mean":
                        if split == 'all':
                            mean_val = format_timedelta(stats[attr][metric])
                            iqr_val = f"[{format_timedelta(stats[attr]['q1'])} - {format_timedelta(stats[attr]['q3'])}]"
                        else:
                            mean_val = format_timedelta(stats[split][attr][metric])
                            iqr_val = f"[{format_timedelta(stats[split][attr]['q1'])} - {format_timedelta(stats[split][attr]['q3'])}]"
                        row.append(f"{mean_val} {iqr_val}")
                    else:
                        if split == 'all':
                            row.append(format_timedelta(stats[attr][metric]))
                        else:
                            row.append(format_timedelta(stats[split][attr][metric]))
            else:
                row = [metric.capitalize()]
                for split in splits:
                    if metric == "total" or metric == "median":
                        row.append(f"{stats[attr][metric]:,.0f}")
                    elif metric == "mean":
                        if split == 'all':
                            mean_val = f"{stats[attr][metric]:,.2f}"
                            iqr_val = f"[{stats[attr]['q1']:,.2f} - {stats[attr]['q3']:,.2f}]"
                        else:
                            mean_val = f"{stats[split][attr][metric]:,.2f}"
                            iqr_val = f"[{stats[split][attr]['q1']:,.2f} - {stats[split][attr]['q3']:,.2f}]"
                        row.append(f"{mean_val} {iqr_val}")
                    else:
                        if split == 'all':
                            row.append(f"{stats[attr][metric]:,.2f}")
                        else:
                            row.append(f"{stats[split][attr][metric]:,.2f}")
            rows.append(row)

    # Create and display the table using Texttable
    table = Texttable()
    table.set_cols_align(["l"] + ["c"] * len(splits))
    table.set_deco(Texttable.HEADER | Texttable.VLINES)
    table.add_rows(rows)
    print(table.draw())

    # Save the table to a LaTeX file
    output_path = os.path.join(OUTPUT_DIR, "summary_stats.txt")
    with open(output_path, 'w') as f:
        latex_txt = latextable.draw_latex(table, caption="Number of inputs across data splits")
        f.write(latex_txt)


In [None]:
raw_mimic_meds_dataset = Dataset.from_parquet(os.path.join(RAW_MIMIC_MEDS_DATA_DIR, "data/*"))
create_dataset_stats(raw_mimic_meds_dataset)

In [None]:
raw_mimic_omop_meds_dataset = Dataset.from_parquet(os.path.join(RAW_MIMIC_OMOP_MEDS_DATA_DIR, "data/*"))
create_dataset_stats(raw_mimic_omop_meds_dataset)

In [None]:
guo_reproduction_ds_before = load_from_disk("/home/niclas/Dokumente/cluster_data/biased_cohort/cohort")
create_dataset_stats(guo_reproduction_ds_before)

In [None]:
guo_reproduction_ds_after = load_from_disk("/home/niclas/Dokumente/cluster_data/biased_cohort/preprocessed")
create_dataset_stats(guo_reproduction_ds_after)

In [None]:
before_filtering = guo_reproduction_ds_before.map(get_summary_statistics,
    batched=True,
    batch_size=25,
    remove_columns=guo_reproduction_ds_before.column_names,
    num_proc=8
)

after_filtering = guo_reproduction_ds_after.map(get_summary_statistics,
    batched=True,
    batch_size=25,
    remove_columns=guo_reproduction_ds_after.column_names,
    num_proc=8
)

In [None]:
# Create the figure and an axis for the histograms
fig, ax1 = plt.subplots(figsize=(10, 6))

before_filtering_data = np.array(before_filtering['events_per_patient'])
after_filtering_data = np.array(after_filtering['events_per_patient'])

# Define logarithmic bins
bins = np.logspace(np.log10(before_filtering_data.min()), 
                   np.log10(before_filtering_data.max()), 
                   100)  # You can adjust the number of bins if necessary

# Plot histograms for 'events_per_patient' before and after filtering with logarithmic bins
#ax1.hist(before_filtering['events_per_patient'], bins=bins, alpha=0.4, label='Before Filtering (Count)', color='blue')
#ax1.hist(after_filtering['events_per_patient'], bins=bins, alpha=0.4, label='After Filtering (Count)', color='orange')

# Set log scale for the x-axis
ax1.set_xscale('log')

# Set labels for the histogram plot
ax1.set_xlabel('Measurements per Patient (Log Scale)')
ax1.set_ylabel('Density')

# Create a secondary y-axis for the KDE plots

# Plot KDE for 'events_per_patient' before filtering
sns.kdeplot(before_filtering['events_per_patient'], ax=ax1, label="Before Code Translation", color="blue", linewidth=2, alpha=0.5, fill=True)

# Plot KDE for 'events_per_patient' after filtering
sns.kdeplot(after_filtering['events_per_patient'], ax=ax1, label="After Code Translation", color="orange", linewidth=2, alpha=0.5, fill=True)

# Set the y-axis label for the KDE plot

# Combine legends from both axes into one
handles1, labels1 = ax1.get_legend_handles_labels()
ax1.legend(handles1, labels1, loc='upper right')

# Tight layout ensures that all plot elements fit within the figure area
plt.tight_layout()

plt.savefig("kde_code_translation.png")

# Show the plot
plt.show()



In [None]:
extended_mapping = load_from_disk("/home/niclas/Dokumente/cluster_data/adjusted_mapping_reduced_cohort/preprocessed")
create_dataset_stats(extended_mapping)

In [None]:
standard_mapping = load_from_disk("/home/niclas/Dokumente/cluster_data/correct_reduced_cohort/preprocessed")
create_dataset_stats(standard_mapping)