In [1]:
import pandas as pd
import numpy as np
from torchvision.datasets import Caltech256

import sys
import os

# Get the parent directory and add it to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

from library.taxonomy_constructors import (
    SyntheticTaxonomy,
    CrossPredictionsTaxonomy,
)
from library.datasets import CIFAR100Mapped

In [2]:
cifar100_dataset = CIFAR100Mapped(root="../datasets/cifar100", download=False)
cifar100_labels = cifar100_dataset.classes
cifar100_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=len(cifar100_labels),
    num_domains=2,
    domain_class_count_mean=50,
    domain_class_count_variance=5,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=cifar100_labels,
    relationship_type="true",
)

In [3]:
cifar100_df_A = pd.read_csv("../data/cifar100_2domain_A_predictions.csv")
cifar100_df_B = pd.read_csv("../data/cifar100_2domain_B_predictions.csv")
cifar100_cross_domain_predictions = [
    (0, 1, np.array(cifar100_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(cifar100_df_A["predictions_B_on_A"], dtype=np.intp)),
]
cifar100_domain_targets = [
    (0, np.array(cifar100_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(cifar100_df_B["domain_B"], dtype=np.intp)),
]

In [4]:
# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="../datasets/caltech256", download=False).categories
caltech256_2d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

In [5]:
caltech256_2d_df_A = pd.read_csv("../data/caltech256_2domain_A_predictions.csv")
caltech256_2d_df_B = pd.read_csv("../data/caltech256_2domain_B_predictions.csv")
caltech256_2d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_2d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(caltech256_2d_df_A["predictions_B_on_A"], dtype=np.intp)),
]
caltech256_2d_domain_targets = [
    (0, np.array(caltech256_2d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_2d_df_B["domain_B"], dtype=np.intp)),
]

In [6]:
caltech256_2d_variant_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=200,
    domain_class_count_variance=10,
    concept_cluster_size_mean=2,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)
caltech256_2d_variant_df_A = pd.read_csv(
    "../data/caltech256_2domain_variant_A_predictions.csv"
)
caltech256_2d_variant_df_B = pd.read_csv(
    "../data/caltech256_2domain_variant_B_predictions.csv"
)
caltech256_2d_variant_cross_domain_predictions = [
    (0, 1, np.array(caltech256_2d_variant_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(caltech256_2d_variant_df_A["predictions_B_on_A"], dtype=np.intp)),
]
caltech256_2d_variant_domain_targets = [
    (0, np.array(caltech256_2d_variant_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_2d_variant_df_B["domain_B"], dtype=np.intp)),
]

In [7]:
caltech256_3d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=3,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=5,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)
caltech256_3d_df_A = pd.read_csv("../data/caltech256_3domain_A_predictions.csv")
caltech256_3d_df_B = pd.read_csv("../data/caltech256_3domain_B_predictions.csv")
caltech256_3d_df_C = pd.read_csv("../data/caltech256_3domain_C_predictions.csv")

In [8]:
caltech256_3d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_3d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (0, 2, np.array(caltech256_3d_df_C["predictions_A_on_C"], dtype=np.intp)),
    (1, 0, np.array(caltech256_3d_df_A["predictions_B_on_A"], dtype=np.intp)),
    (1, 2, np.array(caltech256_3d_df_C["predictions_B_on_C"], dtype=np.intp)),
    (2, 0, np.array(caltech256_3d_df_A["predictions_C_on_A"], dtype=np.intp)),
    (2, 1, np.array(caltech256_3d_df_B["predictions_C_on_B"], dtype=np.intp)),
]
caltech256_3d_domain_targets = [
    (0, np.array(caltech256_3d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_3d_df_B["domain_B"], dtype=np.intp)),
    (2, np.array(caltech256_3d_df_C["domain_C"], dtype=np.intp)),
]

In [9]:
def evaluate_taxonomy(
    constructed_taxonomy, ground_truth_taxonomy, method_name, dataset_name, **params
):
    # Calculate metrics
    edr = constructed_taxonomy.edge_difference_ratio(ground_truth_taxonomy)
    precision, recall, f1 = constructed_taxonomy.precision_recall_f1(
        ground_truth_taxonomy
    )

    # Count relationships for analysis
    num_relationships = len(constructed_taxonomy.graph.edges())
    num_nodes = len(constructed_taxonomy.graph.nodes())

    results = {
        "method": method_name,
        "dataset": dataset_name,
        "edr": edr,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "num_relationships": num_relationships,
        "num_nodes": num_nodes,
        **params,  # Include method-specific parameters
    }

    return results

In [10]:
mcfp_results = []

cifar100_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=cifar100_cross_domain_predictions,
    domain_targets=cifar100_domain_targets,
    domain_labels=cifar100_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
cifar100_mcfp_result = evaluate_taxonomy(
    cifar100_mcfp_taxonomy, cifar100_synthetic_taxonomy, "mcfp", "cifar100_2domain"
)
mcfp_results.append(cifar100_mcfp_result)

