In [9]:
from csv import DictReader
import matplotlib

matplotlib.use("pgf")
import matplotlib.pyplot as plt


# LaTeX settings
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "font.size": 11,
        "pgf.texsystem": "lualatex",
    }
)

In [None]:
with open("cifar100_overfitting_train.csv", "r") as f:
    reader = DictReader(f)
    steps_train = []
    train = []
    for row in reader:
        steps_train.append(int(row["Step"]))
        train.append(float(row["Value"]))
with open("cifar100_overfitting_val.csv", "r") as f:
    reader = DictReader(f)
    steps_val = []
    val = []
    for row in reader:
        steps_val.append(int(row["Step"]))
        val.append(float(row["Value"]))

# Plotting
plt.figure(figsize=(6, 4))
plt.plot(steps_train, train, label="Train")
plt.plot(steps_val, val, label="Validation")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.title("CIFAR-100 Initial Training Run")
plt.legend()
plt.savefig("../thesis/figures/cifar100_overfitting.pgf", bbox_inches="tight")

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [None]:
with open("training_results/caltech256_2domain_A_train.csv", "r") as f:
    reader = DictReader(f)
    steps_train_A = []
    train_A = []
    for row in reader:
        steps_train_A.append(int(row["Step"]))
        train_A.append(float(row["Value"]))

with open("training_results/caltech256_2domain_A_val.csv", "r") as f:
    reader = DictReader(f)
    steps_val_A = []
    val_A = []
    for row in reader:
        steps_val_A.append(int(row["Step"]))
        val_A.append(float(row["Value"]))

with open("training_results/caltech256_2domain_B_train.csv", "r") as f:
    reader = DictReader(f)
    steps_train_B = []
    train_B = []
    for row in reader:
        steps_train_B.append(int(row["Step"]))
        train_B.append(float(row["Value"]))

with open("training_results/caltech256_2domain_B_val.csv", "r") as f:
    reader = DictReader(f)
    steps_val_B = []
    val_B = []
    for row in reader:
        steps_val_B.append(int(row["Step"]))
        val_B.append(float(row["Value"]))

# Create 2x2 subplots for all training runs
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Dataset configurations
datasets = [
    {
        "name": "Caltech-256 2-Domain",
        "files": ["caltech256_2domain_A", "caltech256_2domain_B"],
        "title": "Caltech-256 2-Domain Variant 1",
        "save_name": "caltech256_2domain",
    },
    {
        "name": "Caltech-256 2-Domain Variant",
        "files": ["caltech256_2domain_variant_A", "caltech256_2domain_variant_B"],
        "title": "Caltech-256 2-Domain Variant 2",
        "save_name": "caltech256_2domain_variant",
    },
    {
        "name": "Caltech-256 3-Domain",
        "files": [
            "caltech256_3domain_A",
            "caltech256_3domain_B",
            "caltech256_3domain_C",
        ],
        "title": "Caltech-256 3-Domain Variant",
        "save_name": "caltech256_3domain",
    },
    {
        "name": "CIFAR-100 2-Domain",
        "files": ["cifar100_2domain_A", "cifar100_2domain_B"],
        "title": "CIFAR-100 2-Domain Variant",
        "save_name": "cifar100_2domain",
    },
]

