In [None]:
import os 
os.getcwd()
import sys 
sys.path.append("../")

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt

from marineanomalydetection.utils.assets import labels, labels_binary, labels_multi, labels_11_classes
from marineanomalydetection.io.tif_io import TifIO
from marineanomalydetection.io.load_data import (
    load_segmentation_map,
)
from marineanomalydetection.dataset.categoryaggregation import (
    CategoryAggregation,
)
from marineanomalydetection.dataset.aggregator import (
    aggregate_to_multi, aggregate_to_binary, aggregate_to_11_classes
)

In [None]:
aggregate_classes = CategoryAggregation.MULTI
tile_name = 'S2_27-1-19_16QED'
max_num = 20

In [None]:
if aggregate_classes == CategoryAggregation.BINARY:
    labels_agg = labels_binary
elif aggregate_classes == CategoryAggregation.MULTI:
    labels_agg = labels_multi
elif aggregate_classes == CategoryAggregation.ELEVEN:
    labels_agg = labels_11_classes
else:
    raise Exception("Wrong type of aggregation of classes")

In [None]:
# Constants
folder_predictions = '../data/predicted_unet'
folder_gt = '../data/patches'
model_name = 'unet'
gt_name = "cl"
ext = '.tif'
separator = '_'

original_labels = labels

In [None]:
tif_io = TifIO()

In [None]:
def aggregate_classes_fn(aggregate_classes, seg_map):
    # Aggregation
    if aggregate_classes == CategoryAggregation.MULTI:
        seg_map = aggregate_to_multi(seg_map)

    elif aggregate_classes == CategoryAggregation.BINARY:
        seg_map = aggregate_to_binary(seg_map)
        
    elif aggregate_classes == CategoryAggregation.ELEVEN:
        seg_map = aggregate_to_11_classes(seg_map)
    return seg_map

In [None]:
not_labeled_idx = -1
# Plot ground truth and prediction
# The Ground Truth is showing the colors of all the original 15 classes, and not the colors of the aggregated classes (binary or multi)
for number in range(max_num + 1):
    # Read rgb image
    file_path = os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + ext)
    rgb_img, _ = tif_io.tif_2_rgb(file_path)
    
    # Read ground truth
    seg_map = load_segmentation_map(os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + separator + gt_name + ext))
    seg_map = aggregate_classes_fn(aggregate_classes, seg_map)
    seg_map = np.copy(seg_map - 1)
    
    # Read prediction
    pred = rasterio.open(os.path.join(folder_predictions, tile_name + separator + str(number) + separator + model_name + ext))
    pred_img = pred.read()
    pred_img.shape
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 11))
    ax[0].set_title("RGB Image")
    ax[0].imshow(rgb_img / rgb_img.max())
    ax[1].set_title("Prediction")
    ax[1].imshow(pred_img[0, :, :])
    ax[2].set_title("Ground Truth")
    ax[2].imshow(seg_map)
    ax[0].axis('off')
    ax[1].axis('off')
    ax[2].axis('off')
    plt.tight_layout()
    plt.show()

    # Print value counts in the prediction
    print('Prediction')
    values_1, counts_1 = np.unique(pred_img, return_counts=True)
    for idx in range(len(values_1)):
        print(f"# pixels = {labels_agg[int(values_1[idx]) - 1]} -> {counts_1[idx]}")
    # Print value counts in the ground truth
    print('Ground truth')
    values, counts = np.unique(seg_map, return_counts=True)

    for idx in range(len(values)):
        if values[idx] != not_labeled_idx:
            print(f"# pixels = {labels_agg[int(values[idx])]} -> {counts[idx]} times")