In [1]:
import pandas as pd
import numpy as np
from torchvision.datasets import Caltech256

import sys
import os

# Get the parent directory and add it to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

from library.taxonomy_constructors import (
    SyntheticTaxonomy,
    CrossPredictionsTaxonomy,
)
from library.datasets import CIFAR100Mapped

In [2]:
cifar100_dataset = CIFAR100Mapped(root="../datasets/cifar100", download=False)
cifar100_labels = cifar100_dataset.classes
cifar100_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=len(cifar100_labels),
    num_domains=2,
    domain_class_count_mean=50,
    domain_class_count_variance=5,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=cifar100_labels,
    relationship_type="true",
)

In [3]:
cifar100_df_A = pd.read_csv("../data/cifar100_2domain_A_predictions.csv")
cifar100_df_B = pd.read_csv("../data/cifar100_2domain_B_predictions.csv")
cifar100_cross_domain_predictions = [
    (0, 1, np.array(cifar100_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(cifar100_df_A["predictions_B_on_A"], dtype=np.intp)),
]
cifar100_domain_targets = [
    (0, np.array(cifar100_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(cifar100_df_B["domain_B"], dtype=np.intp)),
]

In [4]:
# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="../datasets/caltech256", download=False).categories
caltech256_2d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

In [5]:
caltech256_2d_df_A = pd.read_csv("../data/caltech256_2domain_A_predictions.csv")
caltech256_2d_df_B = pd.read_csv("../data/caltech256_2domain_B_predictions.csv")
caltech256_2d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_2d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (1, 0, np.array(caltech256_2d_df_A["predictions_B_on_A"], dtype=np.intp)),
]
caltech256_2d_domain_targets = [
    (0, np.array(caltech256_2d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_2d_df_B["domain_B"], dtype=np.intp)),
]

In [6]:
caltech256_3d_synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=3,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=5,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)
caltech256_3d_df_A = pd.read_csv("../data/caltech256_3domain_A_predictions.csv")
caltech256_3d_df_B = pd.read_csv("../data/caltech256_3domain_B_predictions.csv")
caltech256_3d_df_C = pd.read_csv("../data/caltech256_3domain_C_predictions.csv")

In [7]:
caltech256_3d_cross_domain_predictions = [
    (0, 1, np.array(caltech256_3d_df_B["predictions_A_on_B"], dtype=np.intp)),
    (0, 2, np.array(caltech256_3d_df_C["predictions_A_on_C"], dtype=np.intp)),
    (1, 0, np.array(caltech256_3d_df_A["predictions_B_on_A"], dtype=np.intp)),
    (1, 2, np.array(caltech256_3d_df_C["predictions_B_on_C"], dtype=np.intp)),
    (2, 0, np.array(caltech256_3d_df_A["predictions_C_on_A"], dtype=np.intp)),
    (2, 1, np.array(caltech256_3d_df_B["predictions_C_on_B"], dtype=np.intp)),
]
caltech256_3d_domain_targets = [
    (0, np.array(caltech256_3d_df_A["domain_A"], dtype=np.intp)),
    (1, np.array(caltech256_3d_df_B["domain_B"], dtype=np.intp)),
    (2, np.array(caltech256_3d_df_C["domain_C"], dtype=np.intp)),
]

In [8]:
def evaluate_taxonomy(
    constructed_taxonomy, ground_truth_taxonomy, method_name, dataset_name, **params
):
    # Calculate metrics
    edr = constructed_taxonomy.edge_difference_ratio(ground_truth_taxonomy)
    precision, recall, f1 = constructed_taxonomy.precision_recall_f1(
        ground_truth_taxonomy
    )

    # Count relationships for analysis
    num_relationships = len(constructed_taxonomy.graph.edges())
    num_nodes = len(constructed_taxonomy.graph.nodes())

    results = {
        "method": method_name,
        "dataset": dataset_name,
        "edr": edr,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "num_relationships": num_relationships,
        "num_nodes": num_nodes,
        **params,  # Include method-specific parameters
    }

    return results

In [9]:
mcfp_results = []

cifar100_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=cifar100_cross_domain_predictions,
    domain_targets=cifar100_domain_targets,
    domain_labels=cifar100_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
cifar100_mcfp_result = evaluate_taxonomy(
    cifar100_mcfp_taxonomy, cifar100_synthetic_taxonomy, "mcfp", "cifar100_2domain"
)
mcfp_results.append(cifar100_mcfp_result)

caltech256_2d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_2d_cross_domain_predictions,
    domain_targets=caltech256_2d_domain_targets,
    domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_2d_mcfp_result = evaluate_taxonomy(
    caltech256_2d_mcfp_taxonomy,
    caltech256_2d_synthetic_taxonomy,
    "mcfp",
    "caltech256_2domain",
)
mcfp_results.append(caltech256_2d_mcfp_result)