# Plot each dataset in a subplot
for idx, dataset in enumerate(datasets):
    ax = axes[idx // 2, idx % 2]

    # Load and plot data for each domain in this dataset
    for domain_file in dataset["files"]:
        # Load training data
        with open(f"training_results/{domain_file}_train.csv", "r") as f:
            reader = DictReader(f)
            steps_train = []
            train = []
            for row in reader:
                steps_train.append(int(row["Step"]))
                train.append(float(row["Value"]))

        # Load validation data
        with open(f"training_results/{domain_file}_val.csv", "r") as f:
            reader = DictReader(f)
            steps_val = []
            val = []
            for row in reader:
                steps_val.append(int(row["Step"]))
                val.append(float(row["Value"]))

        # Extract domain label (A, B, or C)
        domain_label = domain_file.split("_")[-1]

        # Plot training and validation curves
        ax.plot(steps_train, train, label=f"Train Domain {domain_label}")
        ax.plot(steps_val, val, label=f"Validation Domain {domain_label}")

    ax.set_xlabel("Steps")
    ax.set_ylabel("Accuracy")
    ax.set_title(dataset["title"])
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("../thesis/figures/all_training_runs.pgf", bbox_inches="tight")
plt.show()

# Also save individual plots for backward compatibility
for dataset in datasets:
    plt.figure(figsize=(6, 4))

    for domain_file in dataset["files"]:
        # Load training data
        with open(f"training_results/{domain_file}_train.csv", "r") as f:
            reader = DictReader(f)
            steps_train = []
            train = []
            for row in reader:
                steps_train.append(int(row["Step"]))
                train.append(float(row["Value"]))

        # Load validation data
        with open(f"training_results/{domain_file}_val.csv", "r") as f:
            reader = DictReader(f)
            steps_val = []
            val = []
            for row in reader:
                steps_val.append(int(row["Step"]))
                val.append(float(row["Value"]))

        # Extract domain label (A, B, or C)
        domain_label = domain_file.split("_")[-1]

        # Plot training and validation curves
        plt.plot(steps_train, train, label=f"Train Domain {domain_label}")
        plt.plot(steps_val, val, label=f"Validation Domain {domain_label}")

    plt.xlabel("Steps")
    plt.ylabel("Accuracy")
    plt.title(dataset["title"])
    plt.legend()
    plt.savefig(f"../thesis/figures/{dataset['save_name']}.pgf", bbox_inches="tight")
    plt.close()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [None]:
import pandas as pd

# Load the eval data
eval_caltech256_2domain_A = pd.read_csv(
    "training_results/caltech256_2domain_A_eval.csv"
)
eval_caltech256_2domain_B = pd.read_csv(
    "training_results/caltech256_2domain_B_eval.csv"
)
eval_caltech256_2domain_variant_A = pd.read_csv(
    "training_results/caltech256_2domain_variant_A_eval.csv"
)
eval_caltech256_2domain_variant_B = pd.read_csv(
    "training_results/caltech256_2domain_variant_B_eval.csv"
)
eval_caltech256_3domain_A = pd.read_csv(
    "training_results/caltech256_3domain_A_eval.csv"
)
eval_caltech256_3domain_B = pd.read_csv(
    "training_results/caltech256_3domain_B_eval.csv"
)
eval_caltech256_3domain_C = pd.read_csv(
    "training_results/caltech256_3domain_C_eval.csv"
)
eval_cifar100_2domain_A = pd.read_csv("training_results/cifar100_2domain_A_eval.csv")
eval_cifar100_2domain_B = pd.read_csv("training_results/cifar100_2domain_B_eval.csv")

# Merge the dataframes and add new 'Domain' and 'Dataset Variant' columns
eval_caltech256_2domain_A["Domain"] = "A"
eval_caltech256_2domain_A["Dataset Variant"] = "Caltech-256 2-Domain Variant 1"
eval_caltech256_2domain_B["Domain"] = "B"
eval_caltech256_2domain_B["Dataset Variant"] = "Caltech-256 2-Domain Variant 1"
eval_caltech256_2domain_variant_A["Domain"] = "A"
eval_caltech256_2domain_variant_A["Dataset Variant"] = "Caltech-256 2-Domain Variant 2"
eval_caltech256_2domain_variant_B["Domain"] = "B"
eval_caltech256_2domain_variant_B["Dataset Variant"] = "Caltech-256 2-Domain Variant 2"
eval_caltech256_3domain_A["Domain"] = "A"
eval_caltech256_3domain_A["Dataset Variant"] = "Caltech-256 3-Domain Variant"
eval_caltech256_3domain_B["Domain"] = "B"
eval_caltech256_3domain_B["Dataset Variant"] = "Caltech-256 3-Domain Variant"
eval_caltech256_3domain_C["Domain"] = "C"
eval_caltech256_3domain_C["Dataset Variant"] = "Caltech-256 3-Domain Variant"
eval_cifar100_2domain_A["Domain"] = "A"
eval_cifar100_2domain_A["Dataset Variant"] = "CIFAR-100 2-Domain Variant"
eval_cifar100_2domain_B["Domain"] = "B"
eval_cifar100_2domain_B["Dataset Variant"] = "CIFAR-100 2-Domain Variant"

# Concatenate all dataframes
eval_data = pd.concat(
    [
        eval_caltech256_2domain_A,
        eval_caltech256_2domain_B,
        eval_caltech256_2domain_variant_A,
        eval_caltech256_2domain_variant_B,
        eval_caltech256_3domain_A,
        eval_caltech256_3domain_B,
        eval_caltech256_3domain_C,
        eval_cifar100_2domain_A,
        eval_cifar100_2domain_B,
    ],
    ignore_index=True,
)

# Delete 'Wall Time' column
eval_data.drop(columns=["Wall time"], inplace=True)

# Reorder: 'Dataset Variant', 'Domain', 'Step', 'Value'
eval_data = eval_data[["Dataset Variant", "Domain", "Step", "Value"]]

# Make 'Value' a float
eval_data["Value"] = eval_data["Value"].astype(float)

# Rename columns
eval_data.rename(columns={"Value": "Accuracy", "Step": "Steps"}, inplace=True)

latex_table = (
    eval_data.style.hide(axis="index")
    .format(precision=2)
    .to_latex(
        caption="Evaluation results on test sets. Models were checkpointed after every epoch and evaluated on the validation loss. The model with the lowest validation loss was selected for evaluation on the test set.",
        label="tab:evaluation_results",
        column_format="cccc",
        position="ht",
        position_float="centering",
        hrules=True,
    )
)

with open("../thesis/figures/evaluation_results.tex", "w") as f:
    f.write(latex_table)

In [None]:
import pandas as pd
import numpy as np
from torchvision.datasets import Caltech256

import sys
import os

# Get the parent directory and add it to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

from library.taxonomy_constructors import (
    SyntheticTaxonomy,
    CrossPredictionsTaxonomy,
)
from library.datasets import CIFAR100Mapped

In [None]:
cifar100_dataset = CIFAR100Mapped(root="../datasets/cifar100", download=False)
cifar100_labels = cifar100_dataset.classes
cifar100_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=len(cifar100_labels),
    num_domains=2,
    domain_class_count_mean=50,
    domain_class_count_variance=5,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=cifar100_labels,
    relationship_type="true",
)

