In [1]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256, Caltech101, CIFAR100
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy import Taxonomy
from library.models import UniversalEfficientNetV2Model
from library.datasets import (
    Caltech256DataModule,
    Caltech101DataModule,
    CIFAR100ScaledDataModule,
    CombinedDataModule,
)

# Load dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories
cifar100_labels = CIFAR100(
    root="datasets/cifar100", download=False, train=False
).classes

print(f"Caltech-256 classes: {len(caltech256_labels)}")
print(f"Caltech-101 classes: {len(caltech101_labels)}")
print(f"CIFAR-100 classes: {len(cifar100_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

Caltech-256 classes: 257
Caltech-101 classes: 101
CIFAR-100 classes: 100


In [2]:
# Load the three-domain MCFP taxonomy
three_domain_mcfp_taxonomy = Taxonomy.load("taxonomies/three_domain_mcfp.pkl")

print(
    f"Loaded three-domain MCFP taxonomy with {len(three_domain_mcfp_taxonomy.get_nodes())} nodes"
)
print(
    f"Universal classes: {len([node for node in three_domain_mcfp_taxonomy.get_nodes() if hasattr(node, '__len__') and len(node) == 1])}"
)
print(
    f"Domain classes: {len([node for node in three_domain_mcfp_taxonomy.get_nodes() if hasattr(node, '__len__') and len(node) == 2])}"
)

Loaded three-domain MCFP taxonomy with 1244 nodes
Universal classes: 82
Domain classes: 1162


In [3]:
# Create individual dataset modules
caltech101_dm = Caltech101DataModule(batch_size=16)
caltech256_dm = Caltech256DataModule(batch_size=16)
cifar100_dm = CIFAR100ScaledDataModule(batch_size=16)

# Create three-domain data module
# Domain 0: Caltech-101, Domain 1: Caltech-256, Domain 2: CIFAR-100
three_domain_dataset_module = CombinedDataModule(
    dataset_modules=[caltech101_dm, caltech256_dm, cifar100_dm],
    domain_ids=[0, 1, 2],
    batch_size=32,
    num_workers=11,
)

print("Dataset modules created successfully")

Dataset modules created successfully


In [4]:
# Training configuration for EfficientNetV2
TRAIN = False  # Set to True to train model from scratch
EFFICIENTNET_VARIANT = (
    "s"  # EfficientNet variant: "s" (Small), "m" (Medium), "l" (Large)
)

training_config = {
    "max_epochs": 50,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.00003,  # Slightly lower learning rate for EfficientNet
        "weight_decay": 0.001,
        "betas": (0.9, 0.999),
        "eps": 1e-8,
    },
    "lr_scheduler": "cosine",
    "lr_scheduler_kwargs": {
        "T_max": 50,  # matches max_epochs
        "eta_min": 1e-7,
    },
}

model_name = (
    f"universal-efficientnetv2-{EFFICIENTNET_VARIANT}-three-domain-mcfp-min-val-loss"
)
logger_name = f"universal_efficientnetv2_{EFFICIENTNET_VARIANT}_three_domain_mcfp"

print(
    f"Training configuration set for EfficientNetV2-{EFFICIENTNET_VARIANT.upper()} with three-domain MCFP taxonomy"
)

Training configuration set for EfficientNetV2-S with three-domain MCFP taxonomy


In [5]:
# Create the Universal EfficientNetV2 model
model = UniversalEfficientNetV2Model(
    taxonomy=three_domain_mcfp_taxonomy,
    optim=training_config["optim"],
    optim_kwargs=training_config["optim_kwargs"],
    lr_scheduler=training_config["lr_scheduler"],
    lr_scheduler_kwargs=training_config["lr_scheduler_kwargs"],
    efficientnet_variant=EFFICIENTNET_VARIANT,
)


# Setup trainer
if TRAIN:
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_accuracy",
                mode="max",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
    )

    print("Starting training...")
    # Train the model
    trainer.fit(model, datamodule=three_domain_dataset_module)

    # Test the trained model
    print("Training completed. Running final test...")
    test_results = trainer.test(
        datamodule=three_domain_dataset_module, ckpt_path="best"
    )

    print(f"Final test accuracy: {test_results[0]['eval_accuracy']:.4f}")

else:
    print("Training disabled. Set TRAIN=True to start training.")

Training disabled. Set TRAIN=True to start training.


