In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import SyntheticTaxonomy, CrossPredictionsTaxonomy
from library.models import ResNetModel
from library.datasets import Caltech256MappedDataModule

# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech256_targets = pd.read_csv("data/caltech256.csv")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

In [None]:
# Create synthetic taxonomy with 2 domains
synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

# Extract domain mappings
domain_A_mapping = synthetic_taxonomy.domains[0].to_mapping()
domain_B_mapping = synthetic_taxonomy.domains[1].to_mapping()

print(f"Domain A classes: {len(set(domain_A_mapping.values()))}")
print(f"Domain B classes: {len(set(domain_B_mapping.values()))}")

In [None]:
# Configuration
TRAIN = False  # Set to True to train models from scratch


def train_domain_model(domain_mapping, domain_name, logger_name, model_name):
    """Train a ResNet model for a specific domain"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)
    dataset = Caltech256MappedDataModule(mapping=domain_mapping)

    trainer = Trainer(
        max_epochs=50,
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=len(set(domain_mapping.values())),
            architecture="resnet50",
            optim="sgd",
            optim_kwargs={
                "lr": 0.01,
                "momentum": 0.9,
                "weight_decay": 5e-4,
            },
        )
        trainer.fit(model, datamodule=dataset)
        results = trainer.test(datamodule=dataset, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset)

    print(f"{domain_name} Results: {results}")
    return results

In [None]:
# Train Domain A model
print("Training/Testing Domain A Model:")
domain_A_results = train_domain_model(
    domain_A_mapping,
    "Domain A",
    "caltech256_synthetic_A",
    "resnet50-caltech256-synthetic-A-min-val-loss",
)

In [None]:
# Train Domain B model
print("Training/Testing Domain B Model:")
domain_B_results = train_domain_model(
    domain_B_mapping,
    "Domain B",
    "caltech256_synthetic_B",
    "resnet50-caltech256-synthetic-B-min-val-loss",
)

In [None]:
# Configuration for prediction generation
PREDICT = False  # Set to True to generate predictions from scratch

if PREDICT:
    # Load datasets
    dataset_domain_A = Caltech256MappedDataModule(mapping=domain_A_mapping)
    dataset_domain_B = Caltech256MappedDataModule(mapping=domain_B_mapping)

    # Load trained models
    model_domain_A = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-A-min-val-loss.ckpt"
    )
    model_domain_A.eval()

    model_domain_B = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-B-min-val-loss.ckpt"
    )
    model_domain_B.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    # Generate cross-domain predictions
    print("Generating cross-domain predictions...")
    model_A_on_domain_B = trainer.predict(model_domain_A, datamodule=dataset_domain_B)
    model_B_on_domain_A = trainer.predict(model_domain_B, datamodule=dataset_domain_A)

    # Convert predictions to class indices
    predictions_A_on_B = torch.cat(model_A_on_domain_B).argmax(dim=1)  # type: ignore
    predictions_B_on_A = torch.cat(model_B_on_domain_A).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    domain_A_targets = torch.cat(
        [label for _, label in dataset_domain_A.predict_dataloader()]
    )
    domain_B_targets = torch.cat(
        [label for _, label in dataset_domain_B.predict_dataloader()]
    )

    # Save predictions
    pd.DataFrame(
        {
            "domain_A": domain_A_targets,
            "predictions_B_on_A": predictions_B_on_A,
        }
    ).to_csv("data/caltech256_2domain_A_predictions.csv", index=False)

    pd.DataFrame(
        {
            "domain_B": domain_B_targets,
            "predictions_A_on_B": predictions_A_on_B,
        }
    ).to_csv("data/caltech256_2domain_B_predictions.csv", index=False)

    print("Predictions saved to CSV files.")

# Load prediction results
df_A = pd.read_csv("data/caltech256_2domain_A_predictions.csv")
df_B = pd.read_csv("data/caltech256_2domain_B_predictions.csv")

print(f"Domain A predictions shape: {df_A.shape}")
print(f"Domain B predictions shape: {df_B.shape}")

In [None]:
# Construct taxonomy from cross-domain predictions
taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        (0, 1, np.array(df_B["predictions_A_on_B"], dtype=np.intp)),
        (1, 0, np.array(df_A["predictions_B_on_A"], dtype=np.intp)),
    ],
    domain_targets=[
        (0, np.array(df_A["domain_A"], dtype=np.intp)),
        (1, np.array(df_B["domain_B"], dtype=np.intp)),
    ],
    domain_labels=synthetic_taxonomy.domain_labels,
    relationship_type="hypothesis",
)

print("Taxonomy constructed from cross-domain predictions.")

In [None]:
# Generate and save taxonomy visualizations
print("Generating taxonomy visualizations...")

taxonomy.visualize_graph("2-Domain Caltech256 Model Taxonomy").save_graph(
    "output/caltech256_2domain_synthetic_model_taxonomy.html"
)

synthetic_taxonomy.visualize_graph(
    "2-Domain Caltech256 Ground Truth Taxonomy",
    height=2000,
    width=2000,
).save_graph("output/caltech256_2domain_synthetic_taxonomy.html")

print("Taxonomy visualizations saved to output/ directory.")

In [None]:
# Evaluate taxonomy against ground truth
edr = taxonomy.edge_difference_ratio(synthetic_taxonomy)
precision, recall, f1 = taxonomy.precision_recall_f1(synthetic_taxonomy)

print("2-Domain Synthetic Taxonomy Evaluation:")
print(f"Edge Difference Ratio: {edr:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

In [None]:
# Build universal taxonomies
print("Building universal taxonomies...")

taxonomy.build_universal_taxonomy()
taxonomy.visualize_graph("2-Domain Caltech256 Model Universal Taxonomy").save_graph(
    "output/caltech256_2domain_synthetic_model_universal_taxonomy.html"
)

synthetic_taxonomy.build_universal_taxonomy()
synthetic_taxonomy.visualize_graph(
    "2-Domain Caltech256 Ground Truth Universal Taxonomy"
).save_graph("output/caltech256_2domain_synthetic_universal_taxonomy.html")

print("Universal taxonomy visualizations saved to output/ directory.")

In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import SyntheticTaxonomy, CrossPredictionsTaxonomy
from library.models import ResNetModel
from library.datasets import Caltech256MappedDataModule

# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech256_targets = pd.read_csv("data/caltech256.csv")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

In [None]:
# Create synthetic taxonomy with 2 domains (variant parameters)
synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=200,
    domain_class_count_variance=10,
    concept_cluster_size_mean=2,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

# Extract domain mappings
domain_A_mapping = synthetic_taxonomy.domains[0].to_mapping()
domain_B_mapping = synthetic_taxonomy.domains[1].to_mapping()

print(f"Domain A classes: {len(set(domain_A_mapping.values()))}")
print(f"Domain B classes: {len(set(domain_B_mapping.values()))}")

In [None]:
# Configuration
TRAIN = False  # Set to True to train models from scratch


def train_domain_model(domain_mapping, domain_name, logger_name, model_name):
    """Train a ResNet model for a specific domain"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)
    dataset = Caltech256MappedDataModule(mapping=domain_mapping)

    trainer = Trainer(
        max_epochs=50,
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=len(set(domain_mapping.values())),
            architecture="resnet50",
            optim="sgd",
            optim_kwargs={
                "lr": 0.01,
                "momentum": 0.9,
                "weight_decay": 5e-4,
            },
        )
        trainer.fit(model, datamodule=dataset)
        results = trainer.test(datamodule=dataset, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset)

    print(f"{domain_name} Results: {results}")
    return results

