In [None]:
import os 
os.getcwd()
import sys 
sys.path.append("../")

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt

from anomalymarinedetection.utils.assets import labels, labels_binary, labels_multi
from anomalymarinedetection.io.tif_io import TifIO
from anomalymarinedetection.io.load_data import (
    load_segmentation_map,
)
from anomalymarinedetection.dataset.categoryaggregation import (
    CategoryAggregation,
)
from anomalymarinedetection.dataset.dataloadertype import DataLoaderType
from anomalymarinedetection.dataset.aggregate_classes_to_super_class import (
    aggregate_classes_to_super_class,
)
from anomalymarinedetection.utils.assets import (
    cat_mapping,
    cat_mapping_binary,
    cat_mapping_multi,
)

In [None]:
aggregate_classes = CategoryAggregation.MULTI
tile_name = 'S2_14-9-18_16PCC'
max_num = 43

In [None]:
if aggregate_classes == CategoryAggregation.BINARY:
    labels_agg = labels_binary
elif aggregate_classes == CategoryAggregation.MULTI:
    labels_agg = labels_multi
else:
    raise Exception("Wrong type of aggregation of classes")

In [None]:
# Constants
folder_predictions = '../data/predicted_unet'
folder_gt = '../data/patches'
model_name = 'unet'
gt_name = "cl"
ext = '.tif'
separator = '_'

original_labels = labels

In [None]:
tif_io = TifIO()

In [None]:
def aggregate_classes_fn(aggregate_classes, seg_map):
    # Aggregation
    if aggregate_classes == CategoryAggregation.MULTI:
        # Keep classes: Marine Water, Cloud, Ship, Marine Debris,
        # Algae/Organic Material.
        # Note: make sure you aggregate classes according to the
        # increasing order specified in assets.

        # Aggregate 'Dense Sargassum','Sparse Sargassum', 'Natural
        # Organic Material' to Algae/Natural Organic Material.
        algae_classes_names = labels[
            labels.index("Dense Sargassum") : labels.index(
                "Natural Organic Material"
            )
            + 1
        ]
        super_organic_material_class_name = labels_multi[1]
        seg_map = aggregate_classes_to_super_class(
            seg_map,
            algae_classes_names,
            super_organic_material_class_name,
            cat_mapping,
            cat_mapping_multi,
        )

        # Aggregate Ship to new position
        ship_class_name = [labels[4]]
        super_ship_class_name = labels[4]
        seg_map = aggregate_classes_to_super_class(
            seg_map,
            ship_class_name,
            super_ship_class_name,
            cat_mapping,
            cat_mapping_multi,
        )

        # Aggregate Clouds to new position
        clouds_class_name = [labels[5]]
        super_clouds_class_name = labels[5]
        seg_map = aggregate_classes_to_super_class(
            seg_map,
            clouds_class_name,
            super_clouds_class_name,
            cat_mapping,
            cat_mapping_multi,
        )

        # Aggregate 'Sediment-Laden Water', 'Foam','Turbid Water',
        # 'Shallow Water','Waves','Cloud Shadows','Wakes',
        # 'Mixed Water' to 'Marine Water'
        water_classes_names = labels[-9:]
        super_water_class_name = labels[6]

        seg_map = aggregate_classes_to_super_class(
            seg_map,
            water_classes_names,
            super_water_class_name,
            cat_mapping,
            cat_mapping_multi,
        )

    elif aggregate_classes == CategoryAggregation.BINARY:
        # Keep classes: Marine Debris and Other
        # Aggregate all classes (except Marine Debris) to Marine
        # Water Class
        other_classes_names = labels[labels_binary.index("Other") :]
        super_class_name = labels_binary[
            labels_binary.index("Other")
        ]
        seg_map = aggregate_classes_to_super_class(
            seg_map,
            other_classes_names,
            super_class_name,
            cat_mapping,
            cat_mapping_binary,
        )
    return seg_map

In [None]:
not_labeled_idx = -1
# Plot ground truth and prediction
# The Ground Truth is showing the colors of all the original 15 classes, and not the colors of the aggregated classes (binary or multi)
for number in range(max_num + 1):
    # Read rgb image
    file_path = os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + ext)
    rgb_img, _ = tif_io.tif_2_rgb(file_path)
    
    # Read ground truth
    seg_map = load_segmentation_map(os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + separator + gt_name + ext))
    seg_map = aggregate_classes_fn(aggregate_classes, seg_map)
    seg_map = np.copy(seg_map - 1)
    
    # Read prediction
    pred = rasterio.open(os.path.join(folder_predictions, tile_name + separator + str(number) + separator + model_name + ext))
    pred_img = pred.read()
    pred_img.shape
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 11))
    ax[0].set_title("RGB Image")
    ax[0].imshow(rgb_img / rgb_img.max())
    ax[1].set_title("Prediction")
    ax[1].imshow(pred_img[0, :, :])
    ax[2].set_title("Ground Truth")
    ax[2].imshow(seg_map)

    # Print value counts in the prediction
    print('Prediction')
    values_1, counts_1 = np.unique(pred_img, return_counts=True)
    for idx in range(len(values_1)):
        print(f"# pixels = {labels_agg[int(values_1[idx]) - 1]} -> {counts_1[idx]}")
    # Print value counts in the ground truth
    print('Ground truth')
    values, counts = np.unique(seg_map, return_counts=True)

    for idx in range(len(values)):
        if values[idx] != not_labeled_idx:
            print(f"# pixels = {labels_agg[int(values[idx])]} -> {counts[idx]} times")