In [3]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import glob
import cv2
import re

def extract_number(filename):
    """Extract the numeric part from a filename."""
    numbers = re.findall(r'\d+', filename)
    if numbers:
        return numbers[-1]  # Return the last number found
    return None

def load_png_mask(png_path):
    """Load a binary mask from a PNG file."""
    img = np.array(Image.open(png_path))
    # Convert to binary (0 or 1)
    return img > 0

def load_npy_mask(npy_path):
    """Load a binary mask from a NPY file."""
    return np.load(npy_path).astype(bool)

def resize_mask(mask, target_shape):
    """Resize a binary mask to target shape."""
    mask_uint8 = mask.astype(np.uint8) * 255
    resized = cv2.resize(mask_uint8, (target_shape[1], target_shape[0]),
                         interpolation=cv2.INTER_NEAREST)
    return resized > 0

def match_files_by_number(png_dir, npy_dir):
    """Match PNG and NPY files based on the numeric part in their filenames."""
    # Get all PNG files with their numbers
    png_files = glob.glob(os.path.join(png_dir, '*_bmask.png'))
    png_dict = {}
    for png_path in png_files:
        num = extract_number(os.path.basename(png_path))
        if num:
            png_dict[num] = png_path

    # Get all NPY files with their numbers
    npy_files = glob.glob(os.path.join(npy_dir, '*.npy'))
    matched_pairs = []

    for npy_path in npy_files:
        num = extract_number(os.path.basename(npy_path))
        if num and num in png_dict:
            matched_pairs.append((png_dict[num], npy_path))

    return matched_pairs

def calculate_iou(mask1, mask2):
    """Calculate IoU between two binary masks."""
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()

    if union == 0:
        return 0.0  # Handle edge case where both masks are empty

    return intersection / union

def visualize_comparison(png_path, npy_path, output_dir, target_shape=None):
    """Visualize prediction mask and ground truth overlaid."""
    print(f"Processing: {os.path.basename(png_path)} and {os.path.basename(npy_path)}")

    # Load masks
    pred_mask = load_png_mask(png_path)
    gt_mask = load_npy_mask(npy_path)

    # Print original shapes for debugging
    print(f"Original shapes - Prediction: {pred_mask.shape}, Ground Truth: {gt_mask.shape}")

    # If target shape not specified, use the ground truth shape
    if target_shape is None:
        target_shape = gt_mask.shape

    # Resize masks if needed
    if pred_mask.shape != target_shape:
        pred_mask = resize_mask(pred_mask, target_shape)
    if gt_mask.shape != target_shape:
        gt_mask = resize_mask(gt_mask, target_shape)

    # Calculate IoU
    iou = calculate_iou(pred_mask, gt_mask)

    # Create RGB visualization
    viz = np.zeros((*target_shape, 3), dtype=np.uint8)

    # Red channel (prediction only - False Positives)
    viz[..., 0] = np.logical_and(pred_mask, np.logical_not(gt_mask)) * 255

    # Green channel (ground truth only - False Negatives)
    viz[..., 1] = np.logical_and(np.logical_not(pred_mask), gt_mask) * 255

    # Both Red and Green channels (overlap - True Positives)
    overlap = np.logical_and(pred_mask, gt_mask)
    viz[overlap, 0] = 255
    viz[overlap, 1] = 255

    # Calculate metrics
    true_pos = np.sum(overlap)
    false_pos = np.sum(np.logical_and(pred_mask, np.logical_not(gt_mask)))
    false_neg = np.sum(np.logical_and(np.logical_not(pred_mask), gt_mask))

    # Create figure
    plt.figure(figsize=(12, 10))

    # Plot original masks side by side
    plt.subplot(2, 3, 1)
    plt.imshow(pred_mask, cmap='gray')
    plt.title(f'Prediction')
    plt.axis('off')

    plt.subplot(2, 3, 2)
    plt.imshow(gt_mask, cmap='gray')
    plt.title(f'Ground Truth')
    plt.axis('off')

    # plot the original image
    plt.subplot(2, 3, 3)
    plt.imshow(Image.open(npy_path.replace('.npy', '.png')))
    #/content/dataset/test/ISIC_0024309.npy
    plt.title(f'Original Image')
    plt.axis('off')

    # Plot comparison visualization
    plt.subplot(2, 1, 2)
    plt.imshow(viz)
    plt.title(f'Comparison (IoU: {iou:.4f})\n'
              f'Yellow: True Positive ({true_pos} px)\n'
              f'Red: False Positive ({false_pos} px)\n'
              f'Green: False Negative ({false_neg} px)')
    plt.axis('off')

    # Extract number for the output filename
    number = extract_number(os.path.basename(png_path)) or extract_number(os.path.basename(npy_path))
    output_path = os.path.join(output_dir, f"comparison_{number}.png")

    # Save figure
    plt.tight_layout()
    plt.savefig(output_path, dpi=100)
    plt.close()

    print(f"Saved comparison to {output_path}")

    return {
        'png_file': os.path.basename(png_path),
        'npy_file': os.path.basename(npy_path),
        'number': number,
        'iou': iou,
        'true_positive': true_pos,
        'false_positive': false_pos,
        'false_negative': false_neg
    }