In [None]:
# Train Domain A model
print("Training/Testing Domain A Model:")
domain_A_results = train_domain_model(
    domain_A_mapping,
    "Domain A",
    "caltech256_synthetic_variant_A",
    "resnet50-caltech256-synthetic-variant-A-min-val-loss",
)

In [None]:
# Train Domain B model
print("Training/Testing Domain B Model:")
domain_B_results = train_domain_model(
    domain_B_mapping,
    "Domain B",
    "caltech256_synthetic_variant_B",
    "resnet50-caltech256-synthetic-variant-B-min-val-loss",
)

In [None]:
# Configuration for prediction generation
PREDICT = False  # Set to True to generate predictions from scratch

if PREDICT:
    # Load datasets
    dataset_domain_A = Caltech256MappedDataModule(mapping=domain_A_mapping)
    dataset_domain_B = Caltech256MappedDataModule(mapping=domain_B_mapping)

    # Load trained models
    model_domain_A = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-variant-A-min-val-loss.ckpt"
    )
    model_domain_A.eval()

    model_domain_B = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-variant-B-min-val-loss.ckpt"
    )
    model_domain_B.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    # Generate cross-domain predictions
    print("Generating cross-domain predictions...")
    model_A_on_domain_B = trainer.predict(model_domain_A, datamodule=dataset_domain_B)
    model_B_on_domain_A = trainer.predict(model_domain_B, datamodule=dataset_domain_A)

    # Convert predictions to class indices
    predictions_A_on_B = torch.cat(model_A_on_domain_B).argmax(dim=1)  # type: ignore
    predictions_B_on_A = torch.cat(model_B_on_domain_A).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    domain_A_targets = torch.cat(
        [label for _, label in dataset_domain_A.predict_dataloader()]
    )
    domain_B_targets = torch.cat(
        [label for _, label in dataset_domain_B.predict_dataloader()]
    )

    # Save predictions
    pd.DataFrame(
        {
            "domain_A": domain_A_targets,
            "predictions_B_on_A": predictions_B_on_A,
        }
    ).to_csv("data/caltech256_2domain_variant_A_predictions.csv", index=False)

    pd.DataFrame(
        {
            "domain_B": domain_B_targets,
            "predictions_A_on_B": predictions_A_on_B,
        }
    ).to_csv("data/caltech256_2domain_variant_B_predictions.csv", index=False)

    print("Predictions saved to CSV files.")

