In [None]:
import itertools

from utils import (
    DatasetConfig,
    DatasetTypeEnum,
    DatasetResisc45,
    DataLoader,
    get_pretrained_model,
    get_cross_entropy_loss,
    get_sgd_optimizer,
    CNNHyperParams,
    train_one_epoch,
    val_one_epoch,
)
from utils import *

import pandas as pd

In [None]:
cnn_models = [
    get_cnn_hyper_params_1(),
    get_cnn_hyper_params_3(),
]
kfolds = 5
models_path = "./cls_models"


In [None]:
print(cnn_models)

In [None]:
def train_val_dataloaders(hyperparams: CNNHyperParams, val_fold: int):
    train_config = DatasetConfig(
        dataset_file=hyperparams.dataset_file,
        transform=hyperparams.train_transforms,
        dataset_type=DatasetTypeEnum.train,
        val_fold=val_fold,
    )
    val_config = DatasetConfig(
        dataset_file=hyperparams.dataset_file,
        transform=hyperparams.val_transforms,
        dataset_type=DatasetTypeEnum.val,
        val_fold=val_fold,
    )

    train_set = DatasetResisc45(train_config)
    val_set = DatasetResisc45(val_config)

    train_dataloader = DataLoader(train_set, batch_size=hyperparams.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=hyperparams.batch_size, shuffle=False)

    return train_dataloader, val_dataloader


In [None]:
def train_model_on_fold(hyperparams: CNNHyperParams, val_fold: int, save_path: str):
    print(hyperparams.model_name, hyperparams.dataset_file, val_fold, save_path)

    with open(f"{models_path}/model_{hyperparams.dataset_file}_{val_fold}", "w") as f:
        f.write(f"{str(hyperparams)}; {val_fold}; {save_path}")

    train_data, val_data = train_val_dataloaders(hyperparams, val_fold)

    cnn_model, input_size = get_pretrained_model(hyperparams)
    cnn_model.to(get_device())

    if hyperparams.criterion_name == "cross_entropy":
        criterion = get_cross_entropy_loss()
    if hyperparams.optimizer_name == 'sgd':
        optimizer = get_sgd_optimizer(cnn_model, feature_extract=hyperparams.feature_extract, lr=0.001, momentum=0.9)

    train_losses = []
    val_losses = []
    for epoch in range(hyperparams.num_epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")
        train_loss = train_one_epoch(criterion, optimizer, train_data, cnn_model, device=get_device())
        val_loss = val_one_epoch(criterion, val_data, cnn_model, device=get_device())
        train_losses.append(train_loss)
        val_losses.append(val_loss)
    torch.save(cnn_model, f"{models_path}/model_{hyperparams.dataset_file}_{val_fold}.pt")
    
    pd.DataFrame({
        'train_losses': train_losses,
        'val_losses': val_losses,
    }).to_csv(f"{models_path}/model_{hyperparams.dataset_file}_{val_fold}_loss.csv")

    print("done.")
    # save best torch model in save_path / dataset_file / model_name / val_fold


In [None]:
for hyperparams, val_fold in itertools.product(cnn_models, range(kfolds)):
    train_model_on_fold(hyperparams, val_fold, save_path=models_path)
