# Imports

In [1]:
import os

from matplotlib import pyplot as plt
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

from segmentation_models_pytorch import Unet
from src.utils import (
    get_data,
    get_device,
    get_truth_vs_predicted,
    load,
    loss,
    plot_image_and_prediction,
    predict_image,
    predict_patch,
    seed_everyting,
    test,
    train,
)

# Global parameters

In [2]:
patch_size = 256
img_dir = "data/images"
model_dir = "models"
patch_dir = "data/patches"
results_dir = "results"
gedi_dir = "data/gedi"
random_state = 42
batch_size = 12
num_workers = os.cpu_count()
learning_rate = 1e-2
epochs = 25
is_training = False
bins = list(range(0, 55, 5))
device = get_device()

seed_everyting(random_state)

Using mps device


# Create dataloaders

In [4]:
# Create dataloaders
train_dl, val_dl, test_dl = get_data(
    img_dir, patch_dir, gedi_dir, patch_size, batch_size, num_workers, bins
)

# Create & Train model

In [None]:
if is_training:
    model = Unet(
        encoder_name="efficientnet-b4",
        encoder_weights=None,
        decoder_attention_type="scse",
        in_channels=5,
    ).to(device)

    # Create optimizer
    optimizer = SGD(model.parameters(), learning_rate)

    # Create scheduler
    scheduler = CosineAnnealingLR(optimizer, epochs)

    # Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}\n-------------------------------")
        train(train_dl, model, loss, device, optimizer, scheduler)
        test(val_dl, model, loss, device)

    print("Training finished.")

    test(test_dl, model, loss, device)
else:
    model = load(os.path.join(model_dir, "unet", f"unet-{patch_size}.pt"), device)

In [None]:
# save(model, os.path.join(model_dir, f"{model.name}.pt"))

# Visualise results

In [None]:
inputs, targets = next(iter(test_dl))

In [None]:
idx = 42
patch = inputs[idx], None
img, pred = predict_patch(model, patch, device)
plot_image_and_prediction(img, pred, 3)

In [None]:
image, prediction = predict_image(
    model, device, f"{img_dir}/L15-1059E-1348N.tif", patch_size
)

In [None]:
plot_image_and_prediction(image, prediction, 3)

In [None]:
truth, predicted = get_truth_vs_predicted(model, test_dl, device)

In [None]:
mask = (truth > 0) & (truth < 50)

plt.scatter(truth[mask], predicted[mask], alpha=0.2)

In [None]:
plt.hist2d(truth[mask], predicted[mask], bins)