In [None]:
cifar100_df_A = pd.read_csv("data/cifar100_2domain_A_predictions.csv")
cifar100_df_B = pd.read_csv("data/cifar100_2domain_B_predictions.csv")
cifar100_cross_domain_predictions = [
    (0, 1, np.array(cifar100_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(cifar100_df_A["predictions_B_on_A"], dtype=np.intp)),
]
cifar100_domain_targets = [
    (0, np.array(cifar100_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(cifar100_df_B["domain_B"], dtype=np.intp)),
]

In [None]:
# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech256_2d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

In [None]:
caltech256_2d_df_A = pd.read_csv("data/caltech256_2domain_A_predictions.csv")
caltech256_2d_df_B = pd.read_csv("data/caltech256_2domain_B_predictions.csv")
caltech256_2d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_2d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(caltech256_2d_df_A["predictions_B_on_A"], dtype=np.intp)),
]
caltech256_2d_domain_targets = [
    (0, np.array(caltech256_2d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_2d_df_B["domain_B"], dtype=np.intp)),
]

In [None]:
caltech256_2d_variant_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=200,
    domain_class_count_variance=10,
    concept_cluster_size_mean=2,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)
caltech256_2d_variant_df_A = pd.read_csv(
    "data/caltech256_2domain_variant_A_predictions.csv"
)
caltech256_2d_variant_df_B = pd.read_csv(
    "data/caltech256_2domain_variant_B_predictions.csv"
)
caltech256_2d_variant_cross_domain_predictions = [
    (0, 1, np.array(caltech256_2d_variant_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(caltech256_2d_variant_df_A["predictions_B_on_A"], dtype=np.intp)),
]
caltech256_2d_variant_domain_targets = [
    (0, np.array(caltech256_2d_variant_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_2d_variant_df_B["domain_B"], dtype=np.intp)),
]

In [None]:
caltech256_3d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=3,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=5,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)
caltech256_3d_df_A = pd.read_csv("data/caltech256_3domain_A_predictions.csv")
caltech256_3d_df_B = pd.read_csv("data/caltech256_3domain_B_predictions.csv")
caltech256_3d_df_C = pd.read_csv("data/caltech256_3domain_C_predictions.csv")