caltech256_2d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_2d_cross_domain_predictions,
    domain_targets=caltech256_2d_domain_targets,
    domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_2d_mcfp_result = evaluate_taxonomy(
    caltech256_2d_mcfp_taxonomy,
    caltech256_2d_synthetic_taxonomy,
    "mcfp",
    "caltech256_2domain",
)
mcfp_results.append(caltech256_2d_mcfp_result)

caltech256_2d_variant_mcfp_taxonomy = (
    CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
        domain_targets=caltech256_2d_variant_domain_targets,
        domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
        relationship_type="mcfp",
    )
)
caltech256_2d_variant_mcfp_result = evaluate_taxonomy(
    caltech256_2d_variant_mcfp_taxonomy,
    caltech256_2d_variant_synthetic_taxonomy,
    "mcfp",
    "caltech256_2domain_variant",
)
mcfp_results.append(caltech256_2d_variant_mcfp_result)

caltech256_3d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_3d_cross_domain_predictions,
    domain_targets=caltech256_3d_domain_targets,
    domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_3d_mcfp_result = evaluate_taxonomy(
    caltech256_3d_mcfp_taxonomy,
    caltech256_3d_synthetic_taxonomy,
    "mcfp",
    "caltech256_3domain",
)
mcfp_results.append(caltech256_3d_mcfp_result)

In [11]:
hypothesis_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for upper_bound in range(1, 11):
    cifar100_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=cifar100_cross_domain_predictions,
            domain_targets=cifar100_domain_targets,
            domain_labels=cifar100_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_hypothesis_taxonomy,
        cifar100_synthetic_taxonomy,
        "hypothesis",
        "cifar100_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["cifar100_2domain"][upper_bound] = cifar100_result

    caltech256_2d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_hypothesis_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_2domain"][upper_bound] = caltech256_2d_result

    caltech256_2d_variant_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_hypothesis_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "hypothesis",
        "caltech256_2domain_variant",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_2domain_variant"][
        upper_bound
    ] = caltech256_2d_variant_result

    caltech256_3d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_hypothesis_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_3domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_3domain"][upper_bound] = caltech256_3d_result

In [12]:
density_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0.1, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_density_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="density_threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_density_taxonomy,
        cifar100_synthetic_taxonomy,
        "density_threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    density_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_density_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_2d_variant_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_density_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "density_threshold",
        "caltech256_2domain_variant",
        threshold=threshold,
    )
    density_threshold_results["caltech256_2domain_variant"][
        threshold
    ] = caltech256_2d_variant_result

    caltech256_3d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_density_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [13]:
simple_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_2domain_variant": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0.1, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_simple_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_simple_taxonomy,
        cifar100_synthetic_taxonomy,
        "threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    simple_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_simple_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_2d_variant_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_variant_cross_domain_predictions,
            domain_targets=caltech256_2d_variant_domain_targets,
            domain_labels=caltech256_2d_variant_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_variant_result = evaluate_taxonomy(
        caltech256_2d_variant_simple_taxonomy,
        caltech256_2d_variant_synthetic_taxonomy,
        "threshold",
        "caltech256_2domain_variant",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_2domain_variant"][
        threshold
    ] = caltech256_2d_variant_result

    caltech256_3d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_simple_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [14]:
import matplotlib

matplotlib.use("pgf")
import matplotlib.pyplot as plt

# LaTeX settings
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "font.size": 11,
        "pgf.texsystem": "lualatex",
    }
)

In [15]:
# Plot 1: Hypothesis method (upper_bound parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    upper_bounds = []

    for upper_bound, result in hypothesis_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        upper_bounds.append(upper_bound)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=upper_bounds, cmap="viridis", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Upper Bound")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/hypothesis_method_precision_recall.pgf", bbox_inches="tight"
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [16]:
# Plot 2: Naive threshold method (threshold parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in simple_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="plasma", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Threshold")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/naive_threshold_method_precision_recall.pgf",
    bbox_inches="tight",
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [17]:
# Plot 3: Density threshold method (threshold parameter)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
datasets = [
    "caltech256_2domain",
    "caltech256_2domain_variant",
    "caltech256_3domain",
    "cifar100_2domain",
]
dataset_titles = [
    "Caltech-256 2-Domain Variant 1",
    "Caltech-256 2-Domain Variant 2",
    "Caltech-256 3-Domain Variant",
    "CIFAR-100 2-Domain Variant",
]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i // 2, i % 2]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in density_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="inferno", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Threshold")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/density_threshold_method_precision_recall.pgf",
    bbox_inches="tight",
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [18]:
# Create LaTeX table with best EDR results
import pandas as pd


# Helper function to find best EDR result for each dataset and method
def find_best_edr_result(results_dict, dataset):
    if dataset not in results_dict:
        return None

    best_result = None
    best_edr = float("inf")

    for param_value, result in results_dict[dataset].items():
        if result["edr"] < best_edr:
            best_edr = result["edr"]
            best_result = result

    return best_result


