In [None]:
import os 
os.getcwd()
import sys 
sys.path.append("../")

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt

from marineanomalydetection.utils.assets import labels, labels_binary, labels_multi, labels_11_classes
from marineanomalydetection.io.tif_io import TifIO
from marineanomalydetection.io.load_data import (
    load_segmentation_map,
)
from marineanomalydetection.dataset.categoryaggregation import (
    CategoryAggregation,
)
from marineanomalydetection.dataset.aggregator import (
    aggregate_to_multi, aggregate_to_binary, aggregate_to_11_classes
)

In [None]:
aggregate_classes = CategoryAggregation.MULTI
tile_name = 'S2_4-3-18_50LLR' #'S2_30-8-18_16PCC'
init_num = 24
max_num = 24

In [None]:
ONLY_IMGS_WITH_MARINE_DEBRIS_IN_BOTH_PRED_AND_GT = False

In [None]:
if aggregate_classes == CategoryAggregation.BINARY:
    labels_agg = labels_binary
elif aggregate_classes == CategoryAggregation.MULTI:
    labels_agg = labels_multi
elif aggregate_classes == CategoryAggregation.ELEVEN:
    labels_agg = labels_11_classes
else:
    raise Exception("Wrong type of aggregation of classes")

In [None]:
# Constants
folder_predictions = '../data/predicted_unet'
folder_gt = '../data/patches'
model_name = 'unet'
gt_name = "cl"
ext = '.tif'
separator = '_'

original_labels = labels

In [None]:
tif_io = TifIO()

In [None]:
def aggregate_classes_fn(aggregate_classes, seg_map):
    # Aggregation
    if aggregate_classes == CategoryAggregation.MULTI:
        seg_map = aggregate_to_multi(seg_map)

    elif aggregate_classes == CategoryAggregation.BINARY:
        seg_map = aggregate_to_binary(seg_map)
        
    elif aggregate_classes == CategoryAggregation.ELEVEN:
        seg_map = aggregate_to_11_classes(seg_map)
    return seg_map

In [None]:
def get_coords_marine_debris(img: np.array, marine_debris_idx: int = 0):
    return np.where(img == marine_debris_idx)

In [None]:
def draw_circles_on_marine_debris_pixels(ax, coords_x_md, coords_y_md):
    for xx,yy in zip(coords_x_md, coords_y_md):
        circ = plt.Circle((xx, yy), 5, color='r', fill=False)
        ax.add_patch(circ)

In [None]:
def make_plot(
    ax,
    number: int,
    rgb_img: np.array, 
    pred_img: np.array,
    seg_map: np.array,
    not_labeled_idx: int = -1,
    marine_debris_idx_seg_map: int = 0,
    focus_on_marine_debris: bool = False
):
    ax[0].set_title("RGB Image")
    ax[0].imshow(rgb_img / rgb_img.max())
    
    ax[1].set_title("Prediction")
    ax[1].imshow(pred_img[0, :, :], vmin=not_labeled_idx, vmax=len(labels_multi))
    if focus_on_marine_debris:
        coords_md_y, coords_md_x = get_coords_marine_debris(pred_img[0, :, :], marine_debris_idx_seg_map)
        draw_circles_on_marine_debris_pixels(ax[1], coords_md_x, coords_md_y)
        
    ax[2].set_title("Ground Truth")
    ax[2].imshow(seg_map, vmin=not_labeled_idx, vmax=len(labels_multi))
    if focus_on_marine_debris:
        coords_md_y, coords_md_x = get_coords_marine_debris(seg_map, marine_debris_idx_seg_map)
        draw_circles_on_marine_debris_pixels(ax[2], coords_md_x, coords_md_y)
    
    ax[0].axis('off')
    ax[1].axis('off')
    ax[2].axis('off')
    plt.tight_layout()
    plt.show()
    
    print(f"Idx: {number}")
    
    
    # Print value counts in the prediction
    print('Prediction')
    values_1, counts_1 = np.unique(pred_img, return_counts=True)
    for idx in range(len(values_1)):
        print(f"# pixels = {labels_agg[int(values_1[idx])]} -> {counts_1[idx]}")
    # Print value counts in the ground truth
    print('Ground truth')
    values, counts = np.unique(seg_map, return_counts=True)

    for idx in range(len(values)):
        if values[idx] != not_labeled_idx:
            print(f"# pixels = {labels_agg[int(values[idx])]} -> {counts[idx]} times")
    
    print(np.unique(pred_img))
    print(np.unique(seg_map))
    print("_" * 80)