In [None]:
caltech256_3d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_3d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (0, 2, np.array(caltech256_3d_df_C["predictions_A_on_C"], dtype=np.intp)),
    (1, 0, np.array(caltech256_3d_df_A["predictions_B_on_A"], dtype=np.intp)),
    (1, 2, np.array(caltech256_3d_df_C["predictions_B_on_C"], dtype=np.intp)),
    (2, 0, np.array(caltech256_3d_df_A["predictions_C_on_A"], dtype=np.intp)),
    (2, 1, np.array(caltech256_3d_df_B["predictions_C_on_B"], dtype=np.intp)),
]
caltech256_3d_domain_targets = [
    (0, np.array(caltech256_3d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_3d_df_B["domain_B"], dtype=np.intp)),
    (2, np.array(caltech256_3d_df_C["domain_C"], dtype=np.intp)),
]

In [None]:
def evaluate_taxonomy(
    constructed_taxonomy, ground_truth_taxonomy, method_name, dataset_name, **params
):
    # Calculate metrics
    edr = constructed_taxonomy.edge_difference_ratio(ground_truth_taxonomy)
    precision, recall, f1 = constructed_taxonomy.precision_recall_f1(
        ground_truth_taxonomy
    )

    # Count relationships for analysis
    num_relationships = len(constructed_taxonomy.graph.edges())
    num_nodes = len(constructed_taxonomy.graph.nodes())

    results = {
        "method": method_name,
        "dataset": dataset_name,
        "edr": edr,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "num_relationships": num_relationships,
        "num_nodes": num_nodes,
        **params,  # Include method-specific parameters
    }

    return results

In [None]:
mcfp_results = []

cifar100_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=cifar100_cross_domain_predictions,
    domain_targets=cifar100_domain_targets,
    domain_labels=cifar100_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
cifar100_mcfp_result = evaluate_taxonomy(
    cifar100_mcfp_taxonomy, cifar100_synthetic_taxonomy, "mcfp", "cifar100_2domain"
)
mcfp_results.append(cifar100_mcfp_result)

caltech256_2d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_2d_cross_domain_predictions,
    domain_targets=caltech256_2d_domain_targets,
    domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_2d_mcfp_result = evaluate_taxonomy(
    caltech256_2d_mcfp_taxonomy,
    caltech256_2d_synthetic_taxonomy,
    "mcfp",
    "caltech256_2domain",
)
mcfp_results.append(caltech256_2d_mcfp_result)

caltech256_2d_variant_mcfp_taxonomy = (
    CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
        domain_targets=caltech256_2d_variant_domain_targets,
        domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
        relationship_type="mcfp",
    )
)
caltech256_2d_variant_mcfp_result = evaluate_taxonomy(
    caltech256_2d_variant_mcfp_taxonomy,
    caltech256_2d_variant_synthetic_taxonomy,
    "mcfp",
    "caltech256_2domain_variant",
)
mcfp_results.append(caltech256_2d_variant_mcfp_result)

caltech256_3d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_3d_cross_domain_predictions,
    domain_targets=caltech256_3d_domain_targets,
    domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_3d_mcfp_result = evaluate_taxonomy(
    caltech256_3d_mcfp_taxonomy,
    caltech256_3d_synthetic_taxonomy,
    "mcfp",
    "caltech256_3domain",
)
mcfp_results.append(caltech256_3d_mcfp_result)