# Load prediction results
df_A = pd.read_csv("data/caltech256_2domain_variant_A_predictions.csv")
df_B = pd.read_csv("data/caltech256_2domain_variant_B_predictions.csv")

print(f"Domain A predictions shape: {df_A.shape}")
print(f"Domain B predictions shape: {df_B.shape}")

In [None]:
# Construct taxonomy from cross-domain predictions
taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        (0, 1, np.array(df_B["predictions_A_on_B"], dtype=np.intp)),
        (1, 0, np.array(df_A["predictions_B_on_A"], dtype=np.intp)),
    ],
    domain_targets=[
        (0, np.array(df_A["domain_A"], dtype=np.intp)),
        (1, np.array(df_B["domain_B"], dtype=np.intp)),
    ],
    domain_labels=synthetic_taxonomy.domain_labels,
    relationship_type="hypothesis",
)

print("Taxonomy constructed from cross-domain predictions.")

In [None]:
# Generate and save taxonomy visualizations
print("Generating taxonomy visualizations...")

taxonomy.visualize_graph("2-Domain Caltech256 Variant Model Taxonomy").save_graph(
    "output/caltech256_2domain_variant_synthetic_model_taxonomy.html"
)

synthetic_taxonomy.visualize_graph(
    "2-Domain Caltech256 Variant Ground Truth Taxonomy",
    height=2000,
    width=2000,
).save_graph("output/caltech256_2domain_variant_synthetic_taxonomy.html")

print("Taxonomy visualizations saved to output/ directory.")

In [None]:
# Evaluate taxonomy against ground truth
edr = taxonomy.edge_difference_ratio(synthetic_taxonomy)
precision, recall, f1 = taxonomy.precision_recall_f1(synthetic_taxonomy)

print("2-Domain Synthetic Taxonomy Variant Evaluation:")
print(f"Edge Difference Ratio: {edr:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

In [None]:
# Build universal taxonomies
print("Building universal taxonomies...")

taxonomy.build_universal_taxonomy()
taxonomy.visualize_graph(
    "2-Domain Caltech256 Variant Model Universal Taxonomy"
).save_graph(
    "output/caltech256_2domain_variant_synthetic_model_universal_taxonomy.html"
)

synthetic_taxonomy.build_universal_taxonomy()
synthetic_taxonomy.visualize_graph(
    "2-Domain Caltech256 Variant Ground Truth Universal Taxonomy"
).save_graph("output/caltech256_2domain_variant_synthetic_universal_taxonomy.html")

print("Universal taxonomy visualizations saved to output/ directory.")

In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import SyntheticTaxonomy, CrossPredictionsTaxonomy
from library.models import ResNetModel
from library.datasets import Caltech256MappedDataModule

# Load Caltech256 dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech256_targets = pd.read_csv("data/caltech256.csv")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

In [None]:
# Create synthetic taxonomy with 3 domains
synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=257,
    num_domains=3,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=5,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=caltech256_labels,
    relationship_type="true",
)

