In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256, Caltech101
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy import Taxonomy
from library.models import UniversalResNetModel
from library.datasets import (
    Caltech256DataModule,
    Caltech101DataModule,
    CombinedDataModule,
)

# Load dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories

print(f"Caltech-256 classes: {len(caltech256_labels)}")
print(f"Caltech-101 classes: {len(caltech101_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

Caltech-256 classes: 257
Caltech-101 classes: 101
Caltech-256 class 0: 001.ak47
Caltech-256 class 1: 002.american-flag
Caltech-256 class 2: 003.backpack
Caltech-256 class 3: 004.baseball-bat
Caltech-256 class 4: 005.baseball-glove
Caltech-256 class 5: 006.basketball-hoop
Caltech-256 class 6: 007.bat
Caltech-256 class 7: 008.bathtub
Caltech-256 class 8: 009.bear
Caltech-256 class 9: 010.beer-mug
Caltech-256 class 10: 011.billiards
Caltech-256 class 11: 012.binoculars
Caltech-256 class 12: 013.birdbath
Caltech-256 class 13: 014.blimp
Caltech-256 class 14: 015.bonsai-101
Caltech-256 class 15: 016.boom-box
Caltech-256 class 16: 017.bowling-ball
Caltech-256 class 17: 018.bowling-pin
Caltech-256 class 18: 019.boxing-glove
Caltech-256 class 19: 020.brain-101
Caltech-256 class 20: 021.breadmaker
Caltech-256 class 21: 022.buddha-101
Caltech-256 class 22: 023.bulldozer
Caltech-256 class 23: 024.butterfly
Caltech-256 class 24: 025.cactus
Caltech-256 class 25: 026.cake
Caltech-256 class 26: 027.ca

In [2]:
# Load the universal taxonomy created from the real-world datasets
taxonomy = Taxonomy.load("taxonomies/caltech256_caltech101.pkl")

In [3]:
# Configuration for Multi-Domain Training

# Training configuration
TRAIN = True  # Set to True to train model from scratch

# Create individual dataset modules
caltech101_dm = Caltech101DataModule(batch_size=32)
caltech256_dm = Caltech256DataModule(batch_size=32)

# Create combined data module with domain IDs
# Domain 0: Caltech-101, Domain 1: Caltech-256
dataset_module = CombinedDataModule(
    dataset_modules=[caltech101_dm, caltech256_dm], domain_ids=[0, 1], batch_size=64
)

dataset_name = "Caltech-101 + Caltech-256 (Multi-Domain)"
model_name = "universal-resnet50-multi-domain-min-val-loss"
logger_name = "universal_multi_domain"

In [4]:
# Training configuration
training_config = {
    "max_epochs": 50,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.001,
    },
    "lr_scheduler": "multistep",
    "lr_scheduler_kwargs": {
        "milestones": [15, 30, 40],
        "gamma": 0.1,
    },
}

# Create the Universal ResNet model
model = UniversalResNetModel(
    taxonomy=taxonomy,
    architecture="resnet50",
    optim=training_config["optim"],
    optim_kwargs=training_config["optim_kwargs"],
    lr_scheduler=training_config["lr_scheduler"],
    lr_scheduler_kwargs=training_config["lr_scheduler_kwargs"],
)

In [5]:
# Setup trainer
if TRAIN:
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
        enable_progress_bar=True,
        enable_model_summary=True,
    )

    print("Trainer configured for training")
else:
    trainer = Trainer(
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=True,
    )

    print("Trainer configured for evaluation only")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Trainer configured for training


In [None]:
# Train or load the model
if TRAIN:
    print(f"Starting training on {dataset_name}...")
    print(f"Expected training time: ~{training_config['max_epochs']} epochs")

    # Train the model
    trainer.fit(model, datamodule=dataset_module)

    print("Training completed!")

    # Test the trained model
    print("Testing the trained model...")
    results = trainer.test(datamodule=dataset_module, ckpt_path="best")

else:
    # Load pre-trained model
    print(f"Loading pre-trained model: {model_name}.ckpt")
    try:
        model = UniversalResNetModel.load_from_checkpoint(
            f"checkpoints/{model_name}.ckpt",
            taxonomy=taxonomy,  # Need to pass taxonomy since it's not serialized
        )

        print("Model loaded successfully!")

        # Test the loaded model
        results = trainer.test(model, datamodule=dataset_module)

    except FileNotFoundError:
        print(f"Checkpoint file not found: checkpoints/{model_name}.ckpt")
        print("Please set TRAIN=True to train the model first.")
        results = None

if results:
    print(f"\nTest Results for {dataset_name}:")
    for key, value in results[0].items():
        print(f"  {key}: {value:.4f}")

print(
    "\nNote: All training now uses multi-domain format with (domain_id, domain_class_id) targets"
)

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/bjoern/dev/master-thesis/project/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training on Caltech-101 + Caltech-256 (Multi-Domain)...
Expected training time: ~50 epochs



  | Name      | Type      | Params | Mode 
------------------------------------------------
0 | model     | ResNet    | 26.3 M | train
1 | criterion | KLDivLoss | 0      | train
------------------------------------------------
26.3 M    Trainable params
0         Non-trainable params
26.3 M    Total params
105.366   Total estimated model params size (MB)
162       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.71it/s]

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


                                                                           

/home/bjoern/miniconda3/envs/master-thesis/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:   8%|▊         | 39/492 [00:16<03:07,  2.41it/s, v_num=5]