In [None]:
not_labeled_idx = -1
focus_on_marine_debris = ONLY_IMGS_WITH_MARINE_DEBRIS_IN_BOTH_PRED_AND_GT
# Plot ground truth and prediction
# The Ground Truth is showing the colors of all the original 15 classes, and not the colors of the aggregated classes (binary or multi)
for number in range(init_num, max_num + 1):
    # Read rgb image
    file_path = os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + ext)
    rgb_img, _ = tif_io.tif_2_rgb(file_path)
    
    # Read ground truth
    seg_map = load_segmentation_map(os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + separator + gt_name + ext))
    seg_map = aggregate_classes_fn(aggregate_classes, seg_map)
    seg_map = np.copy(seg_map - 1)
    
    # Read prediction
    pred_path = os.path.join(folder_predictions, tile_name + separator + str(number) + separator + model_name + ext)
    pred = rasterio.open(pred_path)
    pred_img = pred.read()
    
    pred_img = pred_img - 1
    
    marine_debris_idx_seg_map = 0

    if ONLY_IMGS_WITH_MARINE_DEBRIS_IN_BOTH_PRED_AND_GT \
        and marine_debris_idx_seg_map in np.unique(seg_map) \
        and marine_debris_idx_seg_map in np.unique(pred_img):
                
        fig, ax = plt.subplots(1, 3, figsize=(20, 11))
        make_plot(
            ax,
            number,
            rgb_img, 
            pred_img,
            seg_map,
            not_labeled_idx,
            focus_on_marine_debris=focus_on_marine_debris
        )
    else:
        fig, ax = plt.subplots(1, 3, figsize=(20, 11))
        make_plot(
            ax,
            number,
            rgb_img, 
            pred_img,
            seg_map,
            not_labeled_idx,
            marine_debris_idx_seg_map, 
            focus_on_marine_debris=focus_on_marine_debris
        )
        

In [None]:
IDX_TO_SAVE = [30, 32]

In [None]:
not_labeled_idx = -1
# Plot ground truth and prediction
# The Ground Truth is showing the colors of all the original 15 classes, and not the colors of the aggregated classes (binary or multi)
for number in IDX_TO_SAVE:
        # Read rgb image
        file_path = os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + ext)
        rgb_img, _ = tif_io.tif_2_rgb(file_path)
        
        # Read ground truth
        seg_map = load_segmentation_map(os.path.join(folder_gt, tile_name, tile_name + separator + str(number) + separator + gt_name + ext))
        seg_map = aggregate_classes_fn(aggregate_classes, seg_map)
        seg_map = np.copy(seg_map - 1)
        
        # Read prediction
        pred = rasterio.open(os.path.join(folder_predictions, tile_name + separator + str(number) + separator + model_name + ext))
        pred_img = pred.read()
        
        pred_img = pred_img - 1
        
        fig, ax = plt.subplots(1)
        ax.set_aspect('equal')
        plt.title("RGB Image")
        plt.axis('off')
        plt.imshow(rgb_img / rgb_img.max())
        plt.savefig(f"../res/RGB{number}.png", bbox_inches='tight')
        
        fig, ax = plt.subplots(1)
        ax.set_aspect('equal')
        ax.set_title("Semi-supervised prediction")
        ax.axis('off')
        plt.imshow(pred_img[0, :, :], vmin=-1, vmax=len(labels_multi))
        if focus_on_marine_debris:
            coords_md_y, coords_md_x = get_coords_marine_debris(pred_img[0, :, :], marine_debris_idx_seg_map)
            draw_circles_on_marine_debris_pixels(ax, coords_md_x, coords_md_y)
        plt.savefig(f"../res/Prediction_ssl{number}.png", bbox_inches='tight')
        
        fig, ax = plt.subplots(1)
        ax.set_aspect('equal')
        ax.set_title("Ground Truth")
        ax.axis('off')
        ax.imshow(seg_map, vmin=-1, vmax=len(labels_multi))

        if focus_on_marine_debris:
            coords_md_y, coords_md_x = get_coords_marine_debris(seg_map, marine_debris_idx_seg_map)
            draw_circles_on_marine_debris_pixels(ax, coords_md_x, coords_md_y)
        plt.savefig(f"../res/gt{number}.png", bbox_inches='tight')