# Extract domain mappings
domain_A_mapping = synthetic_taxonomy.domains[0].to_mapping()
domain_B_mapping = synthetic_taxonomy.domains[1].to_mapping()
domain_C_mapping = synthetic_taxonomy.domains[2].to_mapping()

print(f"Domain A classes: {len(set(domain_A_mapping.values()))}")
print(f"Domain B classes: {len(set(domain_B_mapping.values()))}")
print(f"Domain C classes: {len(set(domain_C_mapping.values()))}")

In [None]:
# Configuration
TRAIN = False  # Set to True to train models from scratch


def train_domain_model(domain_mapping, domain_name, logger_name, model_name):
    """Train a ResNet model for a specific domain"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)
    dataset = Caltech256MappedDataModule(mapping=domain_mapping)

    trainer = Trainer(
        max_epochs=50,
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=len(set(domain_mapping.values())),
            architecture="resnet50",
            optim="sgd",
            optim_kwargs={
                "lr": 0.01,
                "momentum": 0.9,
                "weight_decay": 5e-4,
            },
        )
        trainer.fit(model, datamodule=dataset)
        results = trainer.test(datamodule=dataset, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset)

    print(f"{domain_name} Results: {results}")
    return results

In [None]:
# Train all three domain models
print("Training/Testing Domain A Model:")
domain_A_results = train_domain_model(
    domain_A_mapping,
    "Domain A",
    "caltech256_synthetic_2_A",
    "resnet50-caltech256-synthetic-2-A-min-val-loss",
)

print("\nTraining/Testing Domain B Model:")
domain_B_results = train_domain_model(
    domain_B_mapping,
    "Domain B",
    "caltech256_synthetic_2_B",
    "resnet50-caltech256-synthetic-2-B-min-val-loss",
)

print("\nTraining/Testing Domain C Model:")
domain_C_results = train_domain_model(
    domain_C_mapping,
    "Domain C",
    "caltech256_synthetic_2_C",
    "resnet50-caltech256-synthetic-2-C-min-val-loss",
)

In [None]:
# Configuration for prediction generation
PREDICT = False  # Set to True to generate predictions from scratch

if PREDICT:
    # Load datasets for all three domains
    dataset_domain_A = Caltech256MappedDataModule(mapping=domain_A_mapping)
    dataset_domain_B = Caltech256MappedDataModule(mapping=domain_B_mapping)
    dataset_domain_C = Caltech256MappedDataModule(mapping=domain_C_mapping)

    # Load trained models for all three domains
    model_domain_A = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-2-A-min-val-loss.ckpt"
    )
    model_domain_A.eval()

    model_domain_B = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-2-B-min-val-loss.ckpt"
    )
    model_domain_B.eval()

    model_domain_C = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-synthetic-2-C-min-val-loss.ckpt"
    )
    model_domain_C.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    print("Generating cross-domain predictions...")

    # Generate all cross-domain predictions (6 combinations)
    model_A_on_domain_B = trainer.predict(model_domain_A, datamodule=dataset_domain_B)
    model_A_on_domain_C = trainer.predict(model_domain_A, datamodule=dataset_domain_C)
    model_B_on_domain_A = trainer.predict(model_domain_B, datamodule=dataset_domain_A)
    model_B_on_domain_C = trainer.predict(model_domain_B, datamodule=dataset_domain_C)
    model_C_on_domain_A = trainer.predict(model_domain_C, datamodule=dataset_domain_A)
    model_C_on_domain_B = trainer.predict(model_domain_C, datamodule=dataset_domain_B)

    # Convert predictions to class indices
    predictions_A_on_B = torch.cat(model_A_on_domain_B).argmax(dim=1)  # type: ignore
    predictions_A_on_C = torch.cat(model_A_on_domain_C).argmax(dim=1)  # type: ignore
    predictions_B_on_A = torch.cat(model_B_on_domain_A).argmax(dim=1)  # type: ignore
    predictions_B_on_C = torch.cat(model_B_on_domain_C).argmax(dim=1)  # type: ignore
    predictions_C_on_A = torch.cat(model_C_on_domain_A).argmax(dim=1)  # type: ignore
    predictions_C_on_B = torch.cat(model_C_on_domain_B).argmax(dim=1)  # type: ignore

    # Get ground truth targets for all domains
    domain_A_targets = torch.cat(
        [label for _, label in dataset_domain_A.predict_dataloader()]
    )
    domain_B_targets = torch.cat(
        [label for _, label in dataset_domain_B.predict_dataloader()]
    )
    domain_C_targets = torch.cat(
        [label for _, label in dataset_domain_C.predict_dataloader()]
    )

    # Save all prediction results
    pd.DataFrame(
        {
            "domain_A": domain_A_targets,
            "predictions_B_on_A": predictions_B_on_A,
            "predictions_C_on_A": predictions_C_on_A,
        }
    ).to_csv("data/caltech256_3domain_A_predictions.csv", index=False)

    pd.DataFrame(
        {
            "domain_B": domain_B_targets,
            "predictions_A_on_B": predictions_A_on_B,
            "predictions_C_on_B": predictions_C_on_B,
        }
    ).to_csv("data/caltech256_3domain_B_predictions.csv", index=False)

    pd.DataFrame(
        {
            "domain_C": domain_C_targets,
            "predictions_A_on_C": predictions_A_on_C,
            "predictions_B_on_C": predictions_B_on_C,
        }
    ).to_csv("data/caltech256_3domain_C_predictions.csv", index=False)

    print("All predictions saved to CSV files.")

# Load the prediction results
df_A_3domain = pd.read_csv("data/caltech256_3domain_A_predictions.csv")
df_B_3domain = pd.read_csv("data/caltech256_3domain_B_predictions.csv")
df_C_3domain = pd.read_csv("data/caltech256_3domain_C_predictions.csv")

print(f"Domain A predictions shape: {df_A_3domain.shape}")
print(f"Domain B predictions shape: {df_B_3domain.shape}")
print(f"Domain C predictions shape: {df_C_3domain.shape}")

In [None]:
# Construct taxonomy from all cross-domain predictions (3 domains = 6 prediction pairs)
taxonomy_3domain = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        # Domain A → Domain B
        (0, 1, np.array(df_B_3domain["predictions_A_on_B"], dtype=np.intp)),
        # Domain A → Domain C
        (0, 2, np.array(df_C_3domain["predictions_A_on_C"], dtype=np.intp)),
        # Domain B → Domain A
        (1, 0, np.array(df_A_3domain["predictions_B_on_A"], dtype=np.intp)),
        # Domain B → Domain C
        (1, 2, np.array(df_C_3domain["predictions_B_on_C"], dtype=np.intp)),
        # Domain C → Domain A
        (2, 0, np.array(df_A_3domain["predictions_C_on_A"], dtype=np.intp)),
        # Domain C → Domain B
        (2, 1, np.array(df_B_3domain["predictions_C_on_B"], dtype=np.intp)),
    ],
    domain_targets=[
        (0, np.array(df_A_3domain["domain_A"], dtype=np.intp)),
        (1, np.array(df_B_3domain["domain_B"], dtype=np.intp)),
        (2, np.array(df_C_3domain["domain_C"], dtype=np.intp)),
    ],
    domain_labels=synthetic_taxonomy.domain_labels,
    relationship_type="hypothesis",
)

print("3-domain taxonomy constructed from cross-domain predictions.")

In [None]:
# Generate and save taxonomy visualizations
print("Generating 3-domain taxonomy visualizations...")

taxonomy_3domain.visualize_graph("3-Domain Caltech256 Model Taxonomy").save_graph(
    "output/caltech256_3domain_synthetic_model_taxonomy.html"
)

synthetic_taxonomy.visualize_graph(
    "3-Domain Caltech256 Ground Truth Taxonomy"
).save_graph("output/caltech256_3domain_synthetic_taxonomy.html")
fig = synthetic_taxonomy.visualize_3d_graph(
    show_labels=False,  # Labels make graph unreadable in 3D
).show()

print("3-domain taxonomy visualizations saved to output/ directory.")

In [None]:
# Evaluate the 3-domain taxonomy against the ground truth
edr_3domain = taxonomy_3domain.edge_difference_ratio(synthetic_taxonomy)
precision_3domain, recall_3domain, f1_3domain = taxonomy_3domain.precision_recall_f1(
    synthetic_taxonomy
)

print("3-Domain Synthetic Taxonomy Evaluation:")
print(f"Edge Difference Ratio: {edr_3domain:.4f}")
print(f"Precision: {precision_3domain:.4f}")
print(f"Recall: {recall_3domain:.4f}")
print(f"F1 Score: {f1_3domain:.4f}")

In [None]:
# Build and visualize universal taxonomies for the 3-domain case
print("Building 3-domain universal taxonomies...")

taxonomy_3domain.build_universal_taxonomy()
taxonomy_3domain.visualize_graph(
    "3-Domain Caltech256 Model Universal Taxonomy"
).save_graph("output/caltech256_3domain_synthetic_model_universal_taxonomy.html")

synthetic_taxonomy.build_universal_taxonomy()
synthetic_taxonomy.visualize_graph(
    "3-Domain Caltech256 Ground Truth Universal Taxonomy"
).save_graph("output/caltech256_3domain_synthetic_universal_taxonomy.html")

print("3-domain universal taxonomy visualizations saved to output/ directory.")

In [None]:
import pandas as pd
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import SyntheticTaxonomy, CrossPredictionsTaxonomy
from library.models import ResNetModel
from library.datasets import CIFAR100MappedDataModule, CIFAR100Mapped

# Load CIFAR100 dataset information
cifar100_dataset = CIFAR100Mapped(root="datasets/cifar100", download=False)
cifar100_labels = cifar100_dataset.classes
print(f"Total CIFAR100 classes: {len(cifar100_labels)}")

# Create a simple target mapping for CIFAR100 (0 to num_classes-1)
cifar100_targets = pd.DataFrame(
    {"class_id": range(len(cifar100_labels)), "class_name": cifar100_labels}
)

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

In [None]:
# Create synthetic taxonomy with 2 domains for CIFAR100
synthetic_taxonomy = SyntheticTaxonomy.create_synthetic_taxonomy(
    num_atomic_concepts=len(cifar100_labels),
    num_domains=2,
    domain_class_count_mean=50,
    domain_class_count_variance=5,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
    no_prediction_class=True,
    atomic_concept_labels=cifar100_labels,
    relationship_type="true",
)

# Extract domain mappings
domain_A_mapping = synthetic_taxonomy.domains[0].to_mapping()
domain_B_mapping = synthetic_taxonomy.domains[1].to_mapping()

print(f"Domain A classes: {len(set(domain_A_mapping.values()))}")
print(f"Domain B classes: {len(set(domain_B_mapping.values()))}")
print(f"Domain A original classes: {len(domain_A_mapping)}")
print(f"Domain B original classes: {len(domain_B_mapping)}")

In [None]:
# Configuration
TRAIN = False


def train_domain_model(domain_mapping, domain_name, logger_name, model_name):
    """Train a ResNet model for a specific domain using CIFAR100 with overfitting mitigation"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)
    dataset = CIFAR100MappedDataModule(mapping=domain_mapping, batch_size=256)

    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        ),
        LearningRateMonitor(logging_interval="epoch"),
    ]

    trainer = Trainer(
        max_epochs=100,
        logger=tb_logger if TRAIN else False,
        callbacks=callbacks,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=len(set(domain_mapping.values())),
            architecture="resnet50",
            optim="adamw",
            optim_kwargs={
                "lr": 0.001,
                "weight_decay": 1e-3,
            },
            # Add learning rate scheduling
            lr_scheduler="multistep",
            lr_scheduler_kwargs={
                "milestones": [30, 60, 80],
                "gamma": 0.1,
            },
        )
        trainer.fit(model, datamodule=dataset)
        results = trainer.test(datamodule=dataset, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset)

    print(f"{domain_name} Results: {results}")
    return results