In [6]:
# Evaluate on individual domains
if not TRAIN:
    # Load pre-trained model for evaluation
    print(f"Loading pre-trained model: {model_name}.ckpt")
    model = UniversalEfficientNetV2Model.load_from_checkpoint(
        f"checkpoints/{model_name}.ckpt",
        taxonomy=three_domain_mcfp_taxonomy,
        efficientnet_variant=EFFICIENTNET_VARIANT,
    )

# Create individual combined data modules for each domain
caltech101_combined_dm = CombinedDataModule(
    dataset_modules=[caltech101_dm],
    domain_ids=[0],  # Domain 0 for Caltech-101
    batch_size=64,
    num_workers=11,
)

caltech256_combined_dm = CombinedDataModule(
    dataset_modules=[caltech256_dm],
    domain_ids=[1],  # Domain 1 for Caltech-256
    batch_size=64,
    num_workers=11,
)

cifar100_combined_dm = CombinedDataModule(
    dataset_modules=[cifar100_dm],
    domain_ids=[2],  # Domain 2 for CIFAR-100
    batch_size=64,
    num_workers=11,
)

# Create trainer for testing
test_trainer = Trainer(
    logger=False,
    enable_checkpointing=False,
)

print("Evaluating on individual domains...")

# Test on Caltech-101 (Domain 0)
caltech101_results = test_trainer.test(model, datamodule=caltech101_combined_dm)
caltech101_accuracy = caltech101_results[0]["eval_accuracy"]

# Test on Caltech-256 (Domain 1)
caltech256_results = test_trainer.test(model, datamodule=caltech256_combined_dm)
caltech256_accuracy = caltech256_results[0]["eval_accuracy"]

# Test on CIFAR-100 (Domain 2)
cifar100_results = test_trainer.test(model, datamodule=cifar100_combined_dm)
cifar100_accuracy = cifar100_results[0]["eval_accuracy"]

# Test on all three domains combined
combined_results = test_trainer.test(model, datamodule=three_domain_dataset_module)
combined_accuracy = combined_results[0]["eval_accuracy"]

print(
    f"\n=== EfficientNetV2-{EFFICIENTNET_VARIANT.upper()} Universal Model Results ==="
)
print(
    f"Caltech-101 accuracy: {caltech101_accuracy:.4f} ({caltech101_accuracy*100:.2f}%)"
)
print(
    f"Caltech-256 accuracy: {caltech256_accuracy:.4f} ({caltech256_accuracy*100:.2f}%)"
)
print(f"CIFAR-100 accuracy: {cifar100_accuracy:.4f} ({cifar100_accuracy*100:.2f}%)")
print(f"Combined accuracy: {combined_accuracy:.4f} ({combined_accuracy*100:.2f}%)")

# Store results for comparison
efficientnet_results = {
    "model": f"UniversalEfficientNetV2-{EFFICIENTNET_VARIANT.upper()}",
    "taxonomy": "Three-Domain MCFP",
    "caltech101": caltech101_accuracy,
    "caltech256": caltech256_accuracy,
    "cifar100": cifar100_accuracy,
    "combined": combined_accuracy,
}

Loading pre-trained model: universal-efficientnetv2-s-three-domain-mcfp-min-val-loss.ckpt


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Evaluating on individual domains...
Testing DataLoader 0:  14%|█▍        | 2/14 [00:00<00:04,  2.58it/s]

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Testing DataLoader 0: 100%|██████████| 14/14 [00:02<00:00,  6.71it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9446367025375366
        eval_loss            2.548962116241455
        hp_metric           0.9446367025375366
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.9446367025375366
        eval_loss            2.548962116241455
 

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 35. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 48/48 [00:05<00:00,  8.01it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8970588445663452
        eval_loss           2.1533899307250977
        hp_metric           0.8970588445663452
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8970588445663452
        eval_loss           2.1533899307250977
 

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 52. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 157/157 [00:17<00:00,  9.00it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8772000074386597
        eval_loss           1.5989248752593994
        hp_metric           0.8772000074386597
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8772000074386597
        eval_loss           1.5989248752593994

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0:   1%|          | 3/436 [00:00<00:58,  7.37it/s]

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Testing DataLoader 0: 100%|██████████| 436/436 [00:31<00:00, 13.97it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8856178522109985
        eval_loss            1.77834951877594
        hp_metric           0.8856178522109985
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      eval_accuracy         0.8856178522109985
        eval_loss            1.77834951877594
 

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 7. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.



=== EfficientNetV2-S Universal Model Results ===
Caltech-101 accuracy: 0.9446 (94.46%)
Caltech-256 accuracy: 0.8971 (89.71%)
CIFAR-100 accuracy: 0.8772 (87.72%)
Combined accuracy: 0.8856 (88.56%)