# Prepare data for the table with methods as rows
table_data = []

# Method definitions
methods = [
    ("MCFP", mcfp_results, None, "N/A"),
    ("Naive Thresholding", simple_threshold_results, "threshold", "threshold"),
    (
        "Density Thresholding",
        density_threshold_results,
        "threshold",
        "density_threshold",
    ),
    ("Relationship Hypothesis", hypothesis_results, "upper_bound", "upper_bound"),
]

for dataset, dataset_name in zip(datasets, dataset_titles):
    for method_name, results_source, param_key, param_name in methods:
        if method_name == "MCFP":
            # Handle MCFP results (stored differently)
            mcfp_result = next(
                (r for r in results_source if r["dataset"] == dataset), None
            )
            if mcfp_result:
                table_data.append(
                    {
                        "Dataset Variant": dataset_name,
                        "Method": method_name,
                        "EDR": f"{mcfp_result['edr']:.3f}",
                        "F1-score": f"{mcfp_result['f1']:.3f}",
                        "Parameter": "N/A",
                    }
                )
        else:
            # Handle parametric methods
            best_result = find_best_edr_result(results_source, dataset)
            if best_result:
                param_value = best_result[param_key] if param_key else "N/A"
                param_display = (
                    f"{param_value:.2f}"
                    if isinstance(param_value, float)
                    else str(param_value)
                )

                table_data.append(
                    {
                        "Dataset Variant": dataset_name,
                        "Method": method_name,
                        "EDR": f"{best_result['edr']:.3f}",
                        "F1-score": f"{best_result['f1']:.3f}",
                        "Parameter": param_display,
                    }
                )

# Create DataFrame with simple column structure
df = pd.DataFrame(table_data)


# Function to identify best EDR scores for each dataset variant
def get_best_edr_indices(df):
    """Find the row indices with the best (lowest) EDR score for each dataset variant"""
    best_edr_indices = []

    # Group by dataset variant and find the minimum EDR
    for dataset_variant in df["Dataset Variant"].unique():
        dataset_rows = df[df["Dataset Variant"] == dataset_variant]
        # Convert EDR strings to float for comparison
        edr_values = dataset_rows["EDR"].astype(float)
        best_idx = edr_values.idxmin()
        best_edr_indices.append(best_idx)

    return best_edr_indices


# Function to identify dataset boundary rows
def get_dataset_boundaries(df):
    """Identify row indices where dataset variant changes"""
    boundaries = []
    current_dataset = None

    for i, row in df.iterrows():
        dataset_name = row["Dataset Variant"]
        if current_dataset is not None and dataset_name != current_dataset:
            boundaries.append(i)
        current_dataset = dataset_name

    return boundaries


# Get best EDR indices and dataset boundaries
best_edr_indices = get_best_edr_indices(df)
dataset_boundaries = get_dataset_boundaries(df)


# Function to apply bold formatting to best EDR scores
def apply_bold_to_best_edr(df, best_indices):
    """Apply bold formatting to EDR values at specified indices"""
    # Create a copy of the dataframe with formatting applied
    formatted_df = df.copy()
    for idx in best_indices:
        formatted_df.loc[idx, "EDR"] = f"\\textbf{{{formatted_df.loc[idx, 'EDR']}}}"

    return formatted_df


# Apply bold formatting to the DataFrame
df_formatted = apply_bold_to_best_edr(df, best_edr_indices)

# Create styler with formatted DataFrame
styler = df_formatted.style.hide(axis="index")

# Generate LaTeX table without clines first
latex_table = styler.to_latex(
    caption="Best EDR Results for Relationship Discovery Methods. For each dataset variant and method, the parameter values that yielded the lowest Edge Difference Ratio (EDR) are shown along with the corresponding F1-score.",
    label="tab:relationship_methods_best_edr",
    column_format="llccc",  # Left align dataset and method, center align metrics
    position="ht",
    position_float="centering",
    hrules=True,
)


# Post-process LaTeX to add hlines at dataset boundaries
def insert_hlines_at_boundaries(latex_str, boundaries):
    """Insert \\hline commands at dataset boundaries in LaTeX table"""
    lines = latex_str.split("\n")
    new_lines = []
    data_row_count = 0
    in_table_body = False

    for line in lines:
        # Track when we enter the table body
        if "\\midrule" in line:
            in_table_body = True
            new_lines.append(line)
            continue

        # Check if this is a data row (contains ampersands and ends with \\)
        if in_table_body and " & " in line and line.strip().endswith(" \\\\"):
            # If this row index is a boundary, add hline before it
            if data_row_count in boundaries:
                new_lines.append("\\hline")

            data_row_count += 1

        new_lines.append(line)

    return "\n".join(new_lines)


# Apply the hline insertion
latex_table = insert_hlines_at_boundaries(latex_table, dataset_boundaries)

# Save to file
with open("../../thesis/figures/relationship_methods_results.tex", "w") as f:
    f.write(latex_table)