## Imports


In [None]:
import torch
import torchvision
import matplotlib.pylab as plt
from datetime import datetime
import os
import time

## Device


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Paths


In [None]:
training_data_directory_path = "../grouped-data/train/"
testing_data_directory_path = "../minimal-grouped-data/test/"
models_directory_path = "./models/"
image_name_list_path = "./xview2.txt"

## Load Testing Data


In [None]:
testing_data = torchvision.datasets.ImageFolder(
    testing_data_directory_path,
    torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
            ),
        ]
    ),
)
testing_data_loader = torch.utils.data.DataLoader(
    testing_data,
    batch_size=12,
    shuffle=False,
    num_workers=8,
)

## Load Training Data


In [None]:
def get_first_n_lines(filename, n):
    with open(filename, "r") as file:
        wanted_image_names = set(next(file).strip() for _ in range(n))

    subset_indices = []
    for idx, img_data in enumerate(training_data.imgs):
        if img_data[0].split("/")[-1] in wanted_image_names:
            subset_indices.append(idx)
            if len(subset_indices) == n:
                break  # Stop the loop once we have found n matches

    return torch.utils.data.DataLoader(
        torch.utils.data.Subset(training_data, subset_indices),
        batch_size=12,
        shuffle=True,
        num_workers=8,
    )


training_data = torchvision.datasets.ImageFolder(
    training_data_directory_path,
    torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
            ),
        ]
    ),
)

named_training_data_loaders = {
    data_set_name: get_first_n_lines(image_name_list_path, image_count)
    for (data_set_name, image_count) in [("baseline", 4), ("bryan", 4)]
}

## Dataset Size


In [None]:
for name, training_data_loader in named_training_data_loaders.items():
    print(f'Training dataset "{name}" size:', len(training_data_loader.dataset))
print("--------------")
print("Testing dataset size:", len(testing_data_loader.dataset))

## Train Model


In [None]:
import model_creation
import logging


def change_log_file_path(logger, new_path):
    for handler in logger.handlers[:]:  # copy the handlers list
        if isinstance(
            handler, logging.FileHandler
        ):  # check if handler is a FileHandler
            logger.removeHandler(handler)  # remove the handler from the logger

    file_handler = logging.FileHandler(new_path)  # create a new FileHandler
    file_handler.setLevel(logging.INFO)  # set the log level of the handler
    logger.addHandler(file_handler)  # add the handler to the logger


epoch_count = 2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

for i, (dataset_name, training_data_loader) in enumerate(
    named_training_data_loaders.items()
):
    dataset_length = len(training_data_loader.dataset)
    model_descriptor = (
        dataset_name
        + f"_size{dataset_length}"
        + f"_targetEpochCount{epoch_count}"
        + f"_creationStart{datetime.utcnow().replace(microsecond=0).isoformat()}Z"
    )
    model_directory_path = os.path.join(models_directory_path, model_descriptor)
    model_info_directory_path = os.path.join(model_directory_path, "_info")

    # Create the directory if it doesn't already exist
    os.makedirs(model_info_directory_path, exist_ok=True)
    change_log_file_path(
        logger, os.path.join(model_info_directory_path, "durations.log")
    )
    logger.info(f"Creating {model_descriptor}")

    model = torchvision.models.resnet18(
        weights=torchvision.models.ResNet18_Weights.DEFAULT
    )
    input_feature_count = model.fc.in_features
    output_feature_count = 5
    model.fc = torch.nn.Linear(input_feature_count, output_feature_count)
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    training_losses = []
    training_accuracies = []
    testing_losses = []
    testing_accuracies = []
    model_creation_start_time = time.time()
    for epoch in range(1, epoch_count + 1):
        epoch_start_time = time.time()
        logger.info("Epoch {} running.".format(epoch))
        training_start_time = time.time()
        loss, accuracy = model_creation.train_model(
            training_data_loader, model, criterion, optimizer
        )
        model_creation.log_duration(
            logger, True, epoch, training_start_time, model_creation_start_time
        )
        training_losses.append(loss)
        training_accuracies.append(accuracy)
        model_name = (
            dataset_name
            + f"_size{dataset_length}"
            + f"_epoch{epoch}Of{epoch_count}"
            + f"_{datetime.utcnow().replace(microsecond=0).isoformat()}Z"
            + ".pth"
        )
        model_creation.export_model(
            model,
            model_name,
            model_directory_path,
        )
        model_creation.export_loss_and_accuracy(
            training_losses,
            training_accuracies,
            testing_losses,
            testing_accuracies,
            model_descriptor,
            model_info_directory_path,
        )
        model_creation.plot_loss(
            training_losses, testing_losses, model_descriptor, model_info_directory_path
        )
        testing_start_time = time.time()
        loss, accuracy = model_creation.test_model(
            model, testing_data_loader, criterion
        )
        testing_losses.append(loss)
        testing_accuracies.append(accuracy)
        model_creation.export_loss_and_accuracy(
            training_losses,
            training_accuracies,
            testing_losses,
            testing_accuracies,
            model_descriptor,
            model_info_directory_path,
        )
        model_creation.log_duration(
            logger, False, epoch, testing_start_time, model_creation_start_time
        )
        model_creation.plot_loss(
            training_losses, testing_losses, model_descriptor, model_info_directory_path
        )

        logger.info(
            "Epoch {} done. Epoch Duration: {}, Total Duration: {}".format(
                epoch,
                model_creation.format_duration(time.time() - epoch_start_time),
                model_creation.format_duration(time.time() - model_creation_start_time),
            )
        )
        logger.info("--------------------------------------------")
    logger.info("Done creating model")
    plt.ioff()
    plt.show()