In [1]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256, Caltech101
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy_constructors import CrossPredictionsTaxonomy
from library.models import ResNetModel
from library.datasets import Caltech256DataModule, Caltech101DataModule

# Load dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories

print(f"Caltech-256 classes: {len(caltech256_labels)}")
print(f"Caltech-101 classes: {len(caltech101_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

Caltech-256 classes: 257
Caltech-101 classes: 101


In [2]:
# Configuration
TRAIN = False  # Set to True to train models from scratch


def train_dataset_model(
    dataset_module, dataset_name, logger_name, model_name, num_classes, training_config
):
    """Train a ResNet model for a specific dataset"""
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger if TRAIN else False,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    if TRAIN:
        model = ResNetModel(
            num_classes=num_classes,
            architecture="resnet50",
            optim=training_config["optim"],
            optim_kwargs=training_config["optim_kwargs"],
        )
        trainer.fit(model, datamodule=dataset_module)
        results = trainer.test(datamodule=dataset_module, ckpt_path="best")
    else:
        model = ResNetModel.load_from_checkpoint(f"checkpoints/{model_name}.ckpt")
        results = trainer.test(model, datamodule=dataset_module)

    return results

In [3]:
caltech256_dataset = Caltech256DataModule()
caltech256_config = {
    "max_epochs": 100,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.001,
    },
    "addition": {
        # Add learning rate scheduling
        "lr_scheduler": "multistep",
        "lr_scheduler_kwargs": {
            "milestones": [20, 50, 80],
            "gamma": 0.1,
        },
    },
}
caltech256_results = train_dataset_model(
    caltech256_dataset,
    "Caltech-256",
    "caltech256_real",
    "resnet50-caltech256-min-val-loss",
    len(caltech256_labels),
    caltech256_config,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 48/48 [00:14<00:00,  3.35it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.6964052319526672
        eval_loss            1.452818751335144
        hp_metric           0.6964052319526672
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [4]:
TRAIN = False

caltech101_dataset = Caltech101DataModule()
caltech101_config = {
    "max_epochs": 50,
    "optim": "sgd",
    "optim_kwargs": {
        "lr": 0.01,
        "momentum": 0.9,
        "weight_decay": 5e-4,
    },
}
caltech101_results = train_dataset_model(
    caltech101_dataset,
    "Caltech-101",
    "caltech101_real",
    "resnet50-caltech101-min-val-loss",
    len(caltech101_labels),
    caltech101_config,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 14/14 [00:03<00:00,  4.57it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9146482348442078
        eval_loss           0.3319704532623291
        hp_metric           0.9146482348442078
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [5]:
PREDICT = False

if PREDICT:
    # Load trained models
    model_caltech256 = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech256-min-val-loss.ckpt"
    )
    model_caltech256.eval()

    model_caltech101 = ResNetModel.load_from_checkpoint(
        "checkpoints/resnet50-caltech101-min-val-loss.ckpt"
    )
    model_caltech101.eval()

    trainer = Trainer(logger=False, enable_checkpointing=False)

    caltech256_on_caltech101 = trainer.predict(
        model_caltech256, datamodule=caltech101_dataset
    )

    caltech101_on_caltech256 = trainer.predict(
        model_caltech101, datamodule=caltech256_dataset
    )

    # Convert predictions to class indices
    predictions_caltech256_on_caltech101 = torch.cat(caltech256_on_caltech101).argmax(dim=1)  # type: ignore
    predictions_caltech101_on_caltech256 = torch.cat(caltech101_on_caltech256).argmax(dim=1)  # type: ignore

    # Get ground truth targets
    caltech101_targets = torch.cat(
        [label for _, label in caltech101_dataset.predict_dataloader()]
    )
    caltech256_targets = torch.cat(
        [label for _, label in caltech256_dataset.predict_dataloader()]
    )

    pd.DataFrame(
        {
            "caltech101": caltech101_targets,
            "predictions_caltech256_on_caltech101": predictions_caltech256_on_caltech101,
        }
    ).to_csv("data/caltech256_caltech101_caltech101_real_predictions.csv", index=False)
    pd.DataFrame(
        {
            "caltech256": caltech256_targets,
            "predictions_caltech101_on_caltech256": predictions_caltech101_on_caltech256,
        }
    ).to_csv("data/caltech256_caltech101_caltech256_real_predictions.csv", index=False)

# Load prediction results
df_caltech101 = pd.read_csv(
    "data/caltech256_caltech101_caltech101_real_predictions.csv"
)
df_caltech256 = pd.read_csv(
    "data/caltech256_caltech101_caltech256_real_predictions.csv"
)

In [6]:
# Construct taxonomy from cross-domain predictions
taxonomy = CrossPredictionsTaxonomy.from_cross_domain_predictions(
    cross_domain_predictions=[
        (
            0,
            1,
            np.array(
                df_caltech256["predictions_caltech101_on_caltech256"], dtype=np.intp
            ),
        ),  # Caltech-101 -> Caltech-256
        (
            1,
            0,
            np.array(
                df_caltech101["predictions_caltech256_on_caltech101"], dtype=np.intp
            ),
        ),  # Caltech-256 -> Caltech-101
    ],
    domain_targets=[
        (0, np.array(df_caltech101["caltech101"], dtype=np.intp)),
        (1, np.array(df_caltech256["caltech256"], dtype=np.intp)),
    ],
    domain_labels={0: caltech101_labels, 1: caltech256_labels},
    relationship_type="mcfp",
    threshold=0.4,
)

In [7]:
# Generate and save taxonomy visualizations
taxonomy.visualize_graph(
    "Caltech-256 and Caltech-101 Cross-Domain Taxonomy"
).save_graph("output/caltech256_caltech101_real_taxonomy.html")

In [8]:
# Build universal taxonomy
taxonomy.build_universal_taxonomy()
taxonomy.visualize_graph("Caltech-256 and Caltech-101 Universal Taxonomy").save_graph(
    "output/caltech256_caltech101_real_universal_taxonomy.html"
)