In [None]:
print("Training/Testing Domain A Model:")
domain_A_results = train_domain_model(
    domain_A_mapping,
    "Domain A",
    "cifar100_synthetic_A",
    "resnet50-cifar100-synthetic-A-min-val-loss",
)

In [None]:
# Train Domain B model
print("Training/Testing Domain B Model:")
domain_B_results = train_domain_model(
    domain_B_mapping,
    "Domain B",
    "cifar100_synthetic_B",
    "resnet50-cifar100-synthetic-B-min-val-loss",
)

In [None]:
# Configuration for prediction generation
PREDICT = False

if PREDICT:
    # Load datasets
    dataset_domain_A = CIFAR100MappedDataModule(
        mapping=domain_A_mapping, batch_size=128
    )
    dataset_domain_B = CIFAR100MappedDataModule(
        mapping=domain_B_mapping, batch_size=128
    )

    # Load trained models
    model_domain_A = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-cifar100-synthetic-A-min-val-loss.ckpt"
    )
    model_domain_A.eval()

    model_domain_B = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-cifar100-synthetic-B-min-val-loss.ckpt"
    )
    model_domain_B.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    # Generate cross-domain predictions
    print("Generating cross-domain predictions...")
    model_A_on_domain_B = trainer.predict(model_domain_A, datamodule=dataset_domain_B)
    model_B_on_domain_A = trainer.predict(model_domain_B, datamodule=dataset_domain_A)

    # Convert predictions to class indices
    predictions_A_on_B = torch.cat(model_A_on_domain_B).argmax(dim=1)  # type: ignore
    predictions_B_on_A = torch.cat(model_B_on_domain_A).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    domain_A_targets = torch.cat(
        [label for _, label in dataset_domain_A.predict_dataloader()]
    )
    domain_B_targets = torch.cat(
        [label for _, label in dataset_domain_B.predict_dataloader()]
    )

    # Save predictions
    pd.DataFrame(
        {
            "domain_A": domain_A_targets,
            "predictions_B_on_A": predictions_B_on_A,
        }
    ).to_csv("data/cifar100_2domain_A_predictions.csv", index=False)

    pd.DataFrame(
        {
            "domain_B": domain_B_targets,
            "predictions_A_on_B": predictions_A_on_B,
        }
    ).to_csv("data/cifar100_2domain_B_predictions.csv", index=False)

    print("Predictions saved to CSV files.")

