In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
from sklearn.decomposition import PCA
from scipy.ndimage import distance_transform_edt
from skimage.feature import graycomatrix, graycoprops, local_binary_pattern

def compute_energy(image):
    """
    Compute the energy of an image based on gradient magnitude.

    Args:
        image (np.array): Input grayscale image, shape (height, width).

    Returns:
        energy (np.array): Energy map, same shape as input.
    """
    grad_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
    energy = np.sqrt(grad_x**2 + grad_y**2)
    return energy

def genericSegEvaluation(seg1, seg2):
    """
    Evaluate non-semantic segmentation using a confusion matrix-based metric.
    Implements the metric from https://doi.org/10.1109/TIP.2005.854491.

    Args:
        seg1 (np.array): First segmentation mask.
        seg2 (np.array): Second segmentation mask.

    Returns:
        error (float): Segmentation error in range [0, 1].
    """
    seg1_flat = seg1.flatten()
    seg2_flat = seg2.flatten()
    cm = confusion_matrix(seg1_flat, seg2_flat)
    
    # Optimal mapping to maximize matching
    row_ind, col_ind = linear_sum_assignment(cm, maximize=True)
    quality = cm[row_ind, col_ind].sum()
    
    # Compute error
    error = 1 - quality / (seg1_flat.size - 1)
    return error

def iou_dice(pred_mask, exp_mask):
    """
    Calculate IoU and Dice coefficients for binary masks.

    Args:
        pred_mask (np.array): Predicted binary mask.
        exp_mask (np.array): Expected binary mask.

    Returns:
        iou (float): Intersection over Union score.
        dice (float): Dice coefficient.
    """
    tp = np.sum((pred_mask == 1) & (exp_mask == 1))
    fp = np.sum((pred_mask == 1) & (exp_mask == 0))
    fn = np.sum((pred_mask == 0) & (exp_mask == 1))
    
    iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
    dice = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    return iou, dice

def get_stats(pred_mask, exp_mask):
    """
    Compute true positives, false positives, false negatives, and true negatives.

    Args:
        pred_mask (np.array): Predicted binary mask.
        exp_mask (np.array): Expected binary mask.

    Returns:
        tp (int): True positives.
        fp (int): False positives.
        fn (int): False negatives.
        tn (int): True negatives.
    """
    tp = np.sum((pred_mask == 1) & (exp_mask == 1))
    fp = np.sum((pred_mask == 1) & (exp_mask == 0))
    fn = np.sum((pred_mask == 0) & (exp_mask == 1))
    tn = np.sum((pred_mask == 0) & (exp_mask == 0))
    return tp, fp, fn, tn

def convert_segmented_to_labels(image, color_palette):
    """
    Convert an RGB segmented image to class labels using a color palette.

    Args:
        image (np.array): Input RGB image, shape (height, width, 3).
        color_palette (dict): Mapping of labels to RGB colors (tuple).

    Returns:
        labels (np.array): Label array, shape (height, width).
    """
    height, width, _ = image.shape
    labels = np.zeros((height, width), dtype=np.int32)
    for label, color in color_palette.items():
        mask = np.all(image == color, axis=-1)
        labels[mask] = label
    return labels

def calculate_multiclass_metrics_per_class(y_true, y_pred, num_classes):
    """
    Calculate multiclass segmentation metrics per class using a confusion matrix.
    Includes optimal mapping for aligning predicted labels.

    Args:
        y_true (np.array): Ground truth labels, shape (height, width).
        y_pred (np.array): Predicted labels, shape (height, width).
        num_classes (int): Number of classes.

    Returns:
        iou_per_class (list): IoU score for each class.
        dice_per_class (list): Dice coefficient for each class.
        global_accuracy (float): Overall accuracy across pixels.
        precision_per_class (list): Precision score for each class.
        recall_per_class (list): Recall score for each class.
        accuracy_per_class (list): Per-class accuracy.
    """
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    cm = confusion_matrix(y_true_flat, y_pred_flat, labels=range(num_classes))

    # Optimal mapping to align predicted labels
    row_ind, col_ind = linear_sum_assignment(-cm)
    mapped_pred = np.zeros_like(y_pred_flat)
    for i, j in zip(row_ind, col_ind):
        mapped_pred[y_pred_flat == j] = i

    cm_mapped = confusion_matrix(y_true_flat, mapped_pred, labels=range(num_classes))

    iou_per_class = []
    dice_per_class = []
    precision_per_class = []
    recall_per_class = []
    accuracy_per_class = []

    total_pixels = y_true_flat.size
    global_accuracy = np.sum(y_true_flat == mapped_pred) / total_pixels

    for cls in range(num_classes):
        tp = cm_mapped[cls, cls]
        fp = cm_mapped[:, cls].sum() - tp
        fn = cm_mapped[cls, :].sum() - tp
        tn = total_pixels - (tp + fp + fn)

        iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
        iou_per_class.append(iou)

        dice = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
        dice_per_class.append(dice)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        precision_per_class.append(precision)

        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        recall_per_class.append(recall)

        class_accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0
        accuracy_per_class.append(class_accuracy)

    return iou_per_class, dice_per_class, global_accuracy, precision_per_class, recall_per_class, accuracy_per_class

