In [5]:
import itertools

from utils import (
    DatasetConfig,
    DatasetTypeEnum,
    DatasetResisc45,
    DataLoader,
    PretrainedModelsEnum,
    resnet18_img_transforms_train,
    resnet18_img_transforms_validation,
    get_pretrained_model,
    get_cross_entropy_loss,
    get_sgd_optimizer,
    CNNHyperParams,
    train_one_epoch,
    val_one_epoch,
)

In [2]:
datasets = ["dataset_ucmerced.csv", "dataset_resisc45.csv"]
models = PretrainedModelsEnum.resnet18
kfolds = 5
models_path = "./models"

#  todo: store csv table with model configurations: PretrainedModelConfig save as CSV with rows for each model configuration;
# use models_config.csv


# hyperparameters
cnn_hyper_params = CNNHyperParams(
    model_name=PretrainedModelsEnum.resnet18,
    num_classes=45,
    batch_size=64,
    num_epochs=20,
    criterion_name="cross_entropy",
    optimizer_name="sgd",
)


In [None]:
def train_val_dataloaders(dataset_file: str, batch_size: int, val_fold: int):
    train_config = DatasetConfig(
        dataset_file=dataset_file,
        transform=resnet18_img_transforms_train(),
        dataset_type=DatasetTypeEnum.train,
        val_fold=val_fold,
    )
    val_config = DatasetConfig(
        dataset_file=dataset_file,
        transform=resnet18_img_transforms_validation(),
        dataset_type=DatasetTypeEnum.val,
        val_fold=val_fold,
    )

    train_set = DatasetResisc45(train_config)
    val_set = DatasetResisc45(val_config)

    num_classes = train_set.num_classes

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    return train_dataloader, val_dataloader


In [3]:
def train_model_on_fold(model_name, dataset_file, val_fold, save_path):
    print(model_name, dataset_file, val_fold, save_path)
    train_data, val_data = train_val_dataloaders(dataset_file)

    model = get_pretrained_model(model_name)

    if criterion_name == "cross_entropy":
        criterion = get_cross_entropy_loss()
    if optimizer_name == 'sgd':
        optimizer = get_sgd_optimizer(model, feature_extract=True, lr=0.001, momentum=0.9)

    for t in range(num_epochs):
        train_one_epoch()
        val_one_epoch()
    print("done.")
    # save best torch model in save_path / dataset_file / model_name / val_fold


In [4]:
for dataset_file, model_name, val_fold in itertools.product(datasets, models, range(kfolds)):
    train_model_on_fold(model_name, dataset_file, val_fold, save_path=models_path)


resnet18 dataset_ucmerced.csv 0 ./models
resnet18 dataset_ucmerced.csv 1 ./models
resnet18 dataset_ucmerced.csv 2 ./models
resnet18 dataset_ucmerced.csv 3 ./models
resnet18 dataset_ucmerced.csv 4 ./models
resnet18 dataset_resisc45.csv 0 ./models
resnet18 dataset_resisc45.csv 1 ./models
resnet18 dataset_resisc45.csv 2 ./models
resnet18 dataset_resisc45.csv 3 ./models
resnet18 dataset_resisc45.csv 4 ./models
