In [None]:
import pandas as pd
from torchvision.datasets import Caltech256


caltech256_labels = Caltech256(root="datasets/caltech256", download=False).categories
caltech256_targets = pd.read_csv("output/caltech256.csv")

In [None]:
from synthetic_taxonomy import SyntheticTaxonomy


synthetic_taxonomy = SyntheticTaxonomy(
    num_atomic_concepts=257,
    num_domains=2,
    domain_class_count_mean=180,
    domain_class_count_variance=10,
    concept_cluster_size_mean=3,
    concept_cluster_size_variance=1,
)

domain_A = synthetic_taxonomy.domains[0].to_mapping()
domain_B = synthetic_taxonomy.domains[1].to_mapping()

In [None]:
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch import loggers as pl_loggers
from models import ResNetMappedClassModel
from datasets import Caltech256MappedClassDataModule


# Reduce the precision of matrix multiplication to speed up training
torch.set_float32_matmul_precision("medium")

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="caltech256_synthetic_1")
dataset = Caltech256MappedClassDataModule(mapping=domain_A)
model_name = "resnet50-caltech256-synthetic-1-min-val-loss"
trainer = Trainer(
    max_epochs=20,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = False

if TRAIN:
    model = ResNetMappedClassModel(
        architecture="resnet50",
        optim="sgd",
        optim_kwargs={
            "lr": 0.01,
            "momentum": 0.9,
            "weight_decay": 5e-4,
        },
        mapping=domain_A,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetMappedClassModel.load_from_checkpoint(
        f"checkpoints/{model_name}.ckpt"
    )
    results = trainer.test(model, datamodule=dataset)

print(results)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs", name="caltech256_synthetic_2")
dataset = Caltech256MappedClassDataModule(mapping=domain_B)
model_name = "resnet50-caltech256-synthetic-2-min-val-loss"
trainer = Trainer(
    max_epochs=20,
    logger=tb_logger,
    callbacks=[
        # Save the model with the lowest validation loss
        ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            filename=model_name,
            enable_version_counter=False,
        )
    ],
)

TRAIN = False

if TRAIN:
    model = ResNetMappedClassModel(
        architecture="resnet50",
        optim="sgd",
        optim_kwargs={
            "lr": 0.01,
            "momentum": 0.9,
            "weight_decay": 5e-4,
        },
        mapping=domain_B,
    )
    trainer.fit(model, datamodule=dataset)

    # Test with the best model from the checkpoint
    results = trainer.test(datamodule=dataset, ckpt_path="best")
else:
    model = ResNetMappedClassModel.load_from_checkpoint(
        f"checkpoints/{model_name}.ckpt"
    )
    results = trainer.test(model, datamodule=dataset)

print(results)