def apply_gabor_filter(gray_image):
    """
    Apply multiple Gabor filters to a grayscale image for texture analysis.

    Args:
        gray_image (np.array): Grayscale image, shape (height, width).

    Returns:
        filtered_images (np.array): Array of filtered images, shape (num_filters, height, width).
    """
    kernels = []
    sigmas = [2.0, 4.0, 8.0]
    lambdas = [5.0, 10.0, 15.0]
    angles = np.arange(0, np.pi, np.pi / 8)

    for sigma in sigmas:
        for lambd in lambdas:
            for theta in angles:
                kernel = cv2.getGaborKernel((21, 21), sigma, theta, lambd, 0.5, 0, ktype=cv2.CV_32F)
                filtered = cv2.filter2D(gray_image, cv2.CV_8UC3, kernel)
                kernels.append(filtered)
    
    return np.array(kernels)

def extract_lbp_features(gray_image):
    """
    Extract Local Binary Pattern (LBP) histogram from a grayscale image.

    Args:
        gray_image (np.array): Grayscale image, shape (height, width).

    Returns:
        lbp_hist (np.array): LBP histogram, shape (10,).
    """
    P = 20
    R = 2
    lbp = local_binary_pattern(gray_image, P=P, R=R, method="uniform")
    lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 11), density=True)
    return lbp_hist

def extract_glcm_features(gray_image):
    """
    Extract GLCM-based texture features from a grayscale image.

    Args:
        gray_image (np.array): Grayscale image, shape (height, width).

    Returns:
        contrast (float): Mean contrast across GLCMs.
        homogeneity (float): Mean homogeneity across GLCMs.
        energy (float): Mean energy across GLCMs.
        correlation (float): Mean correlation across GLCMs.
    """
    distances = [1, 3, 5]
    angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
    glcm = graycomatrix(gray_image, distances=distances, angles=angles, levels=256, symmetric=True, normed=True)
    
    contrast = np.mean(graycoprops(glcm, 'contrast'))
    homogeneity = np.mean(graycoprops(glcm, 'homogeneity'))
    energy = np.mean(graycoprops(glcm, 'energy'))
    correlation = np.mean(graycoprops(glcm, 'correlation'))
    
    return contrast, homogeneity, energy, correlation

def extract_features(image):
    """
    Extract a set of features from an RGB image for classification.

    Args:
        image (np.array): Input RGB image, shape (height, width, 3).

    Returns:
        features (np.array): Feature matrix, shape (height * width, 13).
    """
    height, width, channels = image.shape
    features = np.zeros((height, width, 13))

    # Normalize RGB channels
    image_norm = image.astype(np.float32) / 255.0
    features[:, :, :3] = image_norm

    # Compute local mean
    mean_local = cv2.blur(image_norm, (11, 11))
    features[:, :, 3] = cv2.cvtColor(mean_local, cv2.COLOR_BGR2GRAY)

    # Detect edges
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray_image, 100, 200)
    features[:, :, 4] = edges / 255.0

    # Compute distance transform from edges
    dist_transform = distance_transform_edt(1 - (edges / 255.0))
    features[:, :, 5] = dist_transform / np.max(dist_transform) if np.max(dist_transform) > 0 else dist_transform

    # Extract GLCM features
    contrast, homogeneity, energy, correlation = extract_glcm_features(gray_image)
    features[:, :, 6] = contrast
    features[:, :, 7] = homogeneity
    features[:, :, 8] = energy
    features[:, :, 9] = correlation

    # Extract LBP histogram
    lbp_hist = extract_lbp_features(gray_image)
    features[:, :, 10] = lbp_hist[0]
    features[:, :, 11] = lbp_hist[1]

    # Apply Gabor filter (use first filter)
    gabor_features = apply_gabor_filter(gray_image)
    features[:, :, 12] = gabor_features[0]

    return features.reshape((-1, 13))

In [None]:
# Directory paths for KNN data
input_path = "../../experiments_data/knn_training/input"
target_path = "../../experiments_data/knn_training/target"