def batch_visualize(png_dir, npy_dir, output_dir, target_shape=None):
    """Process all matching files in the directories."""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Match files by their numeric part
    matched_pairs = match_files_by_number(png_dir, npy_dir)
    print(f"Found {len(matched_pairs)} matching file pairs")

    results = []
    for png_path, npy_path in matched_pairs:
        result = visualize_comparison(png_path, npy_path, output_dir, target_shape)
        results.append(result)

    # Calculate average metrics and create a summary report
    if results:
        avg_iou = sum(r['iou'] for r in results) / len(results)
        print(f"\nProcessed {len(results)} images")
        print(f"Average IoU: {avg_iou:.4f}")

        # Save summary report
        report_path = os.path.join(output_dir, "summary_report.txt")
        with open(report_path, 'w') as f:
            f.write(f"Summary Report\n")
            f.write(f"=============\n\n")
            f.write(f"Total images processed: {len(results)}\n")
            f.write(f"Average IoU: {avg_iou:.4f}\n\n")
            f.write(f"Individual Results:\n")
            f.write(f"{'Number':<10} {'IoU':<10} {'PNG File':<30} {'NPY File':<30}\n")
            f.write(f"{'-'*80}\n")

            # Sort by IoU (descending)
            sorted_results = sorted(results, key=lambda x: x['iou'], reverse=True)
            for r in sorted_results:
                f.write(f"{r['number']:<10} {r['iou']:<10.4f} {r['png_file']:<30} {r['npy_file']:<30}\n")

        print(f"Summary report saved to {report_path}")

    return results

# Example usage
if __name__ == "__main__":
    # Set your actual directories
    png_masks_dir = "results"  # Prediction masks (352x352)
    npy_masks_dir = "datasets/dataset/test"  # Ground truth masks (512x512)
    output_dir = "comparison_results"

    # Choose target shape - using ground truth shape (512x512)
    target_shape = (512, 512)

    results = batch_visualize(png_masks_dir, npy_masks_dir, output_dir, target_shape)

Found 681 matching file pairs
Processing: ISIC_0024319_bmask.png and ISIC_0024319.npy
Original shapes - Prediction: (352, 352), Ground Truth: (512, 512)
Saved comparison to comparison_results\comparison_0024319.png
Processing: ISIC_0024320_bmask.png and ISIC_0024320.npy
Original shapes - Prediction: (352, 352), Ground Truth: (512, 512)
Saved comparison to comparison_results\comparison_0024320.png
Processing: ISIC_0024321_bmask.png and ISIC_0024321.npy
Original shapes - Prediction: (352, 352), Ground Truth: (512, 512)
Saved comparison to comparison_results\comparison_0024321.png
Processing: ISIC_0024322_bmask.png and ISIC_0024322.npy
Original shapes - Prediction: (352, 352), Ground Truth: (512, 512)
Saved comparison to comparison_results\comparison_0024322.png
Processing: ISIC_0024323_bmask.png and ISIC_0024323.npy
Original shapes - Prediction: (352, 352), Ground Truth: (512, 512)
Saved comparison to comparison_results\comparison_0024323.png
Processing: ISIC_0024324_bmask.png and ISIC_0