In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt

from data import get_dataloaders
from training import train_loop, evaluate
from unet import UNet, UNet6, ConvNextUNet, UNetWithoutDEM, UNetWithDerivedFeatures, UNetWithDerivedFeaturesWithoutDEM
from metrics import Precision, Recall, HardIoU
from loss import iou_loss, bce_loss, dice_loss

dataloaders = get_dataloaders("dataset")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def evaluate_method(model_cls, loss_fn, directory):
    !rm -r {directory}
    !mkdir {directory}

    num_epochs = 100
    model = model_cls(3, 1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 0.001)
    metrics = {"precision": Precision(), "recall": Recall(), "iou": HardIoU()}

    train_loop(dataloaders, model, loss_fn, optimizer,
        num_epochs=num_epochs, metrics=metrics, early_stop=10, directory=directory, best_metric="iou")

    df = pd.read_csv(f"{directory}/metrics.csv")
    fig, axs = plt.subplots(1, 2, figsize=(16, 8))
    axs[0].plot(df["epoch"], df["train_loss"], label="Training Loss")
    axs[0].plot(df["epoch"], df["val_loss"], label="Validation Loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()
    axs[1].plot(df["epoch"], df["iou"])
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("IoU")
    plt.show()

    model = model.load_state_dict(torch.load(f"{directory}/best.pt"))
    loss, computed_metrics = evaluate(dataloaders[2], model, loss_fn, metrics)
    print(f"test loss: {loss}")
    for key, metric in computed_metrics.items():
        print(f"{key}: {metric}")

In [None]:
evaluate_method(UNet, iou_loss, "unet")
evaluate_method(UNetWithoutDEM, iou_loss, "unet_without_dem")
evaluate_method(UNetWithDerivedFeatures, iou_loss, "unet_with_derived_features")
evaluate_method(UNetWithDerivedFeaturesWithoutDEM, iou_loss, "unet_with_derived_features_without_dem")

evaluate_method(UNet, bce_loss, "unet_bce")
evaluate_method(UNet, dice_loss, "unet_dice")

evaluate_method(UNet6, iou_loss, "unet6")
evaluate_method(ConvNextUNet, iou_loss, "unet_convnext")