def train_multiclass_knn():
    """
    Train a multiclass KNN classifier for image segmentation.
    Extracts features and dynamically creates a color palette from ground truth masks.

    Args:
        None.

    Returns:
        knn (KNeighborsClassifier): Trained KNN model.
        color_palette (dict): Mapping of class indices to RGB colors (BGR tuples).
    """
    # Initialize lists for training data and color palette
    train_features = []
    train_labels = []
    color_palette = []

    # Process training images
    for filename in os.listdir(input_path):
        if filename.endswith(".tif"):
            # Construct file paths
            img_path = os.path.join(input_path, filename)
            label_path = os.path.join(target_path, filename.replace(".tif", ".png"))

            # Load image and ground truth mask
            image = cv2.imread(img_path)
            label = cv2.imread(label_path)

            if image is None or label is None:
                print(f"Error loading {filename}")
                continue

            # Extract features
            features = extract_features(image)

            # Convert RGB mask to class indices
            height, width, _ = label.shape
            label_flat = label.reshape(-1, 3)
            unique_colors, unique_labels = np.unique(label_flat, axis=0, return_inverse=True)
            unique_colors = [tuple(color) for color in unique_colors]
            
            # Update color palette with new colors
            for color in unique_colors:
                if color not in color_palette:
                    color_palette.append(color)

            train_features.append(features)
            train_labels.append(unique_labels)

    # Convert lists to numpy arrays
    train_features = np.vstack(train_features)
    train_labels = np.concatenate(train_labels)

    # Train KNN classifier
    knn = KNeighborsClassifier(n_neighbors=20, weights='distance')
    knn.fit(train_features, train_labels)

    # Create color palette dictionary
    color_palette = {i: color for i, color in enumerate(color_palette)}

    print("Multiclass KNN training completed!")
    return knn, color_palette

knn, color_palette = train_multiclass_knn()
color_palette

Multiclass KNN training completed!


{0: (np.uint8(0), np.uint8(156), np.uint8(255)),
 1: (np.uint8(0), np.uint8(255), np.uint8(0)),
 2: (np.uint8(77), np.uint8(255), np.uint8(255)),
 3: (np.uint8(255), np.uint8(0), np.uint8(0)),
 4: (np.uint8(255), np.uint8(0), np.uint8(255))}

In [None]:
# Directory paths for test data
test_input_path = "../../experiments_data/test_input"
test_target_path = "../../experiments_data/test_target"
output_path = "test_results_knn_multi"

# Create output directory if it doesn't exist
if not os.path.exists(output_path):
    os.makedirs(output_path)

def test_multiclass_knn():
    """
    Test a multiclass KNN classifier on images and generate segmented images and comparisons.
    Saves segmented images and visualizations of original, segmented, and target images.

    Args:
        None

    Returns:
        None
    """
    # Process test images
    for filename in os.listdir(test_input_path):
        if filename.endswith(".tif"):
            # Construct file paths
            test_image_path = os.path.join(test_input_path, filename)
            test_labels_path = os.path.join(test_target_path, filename.replace(".tif", ".png"))
            output_image_path = os.path.join(output_path, f"resultado_knn_{filename.replace('.tif', '.png')}")
            comparison_path = os.path.join(output_path, f"comparacao_{filename.replace('.tif', '.png')}")

            # Load test image and ground truth mask
            test_image = cv2.imread(test_image_path)
            test_labels_rgb = cv2.imread(test_labels_path)

            if test_image is None or test_labels_rgb is None:
                print(f"Error loading {filename} or its corresponding mask!")
                continue

            # Extract features and predict labels
            test_features = extract_features(test_image)
            predicted_labels = knn.predict(test_features)
            predicted_labels = predicted_labels.reshape(test_image.shape[:2])

            # Create segmented image using color palette
            segmented_image = np.zeros((*predicted_labels.shape, 3), dtype=np.uint8)
            for label_idx, color in color_palette.items():
                segmented_image[predicted_labels == label_idx] = color

            # Save segmented image
            cv2.imwrite(output_image_path, segmented_image)
            print(f"Result saved at: {output_image_path}")

            # Generate and save comparison visualization
            plt.figure(figsize=(18, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB))
            plt.title("Original Image")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
            plt.title("Segmentation (KNN)")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(cv2.cvtColor(test_labels_rgb, cv2.COLOR_BGR2RGB))
            plt.title("Target Image")
            plt.axis("off")

            plt.savefig(comparison_path)
            plt.close()  # Close figure to prevent display
            print(f"Comparison saved at: {comparison_path}")

    print("Testing completed!")

