# Imports

In [None]:
import os

from matplotlib import pyplot as plt

from src.utils import (
    get_data,
    get_device,
    get_truth_vs_predicted,
    load,
    plot_image_and_prediction,
    plot_image_channels,
    plot_labels_in_germany,
    plot_predictions,
    predict_batch,
    predict_image,
    seed_everyting,
)

# Global parameters

In [None]:
image_size = 256
img_dir = "data/images"
model_dir = "models"
patch_dir = "data/patches"
results_dir = "results"
gedi_dir = "data/gedi"
random_state = 42
batch_size = 12
num_workers = os.cpu_count() // 2
bins = list(range(0, 55, 5))
device = get_device()

seed_everyting(random_state)

In [None]:
plot_image_channels(f"{img_dir}/L15-1060E-1348N.tif")

In [None]:
# plot_labels_in_germany()

In [None]:
# Plot histogram of labels
# TODO

# Create dataloaders

In [None]:
# Create dataloaders
_, _, test_dl = get_data(
    img_dir, patch_dir, gedi_dir, image_size, batch_size, num_workers, bins
)

# Load models

In [None]:
models = {
    "unet": load(f"{model_dir}/u-plusplus-unetplusplus-efficientnet-b2.pt", device),
    "vit-base": load(f"{model_dir}/archive/vit-base-vit-16.pt", device),
    "vit-base-kd": load(f"{model_dir}/archive/vit-base-vit-16-kd.pt", device),
    "vit-medium": load(f"{model_dir}/archive/vit-medium-vit-16.pt", device),
    "vit-medium-kd": load(f"{model_dir}/archive/vit-medium-vit-16-kd.pt", device),
}

# Visualise results

In [None]:
images, preds = predict_batch(models, test_dl, device)
plot_predictions(images, preds)

In [None]:
model = models["unet"]
img, pred = predict_image(model, device, f"{img_dir}/L15-1060E-1355N.tif", image_size)
plot_image_and_prediction(img, pred)

In [None]:
truth, predicted = get_truth_vs_predicted(model, test_dl, device)


mask = (truth > 0) & (truth < 50)

plt.scatter(truth[mask], predicted[mask], alpha=0.2)

In [None]:
plt.hist2d(truth[mask], predicted[mask], bins)[-1]