caltech256_3d_mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=caltech256_3d_cross_domain_predictions,
    domain_targets=caltech256_3d_domain_targets,
    domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
    relationship_type="mcfp",
)
caltech256_3d_mcfp_result = evaluate_taxonomy(
    caltech256_3d_mcfp_taxonomy,
    caltech256_3d_synthetic_taxonomy,
    "mcfp",
    "caltech256_3domain",
)
mcfp_results.append(caltech256_3d_mcfp_result)

In [10]:
hypothesis_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_3domain": {},
}

for upper_bound in range(1, 11):
    cifar100_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=cifar100_cross_domain_predictions,
            domain_targets=cifar100_domain_targets,
            domain_labels=cifar100_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_hypothesis_taxonomy,
        cifar100_synthetic_taxonomy,
        "hypothesis",
        "cifar100_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["cifar100_2domain"][upper_bound] = cifar100_result

    caltech256_2d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_hypothesis_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_2domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_2domain"][upper_bound] = caltech256_2d_result

    caltech256_3d_hypothesis_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="hypothesis",
            upper_bound=upper_bound,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_hypothesis_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "hypothesis",
        "caltech256_3domain",
        upper_bound=upper_bound,
    )
    hypothesis_results["caltech256_3domain"][upper_bound] = caltech256_3d_result

In [11]:
density_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0.1, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_density_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="density_threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_density_taxonomy,
        cifar100_synthetic_taxonomy,
        "density_threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    density_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_density_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_3d_density_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="density_threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_density_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "density_threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    density_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [12]:
simple_threshold_results = {
    "cifar100_2domain": {},
    "caltech256_2domain": {},
    "caltech256_3domain": {},
}

for threshold in np.arange(0.1, 1.05, 0.05):
    threshold = round(threshold, 2).astype(float)

    cifar100_simple_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
        cross_domain_predictions=cifar100_cross_domain_predictions,
        domain_targets=cifar100_domain_targets,
        domain_labels=cifar100_synthetic_taxonomy.domain_labels,
        relationship_type="threshold",
        threshold=threshold,
    )
    cifar100_result = evaluate_taxonomy(
        cifar100_simple_taxonomy,
        cifar100_synthetic_taxonomy,
        "threshold",
        "cifar100_2domain",
        threshold=threshold,
    )
    simple_threshold_results["cifar100_2domain"][threshold] = cifar100_result

    caltech256_2d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_2d_cross_domain_predictions,
            domain_targets=caltech256_2d_domain_targets,
            domain_labels=caltech256_2d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_2d_result = evaluate_taxonomy(
        caltech256_2d_simple_taxonomy,
        caltech256_2d_synthetic_taxonomy,
        "threshold",
        "caltech256_2domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_2domain"][threshold] = caltech256_2d_result

    caltech256_3d_simple_taxonomy = (
        CrossPredictionsTaxonomy.from_cross_domain_predictions(
            cross_domain_predictions=caltech256_3d_cross_domain_predictions,
            domain_targets=caltech256_3d_domain_targets,
            domain_labels=caltech256_3d_synthetic_taxonomy.domain_labels,
            relationship_type="threshold",
            threshold=threshold,
        )
    )
    caltech256_3d_result = evaluate_taxonomy(
        caltech256_3d_simple_taxonomy,
        caltech256_3d_synthetic_taxonomy,
        "threshold",
        "caltech256_3domain",
        threshold=threshold,
    )
    simple_threshold_results["caltech256_3domain"][threshold] = caltech256_3d_result

In [13]:
import matplotlib

matplotlib.use("pgf")
import matplotlib.pyplot as plt

# LaTeX settings
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "font.size": 11,
        "pgf.texsystem": "lualatex",
    }
)

In [14]:
# Plot 1: Hypothesis method (upper_bound parameter)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
datasets = ["cifar100_2domain", "caltech256_2domain", "caltech256_3domain"]
dataset_titles = ["CIFAR-100 2-Domain", "Caltech-256 2-Domain", "Caltech-256 3-Domain"]

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i]

    # Extract data for this dataset
    precisions = []
    recalls = []
    upper_bounds = []

    for upper_bound, result in hypothesis_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        upper_bounds.append(upper_bound)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=upper_bounds, cmap="viridis", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Upper Bound")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/hypothesis_method_precision_recall.pgf", bbox_inches="tight"
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [15]:
# Plot 2: Naive thresholding method (threshold parameter)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in simple_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="plasma", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Threshold")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/naive_thresholding_precision_recall.pgf", bbox_inches="tight"
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [16]:
# Plot 3: Density thresholding method (density_threshold parameter)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (dataset, title) in enumerate(zip(datasets, dataset_titles)):
    ax = axes[i]

    # Extract data for this dataset
    precisions = []
    recalls = []
    thresholds = []

    for threshold, result in density_threshold_results[dataset].items():
        precisions.append(result["precision"])
        recalls.append(result["recall"])
        thresholds.append(threshold)

    # Create scatter plot with colormap
    scatter = ax.scatter(recalls, precisions, c=thresholds, cmap="inferno", s=50)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label("Density Threshold")

