In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.color import label2rgb

def plot_segmentation_mask(image, mask, class_names=None):
    """
    Plot the image and its segmentation mask.
    
    Parameters:
    image (numpy.ndarray): The input image.
    mask (numpy.ndarray): The segmentation mask.
    class_names (list, optional): A list of class names for each channel in the mask.
    """
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    ax[0].imshow(image)
    ax[0].set_title("Original Image")

    if class_names is None:
        ax[1].imshow(mask)
    else:
        labeled_mask = np.argmax(mask, axis=-1)
        colored_mask = label2rgb(labeled_mask, bg_label=0)
        ax[1].imshow(colored_mask)
        ax[1].set_title("Segmentation Mask")

        # Add legend
        handles = [plt.Rectangle((0, 0), 1, 1, color=cmap[i]) for i in range(len(class_names))]
        ax[1].legend(handles, class_names, loc="lower right")

    plt.show()

def analyze_segmentation_masks(image_dir, mask_dir, class_names=None):
    """
    Analyze the segmentation masks in the given directories.
    
    Parameters:
    image_dir (str): The directory containing the original images.
    mask_dir (str): The directory containing the segmentation masks.
    class_names (list, optional): A list of class names for each channel in the mask.
    """
    image_files = [f for f in os.listdir(image_dir) if f.endswith(".npy")]
    mask_files = [f for f in os.listdir(mask_dir) if f.endswith(".npy")]

    # Assume the files are named consistently (e.g., image_0.npy, mask_0.npy)
    assert len(image_files) == len(mask_files)

    # Analyze the segmentation masks
    for i in range(len(image_files)):
        image = np.load(os.path.join(image_dir, image_files[i]))
        mask = np.load(os.path.join(mask_dir, mask_files[i]))

        plot_segmentation_mask(image, mask, class_names)

        # Additional analysis:
        # - Number of objects per image
        # - Size distribution of objects
        # - Class distribution of objects
        num_objects = np.sum(np.any(mask > 0, axis=(0, 1)))
        print(f"Image {i}: {num_objects} objects detected")

# Example usage
output_dir = "/path/to/output/directory"
image_dir = os.path.join(output_dir, "coco", "segmentation", "images")
mask_dir = os.path.join(output_dir, "coco", "segmentation", "masks")

class_names = ["Background", "Person", "Vehicle", "Animal", "Accessory", "Indoor"]
analyze_segmentation_masks(image_dir, mask_dir, class_names)