In [None]:
hypothesis_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for upper_bound in range(1, 11):
    cifar100_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=cifar100_cross_domain_predictions,
            domain_targets=cifar100_domain_targets,
            domain_labels=cifar100_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_hypothesis_taxonomy,
        cifar100_synthetic_taxonomy,
        "hypothesis",
        "cifar100_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["cifar100_2domain"][upper_bound] = cifar100_result

    caltech256_2d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_hypothesis_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_2domain"][upper_bound] = caltech256_2d_result

    caltech256_2d_variant_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_hypothesis_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "hypothesis",
        "caltech256_2domain_variant",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_2domain_variant"][
        upper_bound
    ] = caltech256_2d_variant_result

    caltech256_3d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_hypothesis_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_3domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_3domain"][upper_bound] = caltech256_3d_result

In [None]:
density_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_density_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="density_threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_density_taxonomy,
        cifar100_synthetic_taxonomy,
        "density_threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    density_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_density_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_2d_variant_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_density_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "density_threshold",
        "caltech256_2domain_variant",
        threshold=threshold,
    )
    density_threshold_results["caltech256_2domain_variant"][
        threshold
    ] = caltech256_2d_variant_result

    caltech256_3d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_density_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [None]:
simple_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_simple_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_simple_taxonomy,
        cifar100_synthetic_taxonomy,
        "threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    simple_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_simple_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_2d_variant_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_simple_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "threshold",
        "caltech256_2domain_variant",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_2domain_variant"][
        threshold
    ] = caltech256_2d_variant_result

    caltech256_3d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_simple_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [None]:
import matplotlib

matplotlib.use("pgf")
import matplotlib.pyplot as plt

# LaTeX settings
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "font.size": 11,
        "pgf.texsystem": "lualatex",
    }
)

In [None]:
# Plot 1: Hypothesis method (upper_bound parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    upper_bounds = []

    for upper_bound, result in hypothesis_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        upper_bounds.append(upper_bound)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=upper_bounds, cmap="viridis", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Upper Bound")

