In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256, Caltech101, CIFAR100
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import CrossPredictionsTaxonomy
from library.models.resnet import ResNetModel, _Architecture
from library.datasets.caltech256 import Caltech256DataModule
from library.datasets.caltech101 import Caltech101DataModule
from library.datasets.cifar100 import CIFAR100DataModule

# Load dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories
cifar100_labels = CIFAR100(
    root="datasets/cifar100", download=False, train=False
).classes

print(f"Caltech-256 classes: {len(caltech256_labels)}")
print(f"Caltech-101 classes: {len(caltech101_labels)}")
print(f"CIFAR-100 classes: {len(cifar100_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

Caltech-256 classes: 257
Caltech-101 classes: 101
CIFAR-100 classes: 100


In [2]:
# Configuration
TRAIN = False  # Set to True to train models from scratch


def train_dataset_model(
    dataset_module,
    dataset_name,
    logger_name,
    model_name,
    num_classes,
    training_config,
    architecture: _Architecture = "resnet50",
):
    """Train a ResNet model for a specific dataset"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=num_classes,
            architecture=architecture,
            optim=training_config["optim"],
            optim_kwargs=training_config["optim_kwargs"],
        )
        trainer.fit(model, datamodule=dataset_module)
        results = trainer.test(datamodule=dataset_module, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset_module)

    return results

In [3]:
# Train Caltech-256 model
caltech256_dataset = Caltech256DataModule()
caltech256_config = {
    "max_epochs": 50,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.001,
    },
    "addition": {
        # Add learning rate scheduling
        "lr_scheduler": "multistep",
        "lr_scheduler_kwargs": {
            "milestones": [20, 50, 80],
            "gamma": 0.1,
        },
    },
}
caltech256_results = train_dataset_model(
    caltech256_dataset,
    "Caltech-256",
    "caltech256_three_domain",
    "resnet50-caltech256-min-val-loss",
    len(caltech256_labels),
    caltech256_config,
)
print(f"Caltech-256 results: {caltech256_results}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 48/48 [00:23<00:00,  2.06it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.6996731758117676
        eval_loss           1.6287411451339722
        hp_metric           0.6996731758117676
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Caltech-256 results: [{'eval_loss': 1.6287411451339722, 'eval_accuracy': 0.6996731758117676, 'hp_metric': 0.6996731758117676}]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
─────────────────────────────────────────────────────────────────────────────────────────

In [4]:
# Train Caltech-101 model
caltech101_dataset = Caltech101DataModule()
caltech101_config = {
    "max_epochs": 50,
    "optim": "sgd",
    "optim_kwargs": {
        "lr": 0.01,
        "momentum": 0.9,
        "weight_decay": 5e-4,
    },
    "addition": {
        # Add learning rate scheduling
        "lr_scheduler": "multistep",
        "lr_scheduler_kwargs": {
            "milestones": [20, 50, 80],
            "gamma": 0.1,
        },
    },
}
caltech101_results = train_dataset_model(
    caltech101_dataset,
    "Caltech-101",
    "caltech101_three_domain",
    "resnet50-caltech101-min-val-loss",
    len(caltech101_labels),
    caltech101_config,
)
print(f"Caltech-101 results: {caltech101_results}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 14/14 [00:03<00:00,  4.01it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9169549942016602
        eval_loss           0.3228408694267273
        hp_metric           0.9169549942016602
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Caltech-101 results: [{'eval_loss': 0.3228408694267273, 'eval_accuracy': 0.9169549942016602, 'hp_metric': 0.9169549942016602}]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
─────────────────────────────────────────────────────────────────────────────────────────

In [5]:
# Train CIFAR-100 model
cifar100_dataset = CIFAR100DataModule()
cifar100_config = {
    "max_epochs": 50,
    "optim": "sgd",
    "optim_kwargs": {
        "lr": 0.01,
        "momentum": 0.9,
        "weight_decay": 5e-4,
    },
    "addition": {
        # Add learning rate scheduling
        "lr_scheduler": "multistep",
        "lr_scheduler_kwargs": {
            "milestones": [20, 50, 80],
            "gamma": 0.1,
        },
    },
}
cifar100_results = train_dataset_model(
    cifar100_dataset,
    "CIFAR-100",
    "cifar100_three_domain",
    "resnet152-cifar100-min-val-loss",
    len(cifar100_labels),
    cifar100_config,
    architecture="resnet152",
)
print(f"CIFAR-100 results: {cifar100_results}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 40/40 [00:02<00:00, 17.12it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.6047000288963318
        eval_loss           1.9386216402053833
        hp_metric           0.6047000288963318
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.6047000288963318
        eval_loss           1.9386216402053833
 

In [6]:
# Generate cross-domain predictions
PREDICT = False

if PREDICT:
    # Load trained models
    model_caltech256 = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-min-val-loss.ckpt"
    )
    model_caltech256.eval()

    model_caltech101 = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech101-min-val-loss.ckpt"
    )
    model_caltech101.eval()

    model_cifar100 = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet152-cifar100-min-val-loss.ckpt"
    )
    model_cifar100.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    # Cross-domain predictions
    # Caltech-256 on other domains
    caltech256_on_caltech101 = trainer.predict(
        model_caltech256, datamodule=caltech101_dataset
    )
    caltech256_on_cifar100 = trainer.predict(
        model_caltech256, datamodule=cifar100_dataset
    )

    # Caltech-101 on other domains
    caltech101_on_caltech256 = trainer.predict(
        model_caltech101, datamodule=caltech256_dataset
    )
    caltech101_on_cifar100 = trainer.predict(
        model_caltech101, datamodule=cifar100_dataset
    )

    # CIFAR-100 on other domains
    cifar100_on_caltech256 = trainer.predict(
        model_cifar100, datamodule=caltech256_dataset
    )
    cifar100_on_caltech101 = trainer.predict(
        model_cifar100, datamodule=caltech101_dataset
    )

    # Convert predictions to class indices
    predictions_caltech256_on_caltech101 = torch.cat(caltech256_on_caltech101).argmax(dim=1)  # type: ignore
    predictions_caltech256_on_cifar100 = torch.cat(caltech256_on_cifar100).argmax(dim=1)  # type: ignore
    predictions_caltech101_on_caltech256 = torch.cat(caltech101_on_caltech256).argmax(dim=1)  # type: ignore
    predictions_caltech101_on_cifar100 = torch.cat(caltech101_on_cifar100).argmax(dim=1)  # type: ignore
    predictions_cifar100_on_caltech256 = torch.cat(cifar100_on_caltech256).argmax(dim=1)  # type: ignore
    predictions_cifar100_on_caltech101 = torch.cat(cifar100_on_caltech101).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    caltech101_targets = torch.cat(
        [label for _, label in caltech101_dataset.predict_dataloader()]
    )
    caltech256_targets = torch.cat(
        [label for _, label in caltech256_dataset.predict_dataloader()]
    )
    cifar100_targets = torch.cat(
        [label for _, label in cifar100_dataset.predict_dataloader()]
    )

    # Save predictions for Caltech-101
    pd.DataFrame(
        {
            "caltech101": caltech101_targets,
            "predictions_caltech256_on_caltech101": predictions_caltech256_on_caltech101,
            "predictions_cifar100_on_caltech101": predictions_cifar100_on_caltech101,
        }
    ).to_csv("data/three_domain_caltech101_predictions.csv", index=False)

    # Save predictions for Caltech-256
    pd.DataFrame(
        {
            "caltech256": caltech256_targets,
            "predictions_caltech101_on_caltech256": predictions_caltech101_on_caltech256,
            "predictions_cifar100_on_caltech256": predictions_cifar100_on_caltech256,
        }
    ).to_csv("data/three_domain_caltech256_predictions.csv", index=False)

    # Save predictions for CIFAR-100
    pd.DataFrame(
        {
            "cifar100": cifar100_targets,
            "predictions_caltech256_on_cifar100": predictions_caltech256_on_cifar100,
            "predictions_caltech101_on_cifar100": predictions_caltech101_on_cifar100,
        }
    ).to_csv("data/three_domain_cifar100_predictions.csv", index=False)

In [7]:
# Load prediction results
df_caltech101 = pd.read_csv("data/three_domain_caltech101_predictions.csv")
df_caltech256 = pd.read_csv("data/three_domain_caltech256_predictions.csv")
df_cifar100 = pd.read_csv("data/three_domain_cifar100_predictions.csv")

print(f"Loaded predictions for {len(df_caltech101)} Caltech-101 samples")
print(f"Loaded predictions for {len(df_caltech256)} Caltech-256 samples")
print(f"Loaded predictions for {len(df_cifar100)} CIFAR-100 samples")

Loaded predictions for 867 Caltech-101 samples
Loaded predictions for 3060 Caltech-256 samples
Loaded predictions for 10000 CIFAR-100 samples


In [8]:
# Construct three-domain taxonomy from cross-domain predictions
taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        # Caltech-101 (domain 0) predictions on other domains
        (
            0,
            1,
            np.array(
                df_caltech256["predictions_caltech101_on_caltech256"], dtype=np.intp
            ),
        ),  # Caltech-101 -> Caltech-256
        (
            0,
            2,
            np.array(df_cifar100["predictions_caltech101_on_cifar100"], dtype=np.intp),
        ),  # Caltech-101 -> CIFAR-100
        # Caltech-256 (domain 1) predictions on other domains
        (
            1,
            0,
            np.array(
                df_caltech101["predictions_caltech256_on_caltech101"], dtype=np.intp
            ),
        ),  # Caltech-256 -> Caltech-101
        (
            1,
            2,
            np.array(df_cifar100["predictions_caltech256_on_cifar100"], dtype=np.intp),
        ),  # Caltech-256 -> CIFAR-100
        # CIFAR-100 (domain 2) predictions on other domains
        (
            2,
            0,
            np.array(
                df_caltech101["predictions_cifar100_on_caltech101"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-101
        (
            2,
            1,
            np.array(
                df_caltech256["predictions_cifar100_on_caltech256"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-256
    ],
    domain_targets=[
        (0, np.array(df_caltech101["caltech101"], dtype=np.intp)),
        (1, np.array(df_caltech256["caltech256"], dtype=np.intp)),
        (2, np.array(df_cifar100["cifar100"], dtype=np.intp)),
    ],
    domain_labels={0: caltech101_labels, 1: caltech256_labels, 2: cifar100_labels},
    relationship_type="hypothesis",
    upper_bound=5,
)

print("Three-domain taxonomy constructed successfully!")

Three-domain taxonomy constructed successfully!


In [9]:
# Construct three-domain MCFP taxonomy from cross-domain predictions
mcfp_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        # Caltech-101 (domain 0) predictions on other domains
        (
            0,
            1,
            np.array(
                df_caltech256["predictions_caltech101_on_caltech256"], dtype=np.intp
            ),
        ),  # Caltech-101 -> Caltech-256
        (
            0,
            2,
            np.array(df_cifar100["predictions_caltech101_on_cifar100"], dtype=np.intp),
        ),  # Caltech-101 -> CIFAR-100
        # Caltech-256 (domain 1) predictions on other domains
        (
            1,
            0,
            np.array(
                df_caltech101["predictions_caltech256_on_caltech101"], dtype=np.intp
            ),
        ),  # Caltech-256 -> Caltech-101
        (
            1,
            2,
            np.array(df_cifar100["predictions_caltech256_on_cifar100"], dtype=np.intp),
        ),  # Caltech-256 -> CIFAR-100
        # CIFAR-100 (domain 2) predictions on other domains
        (
            2,
            0,
            np.array(
                df_caltech101["predictions_cifar100_on_caltech101"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-101
        (
            2,
            1,
            np.array(
                df_caltech256["predictions_cifar100_on_caltech256"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-256
    ],
    domain_targets=[
        (0, np.array(df_caltech101["caltech101"], dtype=np.intp)),
        (1, np.array(df_caltech256["caltech256"], dtype=np.intp)),
        (2, np.array(df_cifar100["cifar100"], dtype=np.intp)),
    ],
    domain_labels={0: caltech101_labels, 1: caltech256_labels, 2: cifar100_labels},
    relationship_type="mcfp",
)

print("Three-domain MCFP taxonomy constructed successfully!")

Three-domain MCFP taxonomy constructed successfully!


In [10]:
# Construct three-domain MCFP Binary taxonomy from cross-domain predictions
mcfp_binary_taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        # Caltech-101 (domain 0) predictions on other domains
        (
            0,
            1,
            np.array(
                df_caltech256["predictions_caltech101_on_caltech256"], dtype=np.intp
            ),
        ),  # Caltech-101 -> Caltech-256
        (
            0,
            2,
            np.array(df_cifar100["predictions_caltech101_on_cifar100"], dtype=np.intp),
        ),  # Caltech-101 -> CIFAR-100
        # Caltech-256 (domain 1) predictions on other domains
        (
            1,
            0,
            np.array(
                df_caltech101["predictions_caltech256_on_caltech101"], dtype=np.intp
            ),
        ),  # Caltech-256 -> Caltech-101
        (
            1,
            2,
            np.array(df_cifar100["predictions_caltech256_on_cifar100"], dtype=np.intp),
        ),  # Caltech-256 -> CIFAR-100
        # CIFAR-100 (domain 2) predictions on other domains
        (
            2,
            0,
            np.array(
                df_caltech101["predictions_cifar100_on_caltech101"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-101
        (
            2,
            1,
            np.array(
                df_caltech256["predictions_cifar100_on_caltech256"], dtype=np.intp
            ),
        ),  # CIFAR-100 -> Caltech-256
    ],
    domain_targets=[
        (0, np.array(df_caltech101["caltech101"], dtype=np.intp)),
        (1, np.array(df_caltech256["caltech256"], dtype=np.intp)),
        (2, np.array(df_cifar100["cifar100"], dtype=np.intp)),
    ],
    domain_labels={0: caltech101_labels, 1: caltech256_labels, 2: cifar100_labels},
    relationship_type="mcfp_binary",
)

print("Three-domain MCFP Binary taxonomy constructed successfully!")

Three-domain MCFP Binary taxonomy constructed successfully!


In [11]:
# Generate and save MCFP Binary taxonomy visualizations
mcfp_binary_taxonomy.visualize_graph(
    "Three-Domain MCFP Binary Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_mcfp_binary_taxonomy.html")

print(
    "MCFP Binary taxonomy visualization saved to output/three_domain_mcfp_binary_taxonomy.html"
)

MCFP Binary taxonomy visualization saved to output/three_domain_mcfp_binary_taxonomy.html


In [12]:
# Build MCFP Binary universal taxonomy
mcfp_binary_taxonomy.build_universal_taxonomy()
mcfp_binary_taxonomy.visualize_graph(
    "Three-Domain MCFP Binary Universal Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_mcfp_binary_universal_taxonomy.html")

# Save the MCFP Binary taxonomy
mcfp_binary_taxonomy.save("taxonomies/three_domain_mcfp_binary.pkl")

print("MCFP Binary universal taxonomy built and saved!")

MCFP Binary universal taxonomy built and saved!


In [13]:
# Generate and save MCFP taxonomy visualizations
mcfp_taxonomy.visualize_graph(
    "Three-Domain MCFP Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_mcfp_taxonomy.html")

print("MCFP taxonomy visualization saved to output/three_domain_mcfp_taxonomy.html")

MCFP taxonomy visualization saved to output/three_domain_mcfp_taxonomy.html


In [14]:
# Build MCFP universal taxonomy
mcfp_taxonomy.build_universal_taxonomy()
mcfp_taxonomy.visualize_graph(
    "Three-Domain MCFP Universal Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_mcfp_universal_taxonomy.html")

# Save the MCFP taxonomy
mcfp_taxonomy.save("taxonomies/three_domain_mcfp.pkl")

print("MCFP universal taxonomy built and saved!")

MCFP universal taxonomy built and saved!


In [15]:
# Generate and save taxonomy visualizations
taxonomy.visualize_graph(
    "Three-Domain Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_taxonomy.html")

print("Taxonomy visualization saved to output/three_domain_taxonomy.html")

Taxonomy visualization saved to output/three_domain_taxonomy.html


In [None]:
# Build universal taxonomy
taxonomy.build_universal_taxonomy()
taxonomy.visualize_graph(
    "Three-Domain Universal Taxonomy: Caltech-256, Caltech-101, and CIFAR-100"
).save_graph("output/three_domain_universal_taxonomy.html")

# Save the taxonomymcfpmcfmcfpdasssdassd
taxonomy.save("taxonomies/three_domain_hypothesis.pkl")

print("Universal taxonomy built and saved!")

Universal taxonomy built and saved!
