# Universal ResNet Model Training

This notebook demonstrates training a Universal ResNet model using the taxonomy built from Caltech-256 and Caltech-101 datasets. The model uses universal classes as the output layer and maps domain class labels to universal class activations based on taxonomy relationships.

In [None]:
import pandas as pd
import numpy as np
import torch
from torchvision.datasets import Caltech256, Caltech101
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers

from library.taxonomy import Taxonomy
from library.models import UniversalResNetModel
from library.datasets import Caltech256DataModule, Caltech101DataModule

# Load dataset information
caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech101_labels = Caltech101(root="datasets/caltech101", download=False).categories

print(f"Caltech-256 classes: {len(caltech256_labels)}")
print(f"Caltech-101 classes: {len(caltech101_labels)}")

# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

# Print labels
for idx, label in enumerate(caltech256_labels):
    print(f"Caltech-256 class {idx}: {label}")

for idx, label in enumerate(caltech101_labels):
    print(f"Caltech-101 class {idx}: {label}")

In [None]:
# Load the universal taxonomy created from the real-world datasets
taxonomy = Taxonomy.load("taxonomies/caltech256_caltech101.pkl")

In [None]:
# Configuration
TRAIN = True  # Set to True to train models from scratch
DOMAIN_ID = 0  # 0 for Caltech-101, 1 for Caltech-256

# Select dataset based on domain
if DOMAIN_ID == 0:
    dataset_module = Caltech101DataModule()
    dataset_name = "Caltech-101"
    model_name = "universal-resnet50-caltech101-min-val-loss"
    logger_name = "universal_caltech101"
    num_classes = len(caltech101_labels)
else:
    dataset_module = Caltech256DataModule()
    dataset_name = "Caltech-256"
    model_name = "universal-resnet50-caltech256-min-val-loss"
    logger_name = "universal_caltech256"
    num_classes = len(caltech256_labels)

print(f"Training on {dataset_name} (Domain {DOMAIN_ID})")

In [None]:
# Training configuration
training_config = {
    "max_epochs": 50,
    "optim": "adamw",
    "optim_kwargs": {
        "lr": 0.001,
        "weight_decay": 0.001,
    },
    "lr_scheduler": "multistep",
    "lr_scheduler_kwargs": {
        "milestones": [15, 30, 40],
        "gamma": 0.1,
    },
}

print("Training configuration:")
for key, value in training_config.items():
    print(f"  {key}: {value}")

In [None]:
# Create the Universal ResNet model
model = UniversalResNetModel(
    taxonomy=taxonomy,
    architecture="resnet50",
    optim=training_config["optim"],
    optim_kwargs=training_config["optim_kwargs"],
    lr_scheduler=training_config["lr_scheduler"],
    lr_scheduler_kwargs=training_config["lr_scheduler_kwargs"],
)

# Set the domain for training
model.set_domain(DOMAIN_ID)

print(f"Universal ResNet model created successfully!")
print(f"Model output size (universal classes): {model.num_universal_classes}")
print(f"Training domain set to: {DOMAIN_ID}")

In [None]:
# Setup trainer
if TRAIN:
    tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=logger_name)

    trainer = Trainer(
        max_epochs=training_config["max_epochs"],
        logger=tb_logger,
        callbacks=[
            ModelCheckpoint(
                dirpath="checkpoints",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename=model_name,
                enable_version_counter=False,
            )
        ],
        enable_progress_bar=True,
        enable_model_summary=True,
    )

    print("Trainer configured for training")
else:
    trainer = Trainer(
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=True,
    )

    print("Trainer configured for evaluation only")

In [None]:
# Train or load the model
if TRAIN:
    print(f"Starting training on {dataset_name}...")
    print(f"Expected training time: ~{training_config['max_epochs']} epochs")

    # Train the model
    trainer.fit(model, datamodule=dataset_module)

    print("Training completed!")

    # Test the trained model
    print("Testing the trained model...")
    results = trainer.test(datamodule=dataset_module, ckpt_path="best")

else:
    # Load pre-trained model
    print(f"Loading pre-trained model: {model_name}.ckpt")
    try:
        model = UniversalResNetModel.load_from_checkpoint(
            f"checkpoints/{model_name}.ckpt",
            taxonomy=taxonomy,  # Need to pass taxonomy since it's not serialized
        )
        model.set_domain(DOMAIN_ID)

        print("Model loaded successfully!")

        # Test the loaded model
        results = trainer.test(model, datamodule=dataset_module)

    except FileNotFoundError:
        print(f"Checkpoint file not found: checkpoints/{model_name}.ckpt")
        print("Please set TRAIN=True to train the model first.")
        results = None

if results:
    print(f"\nTest Results for {dataset_name}:")
    for key, value in results[0].items():
        print(f"  {key}: {value:.4f}")