plt.tight_layout()
plt.savefig(
    "../thesis/figures/hypothesis_method_precision_recall.pgf", bbox_inches="tight"
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [None]:
# Plot 2: Naive threshold method (threshold parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in simple_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="plasma", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Threshold")

plt.tight_layout()
plt.savefig(
    "../thesis/figures/naive_threshold_method_precision_recall.pgf",
    bbox_inches="tight",
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [None]:
# Plot 3: Density threshold method (threshold parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in density_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="inferno", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Threshold")

plt.tight_layout()
plt.savefig(
    "../thesis/figures/density_threshold_method_precision_recall.pgf",
    bbox_inches="tight",
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [None]:
# Create LaTeX table with best EDR results
import pandas as pd


# Helper function to find best EDR result for each dataset and method
def find_best_edr_result(results_dict, dataset):
    if dataset not in results_dict:
        return None

    best_result = None
    best_edr = float("inf")

    for param_value, result in results_dict[dataset].items():
        if result["edr"] < best_edr:
            best_edr = result["edr"]
            best_result = result

    return best_result


# Prepare data for the table with methods as rows
table_data = []

# Method definitions
methods = [
    ("MCFP", mcfp_results, None, "N/A"),
    ("Naive Thresholding", simple_threshold_results, "threshold", "threshold"),
    (
        "Density Thresholding",
        density_threshold_results,
        "threshold",
        "density_threshold",
    ),
    ("Relationship Hypothesis", hypothesis_results, "upper_bound", "upper_bound"),
]

for dataset, dataset_name in zip(datasets, dataset_titles):
    for method_name, results_source, param_key, param_name in methods:
        if method_name == "MCFP":
            # Handle MCFP results (stored differently)
            mcfp_result = next(
                (r for r in results_source if r["dataset"] == dataset), None
            )
            if mcfp_result:
                table_data.append(
                    {
                        "Dataset Variant": dataset_name,
                        "Method": method_name,
                        "EDR": f"{mcfp_result['edr']:.3f}",
                        "F1-score": f"{mcfp_result['f1']:.3f}",
                        "Parameter": "N/A",
                    }
                )
        else:
            # Handle parametric methods
            best_result = find_best_edr_result(results_source, dataset)
            if best_result:
                param_value = best_result[param_key] if param_key else "N/A"
                param_display = (
                    f"{param_value:.2f}"
                    if isinstance(param_value, float)
                    else str(param_value)
                )

                table_data.append(
                    {
                        "Dataset Variant": dataset_name,
                        "Method": method_name,
                        "EDR": f"{best_result['edr']:.3f}",
                        "F1-score": f"{best_result['f1']:.3f}",
                        "Parameter": param_display,
                    }
                )

# Create DataFrame with simple column structure
df = pd.DataFrame(table_data)


# Function to identify best values for each metric within each dataset variant
def get_best_metric_indices_per_dataset(df):
    """Find the row indices with the best values for each metric within each dataset variant"""
    best_edr_indices = []
    best_f1_indices = []

    # Group by dataset variant and find the best values for each metric
    for dataset_variant in df["Dataset Variant"].unique():
        dataset_rows = df[df["Dataset Variant"] == dataset_variant]

        # Best EDR (lowest value)
        edr_values = dataset_rows["EDR"].astype(float)
        best_edr_idx = edr_values.idxmin()
        best_edr_indices.append(best_edr_idx)

        # Best F1-score (highest value)
        f1_values = dataset_rows["F1-score"].astype(float)
        best_f1_idx = f1_values.idxmax()
        best_f1_indices.append(best_f1_idx)

    return best_edr_indices, best_f1_indices


# Function to identify dataset boundary rows
def get_dataset_boundaries(df):
    """Identify row indices where dataset variant changes"""
    boundaries = []
    current_dataset = None

    for i, row in df.iterrows():
        dataset_name = row["Dataset Variant"]
        if current_dataset is not None and dataset_name != current_dataset:
            boundaries.append(i)
        current_dataset = dataset_name

    return boundaries


# Get best metric indices and dataset boundaries
best_edr_indices, best_f1_indices = get_best_metric_indices_per_dataset(df)
dataset_boundaries = get_dataset_boundaries(df)


# Function to apply bold formatting to best values independently for each metric
def apply_bold_formatting(df, best_edr_indices, best_f1_indices):
    """Apply bold formatting to best values for each metric independently"""
    # Create a copy of the dataframe with formatting applied
    formatted_df = df.copy()

    # Bold the best EDR values
    for idx in best_edr_indices:
        formatted_df.loc[idx, "EDR"] = f"\\textbf{{{formatted_df.loc[idx, 'EDR']}}}"

    # Bold the best F1-score values
    for idx in best_f1_indices:
        formatted_df.loc[idx, "F1-score"] = (
            f"\\textbf{{{formatted_df.loc[idx, 'F1-score']}}}"
        )

    return formatted_df


# Apply bold formatting to the DataFrame
df_formatted = apply_bold_formatting(df, best_edr_indices, best_f1_indices)

# Create styler with formatted DataFrame
styler = df_formatted.style.hide(axis="index")

# Generate LaTeX table without clines first
latex_table = styler.to_latex(
    caption="Best EDR results for relationship discovery methods. For each dataset variant and method, the parameter values that yielded the lowest Edge Difference Ratio (EDR) are shown along with the corresponding F1-score.",
    label="tab:relationship_methods_best_edr",
    column_format="llccc",  # Left align dataset and method, center align metrics
    position="ht",
    position_float="centering",
    hrules=True,
)


# Post-process LaTeX to add hlines at dataset boundaries
def insert_hlines_at_boundaries(latex_str, boundaries):
    """Insert \\hline commands at dataset boundaries in LaTeX table"""
    lines = latex_str.split("\n")
    new_lines = []
    data_row_count = 0
    in_table_body = False

    for line in lines:
        # Track when we enter the table body
        if "\\midrule" in line:
            in_table_body = True
            new_lines.append(line)
            continue

        # Check if this is a data row (contains ampersands and ends with \\)
        if in_table_body and " & " in line and line.strip().endswith(" \\\\"):
            # If this row index is a boundary, add hline before it
            if data_row_count in boundaries:
                new_lines.append("\\hline")

            data_row_count += 1

        new_lines.append(line)

    return "\n".join(new_lines)


# Apply the hline insertion
latex_table = insert_hlines_at_boundaries(latex_table, dataset_boundaries)

# Save to file
with open("../thesis/figures/relationship_methods_results.tex", "w") as f:
    f.write(latex_table)

In [None]:
# Create table with globally optimal parameters and averaged performance metrics
import pandas as pd


# Function to find globally best parameter for each method across all datasets
def find_globally_best_parameter(results_dict, param_key, metric="edr"):
    """Find the parameter value that gives the best average performance across all datasets"""
    all_param_values = set()

    # Collect all possible parameter values
    for dataset in results_dict:
        all_param_values.update(results_dict[dataset].keys())

    best_param = None
    best_avg_metric = float("inf") if metric == "edr" else 0

    # For each parameter value, calculate average metric across all datasets
    for param_value in all_param_values:
        metric_values = []

        for dataset in results_dict:
            if param_value in results_dict[dataset]:
                metric_values.append(results_dict[dataset][param_value][metric])

        if metric_values:  # Only consider if we have data for this parameter
            avg_metric = sum(metric_values) / len(metric_values)

            # For EDR, lower is better; for F1, higher is better
            if metric == "edr" and avg_metric < best_avg_metric:
                best_avg_metric = avg_metric
                best_param = param_value
            elif metric == "f1" and avg_metric > best_avg_metric:
                best_avg_metric = avg_metric
                best_param = param_value

    return best_param


# Function to calculate average metrics for a given parameter across all datasets
def calculate_average_metrics(results_dict, param_value):
    """Calculate average EDR, F1, precision, and recall for a given parameter across all datasets"""
    metrics = {"edr": [], "f1": [], "precision": [], "recall": []}

    for dataset in results_dict:
        if param_value in results_dict[dataset]:
            result = results_dict[dataset][param_value]
            metrics["edr"].append(result["edr"])
            metrics["f1"].append(result["f1"])
            metrics["precision"].append(result["precision"])
            metrics["recall"].append(result["recall"])

    # Calculate averages
    avg_metrics = {}
    for metric_name, values in metrics.items():
        if values:
            avg_metrics[metric_name] = sum(values) / len(values)
        else:
            avg_metrics[metric_name] = None

    return avg_metrics


# Find globally optimal parameters for each method

# For simple threshold method
best_simple_threshold = find_globally_best_parameter(
    simple_threshold_results, "threshold", "edr"
)
simple_avg_metrics = calculate_average_metrics(
    simple_threshold_results, best_simple_threshold
)

# For density threshold method
best_density_threshold = find_globally_best_parameter(
    density_threshold_results, "threshold", "edr"
)
density_avg_metrics = calculate_average_metrics(
    density_threshold_results, best_density_threshold
)

# For hypothesis method
best_upper_bound = find_globally_best_parameter(
    hypothesis_results, "upper_bound", "edr"
)
hypothesis_avg_metrics = calculate_average_metrics(hypothesis_results, best_upper_bound)

# For MCFP method (no parameters to optimize)
mcfp_metrics = {"edr": [], "f1": [], "precision": [], "recall": []}
for result in mcfp_results:
    mcfp_metrics["edr"].append(result["edr"])
    mcfp_metrics["f1"].append(result["f1"])
    mcfp_metrics["precision"].append(result["precision"])
    mcfp_metrics["recall"].append(result["recall"])

mcfp_avg_metrics = {}
for metric_name, values in mcfp_metrics.items():
    mcfp_avg_metrics[metric_name] = sum(values) / len(values)

# Create table data for globally optimal results
global_table_data = [
    {
        "Method": "MCFP",
        "Parameter": "N/A",
        "EDR": f"{mcfp_avg_metrics['edr']:.3f}",
        "Precision": f"{mcfp_avg_metrics['precision']:.3f}",
        "Recall": f"{mcfp_avg_metrics['recall']:.3f}",
        "F1-score": f"{mcfp_avg_metrics['f1']:.3f}",
    },
    {
        "Method": "Naive Thresholding",
        "Parameter": f"{best_simple_threshold:.2f}",
        "EDR": f"{simple_avg_metrics['edr']:.3f}",
        "Precision": f"{simple_avg_metrics['precision']:.3f}",
        "Recall": f"{simple_avg_metrics['recall']:.3f}",
        "F1-score": f"{simple_avg_metrics['f1']:.3f}",
    },
    {
        "Method": "Density Thresholding",
        "Parameter": f"{best_density_threshold:.2f}",
        "EDR": f"{density_avg_metrics['edr']:.3f}",
        "Precision": f"{density_avg_metrics['precision']:.3f}",
        "Recall": f"{density_avg_metrics['recall']:.3f}",
        "F1-score": f"{density_avg_metrics['f1']:.3f}",
    },
    {
        "Method": "Relationship Hypothesis",
        "Parameter": f"{best_upper_bound}",
        "EDR": f"{hypothesis_avg_metrics['edr']:.3f}",
        "Precision": f"{hypothesis_avg_metrics['precision']:.3f}",
        "Recall": f"{hypothesis_avg_metrics['recall']:.3f}",
        "F1-score": f"{hypothesis_avg_metrics['f1']:.3f}",
    },
]

# Create DataFrame
global_df = pd.DataFrame(global_table_data)


# Function to find best values for each metric column independently
def apply_bold_to_best_values(df):
    """Apply bold formatting to the best value in each metric column independently"""
    formatted_df = df.copy()

    # Find best EDR (lowest value)
    edr_values = [float(row["EDR"]) for row in global_table_data]
    best_edr_value = min(edr_values)

    # Find best Precision (highest value)
    precision_values = [float(row["Precision"]) for row in global_table_data]
    best_precision_value = max(precision_values)

    # Find best Recall (highest value)
    recall_values = [float(row["Recall"]) for row in global_table_data]
    best_recall_value = max(recall_values)

    # Find best F1-score (highest value)
    f1_values = [float(row["F1-score"]) for row in global_table_data]
    best_f1_value = max(f1_values)

    # Apply bold formatting to best values
    for idx, row in formatted_df.iterrows():
        if float(row["EDR"]) == best_edr_value:
            formatted_df.loc[idx, "EDR"] = f"\\textbf{{{row['EDR']}}}"

        if float(row["Precision"]) == best_precision_value:
            formatted_df.loc[idx, "Precision"] = f"\\textbf{{{row['Precision']}}}"

        if float(row["Recall"]) == best_recall_value:
            formatted_df.loc[idx, "Recall"] = f"\\textbf{{{row['Recall']}}}"

        if float(row["F1-score"]) == best_f1_value:
            formatted_df.loc[idx, "F1-score"] = f"\\textbf{{{row['F1-score']}}}"

    return formatted_df


# Apply bold formatting to best values in each column
global_df_formatted = apply_bold_to_best_values(global_df)

# Create styler and generate LaTeX table
global_styler = global_df_formatted.style.hide(axis="index")

global_latex_table = global_styler.to_latex(
    caption="Average performance metrics for relationship discovery methods with globally optimal parameters. Each method uses the parameter value that minimizes the average EDR across all dataset variants. Performance metrics are then averaged across all dataset variants using these optimal parameters.",
    label="tab:relationship_methods_global_optimal",
    column_format="lccccc",  # Left align method, center align others
    position="ht",
    position_float="centering",
    hrules=True,
)

# Save to file
with open("../thesis/figures/relationship_methods_global_optimal.tex", "w") as f:
    f.write(global_latex_table)