In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt

from anomalymarinedetection.utils.assets import labels, labels_binary, labels_multi
from anomalymarinedetection.io.tif_io import TifIO

In [None]:
aggregate_classes = 'multi'
patch_name = 'S2_14-9-18_16PCC'
max_num = 5

In [None]:
if aggregate_classes == 'binary':
    labels_agg = labels_binary
elif aggregate_classes == 'multi':
    labels_agg = labels_multi
else:
    raise Exception("Wrong type of aggregation of classes")

In [None]:
# Constants
folder_predictions = '/data/anomaly-marine-detection/data/predicted_unet'
folder_gt = '/data/anomaly-marine-detection/data/patches'
model_name = 'unet'
gt_name = "cl"
ext = '.tif'
separator = '_'

original_labels = labels

In [None]:
tif_io = TifIO()

In [None]:
# Plot ground truth and prediction
# The Ground Truth is showing the colors of all the original 15 classes, and not the colors of the aggregated classes (binary or multi)
for number in range(max_num + 1):
    # Read rgb image
    file_path = os.path.join(folder_gt, patch_name, patch_name + separator + str(number) + ext)
    rgb_img, _ = tif_io.tif_2_rgb(file_path)
    
    # Read ground truth
    gt = rasterio.open(os.path.join(folder_gt, patch_name, patch_name + separator + str(number) + separator + gt_name + ext))
    gt_img = gt.read()
    
    # Read prediction
    pred = rasterio.open(os.path.join(folder_predictions, patch_name + separator + str(number) + separator + model_name + ext))
    pred_img = pred.read()
    pred_img.shape
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 11))
    ax[0].set_title("RGB Image")
    ax[0].imshow(rgb_img / rgb_img.max())
    ax[1].set_title("Prediction")
    ax[1].imshow(pred_img[0, :, :])
    ax[2].set_title("Ground Truth")
    ax[2].imshow(gt_img[0, :, :])
    # Print value counts in the prediction
    print('Prediction')
    values, counts = np.unique(pred_img, return_counts=True)
    for idx in range(len(values)):
        print(f"# pixels = {labels_agg[int(values[idx]) - 1]} -> {counts[idx]}")
    # Print value counts in the ground truth
    print('Ground truth')
    values, counts = np.unique(gt_img, return_counts=True)
    for idx in range(len(values)):
        print(f"# pixels = {original_labels[int(values[idx]) - 1]} -> {counts[idx]} times")