# Load prediction results
df_A = pd.read_csv("data/cifar100_2domain_A_predictions.csv")
df_B = pd.read_csv("data/cifar100_2domain_B_predictions.csv")

print(f"Domain A predictions shape: {df_A.shape}")
print(f"Domain B predictions shape: {df_B.shape}")

In [None]:
# Construct taxonomy from cross-domain predictions
taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        (0, 1, np.array(df_B["predictions_A_on_B"], dtype=np.intp)),
        (1, 0, np.array(df_A["predictions_B_on_A"], dtype=np.intp)),
    ],
    domain_targets=[
        (0, np.array(df_A["domain_A"], dtype=np.intp)),
        (1, np.array(df_B["domain_B"], dtype=np.intp)),
    ],
    domain_labels=synthetic_taxonomy.domain_labels,
    relationship_type="density_threshold",
    threshold=0.6,
)

print("Taxonomy constructed from cross-domain predictions.")

In [None]:
# Generate and save taxonomy visualizations
print("Generating taxonomy visualizations...")

taxonomy.visualize_graph("2-Domain CIFAR100 Model Taxonomy").save_graph(
    "output/cifar100_2domain_synthetic_model_taxonomy.html"
)

synthetic_taxonomy.visualize_graph(
    "2-Domain CIFAR100 Ground Truth Taxonomy",
    height=2000,
    width=2000,
).save_graph("output/cifar100_2domain_synthetic_taxonomy.html")

print("Taxonomy visualizations saved to output/ directory.")

In [None]:
# Evaluate taxonomy against ground truth
edr = taxonomy.edge_difference_ratio(synthetic_taxonomy)
precision, recall, f1 = taxonomy.precision_recall_f1(synthetic_taxonomy)

print("2-Domain CIFAR100 Taxonomy Evaluation:")
print(f"Edge Difference Ratio: {edr:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

In [None]:
# Build universal taxonomies
print("Building universal taxonomies...")

taxonomy.build_universal_taxonomy()
taxonomy.visualize_graph("2-Domain CIFAR100 Model Universal Taxonomy").save_graph(
    "output/cifar100_2domain_synthetic_model_universal_taxonomy.html"
)

synthetic_taxonomy.build_universal_taxonomy()
synthetic_taxonomy.visualize_graph(
    "2-Domain CIFAR100 Ground Truth Universal Taxonomy"
).save_graph("output/cifar100_2domain_synthetic_universal_taxonomy.html")

print("Universal taxonomy visualizations saved to output/ directory.")