test_multiclass_knn()

In [None]:
# Directory paths
test_target_path = "../../experiments_data/test_target"
output_path = "test_results_knn_multi"
metrics_file = os.path.join(output_path, "metrics.txt")

def evaluate_multiclass_knn_segmentation():
    """
    Evaluate multiclass KNN segmentation results and compute metrics.
    Saves IoU, Dice, Precision, Recall, Accuracy, and error metrics to a file.

    Args:
        None

    Returns:
        None
    """
    # Initialize lists to store metrics for each class
    num_classes = len(color_palette)
    iou_all_images = [[] for _ in range(num_classes)]
    dice_all_images = [[] for _ in range(num_classes)]
    precision_all_images = [[] for _ in range(num_classes)]
    recall_all_images = [[] for _ in range(num_classes)]
    accuracy_per_class_all_images = [[] for _ in range(num_classes)]
    global_accuracy_all_images = []
    generic_seg_errors = []

    # Process saved segmentation results
    for filename in os.listdir(output_path):
        if filename.startswith("resultado_knn_") and filename.endswith(".png"):
            # Construct file paths
            original_filename = filename.replace("resultado_knn_", "").replace(".png", ".tif")
            test_labels_path = os.path.join(test_target_path, original_filename.replace(".tif", ".png"))
            segmented_image_path = os.path.join(output_path, filename)

            # Load ground truth and segmented images
            test_labels_rgb = cv2.imread(test_labels_path)
            segmented_image = cv2.imread(segmented_image_path)

            if test_labels_rgb is None or segmented_image is None:
                print(f"Error loading {original_filename} or its segmented result!")
                continue

            # Convert images to class labels
            test_labels = convert_segmented_to_labels(test_labels_rgb, color_palette)
            predicted_labels = convert_segmented_to_labels(segmented_image, color_palette)

            # Compute per-class metrics
            iou_per_class, dice_per_class, global_accuracy, precision_per_class, recall_per_class, accuracy_per_class = calculate_multiclass_metrics_per_class(test_labels, predicted_labels, num_classes)

            # Compute segmentation error
            error = genericSegEvaluation(test_labels, predicted_labels)

            # Store metrics
            for cls in range(num_classes):
                iou_all_images[cls].append(iou_per_class[cls])
                dice_all_images[cls].append(dice_per_class[cls])
                precision_all_images[cls].append(precision_per_class[cls])
                recall_all_images[cls].append(recall_per_class[cls])
                accuracy_per_class_all_images[cls].append(accuracy_per_class[cls])
            global_accuracy_all_images.append(global_accuracy)
            generic_seg_errors.append(error)

    # Calculate average metrics across images
    mean_iou_per_class = [np.mean(iou_all_images[cls]) for cls in range(num_classes)]
    mean_dice_per_class = [np.mean(dice_all_images[cls]) for cls in range(num_classes)]
    mean_precision_per_class = [np.mean(precision_all_images[cls]) for cls in range(num_classes)]
    mean_recall_per_class = [np.mean(recall_all_images[cls]) for cls in range(num_classes)]
    mean_accuracy_per_class = [np.mean(accuracy_per_class_all_images[cls]) for cls in range(num_classes)]
    mean_global_accuracy = np.mean(global_accuracy_all_images)
    mean_generic_seg_error = np.mean(generic_seg_errors)

    # Write metrics to file
    class_names = ["Orange", "Green", "Light Yellow (Empty Space)", "Blue", "Magenta"]
    with open(metrics_file, 'w') as f:
        f.write("Metric Results (KNN Multiclass)\n")
        f.write("==================================\n")
        
        for cls in range(num_classes):
            f.write(f"Class: {class_names[cls]}\n")
            f.write(f"Mean IoU: {mean_iou_per_class[cls]:.4f}\n")
            f.write(f"Mean Dice: {mean_dice_per_class[cls]:.4f}\n")
            f.write(f"Mean Precision: {mean_precision_per_class[cls]:.4f}\n")
            f.write(f"Mean Recall: {mean_recall_per_class[cls]:.4f}\n")
            f.write(f"Mean Accuracy: {mean_accuracy_per_class[cls]:.4f}\n")
            f.write("----------------------------------\n")
        
        f.write(f"Global Accuracy (average across images): {mean_global_accuracy:.4f}\n")
        f.write(f"Mean genericSegEvaluation Error: {mean_generic_seg_error:.4f}\n")

    print(f"Multiclass metrics saved at: {metrics_file}")

evaluate_multiclass_knn_segmentation()

Multiclass metrics saved at: test_results_MULTI/metrics.txt