plt.tight_layout()
plt.savefig(
    "../../thesis/figures/density_thresholding_precision_recall.pgf",
    bbox_inches="tight",
)
plt.show()

Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown font: EB Garamond
Ignoring unknown fon

In [17]:
# Create LaTeX table with best EDR results
import pandas as pd


# Helper function to find best EDR result for each dataset and method
def find_best_edr_result(results_dict, dataset):
    if dataset not in results_dict:
        return None

    best_result = None
    best_edr = float("inf")

    for param_value, result in results_dict[dataset].items():
        if result["edr"] < best_edr:
            best_edr = result["edr"]
            best_result = result

    return best_result


# Prepare data for the table
table_data = []

for dataset, dataset_name in zip(datasets, dataset_titles):
    row = {"Dataset Variant": dataset_name}

    # MCFP method (no parameters)
    mcfp_result = next((r for r in mcfp_results if r["dataset"] == dataset), None)
    if mcfp_result:
        row["MCFP EDR"] = f"{mcfp_result['edr']:.3f}"
        row["MCFP F1"] = f"{mcfp_result['f1']:.3f}"
        row["MCFP Param"] = "N/A"

    # Naive thresholding method
    naive_result = find_best_edr_result(simple_threshold_results, dataset)
    if naive_result:
        row["Naive EDR"] = f"{naive_result['edr']:.3f}"
        row["Naive F1"] = f"{naive_result['f1']:.3f}"
        row["Naive Param"] = f"{naive_result['threshold']:.2f}"

    # Density thresholding method
    density_result = find_best_edr_result(density_threshold_results, dataset)
    if density_result:
        row["Density EDR"] = f"{density_result['edr']:.3f}"
        row["Density F1"] = f"{density_result['f1']:.3f}"
        row["Density Param"] = f"{density_result['threshold']:.2f}"

    # Hypothesis method
    hypothesis_result = find_best_edr_result(hypothesis_results, dataset)
    if hypothesis_result:
        row["Hypothesis EDR"] = f"{hypothesis_result['edr']:.3f}"
        row["Hypothesis F1"] = f"{hypothesis_result['f1']:.3f}"
        row["Hypothesis Param"] = f"{hypothesis_result['upper_bound']}"

    table_data.append(row)

# Create DataFrame
df = pd.DataFrame(table_data)

# Generate LaTeX table
latex_table = df.to_latex(
    index=False,
    escape=False,
    caption="Best EDR Results for Relationship Discovery Methods. For each dataset variant and method, the parameter values that yielded the lowest Edge Difference Ratio (EDR) are shown along with the corresponding F1-score.",
    label="tab:relationship_methods_best_edr",
    column_format="l"
    + "ccc" * 4,  # Left align dataset names, center align method columns
    position="ht",
    float_format="%.3f",
)

# Add multi-column headers manually
latex_lines = latex_table.split("\n")
header_idx = None
for i, line in enumerate(latex_lines):
    if "Dataset Variant" in line:
        header_idx = i
        break

if header_idx:
    # Insert multi-column header line
    multicolumn_header = "& \\multicolumn{3}{c}{MCFP} & \\multicolumn{3}{c}{Naive Thresholding} & \\multicolumn{3}{c}{Density Thresholding} & \\multicolumn{3}{c}{Relationship Hypothesis} \\\\"
    latex_lines.insert(header_idx, multicolumn_header)

    # Modify the existing header line
    latex_lines[header_idx + 1] = (
        "Dataset Variant & EDR & F1-score & Param & EDR & F1-score & Param & EDR & F1-score & Param & EDR & F1-score & Param \\\\"
    )

# Rejoin the lines
latex_table = "\n".join(latex_lines)

# Save to file
with open("../../thesis/figures/relationship_methods_results.tex", "w") as f:
    f.write(latex_table)

print("LaTeX table saved to relationship_methods_results.tex")
print("\nTable preview:")
print(df.to_string(index=False))

LaTeX table saved to relationship_methods_results.tex

Table preview:
     Dataset Variant MCFP EDR MCFP F1 MCFP Param Naive EDR Naive F1 Naive Param Density EDR Density F1 Density Param Hypothesis EDR Hypothesis F1 Hypothesis Param
  CIFAR-100 2-Domain    0.634   0.636        N/A     0.525    0.770        0.15       0.562      0.688          0.40          0.534         0.595                4
Caltech-256 2-Domain    0.670   0.526        N/A     0.459    0.798        0.15       0.508      0.662          0.70          0.482         0.737                5
Caltech-256 3-Domain    0.707   0.443        N/A     0.377    0.864        0.10       0.420      0.740          0.75          0.391         0.830                6
