We will get image segmentations for each image in our dataset

In [13]:
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import torch
import os

In [15]:
image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-cityscapes-panoptic")
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-small-cityscapes-panoptic"
)

def segment(image):
    inputs = image_processor(image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # Model predicts class_queries_logits of shape `(batch_size, num_queries)`
    # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
    class_queries_logits = outputs.class_queries_logits
    masks_queries_logits = outputs.masks_queries_logits

    # Perform post-processing to get panoptic segmentation map
    pred_panoptic_map = image_processor.post_process_panoptic_segmentation(
        outputs, target_sizes=[image.size[::-1]]
    )[0]["segmentation"]
    return pred_panoptic_map

In [14]:
import matplotlib.pyplot as plt
import numpy as np

def save_segmentation_image(image, pred_panoptic_map, save_path):
    # Create a color map
    unique_labels = np.unique(pred_panoptic_map)
    num_labels = len(unique_labels)
    colormap = plt.get_cmap('tab20', num_labels)

    # Create a colored segmentation map
    seg_image = colormap(pred_panoptic_map / num_labels)

    # Initialize plot
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(image)
    ax.imshow(seg_image, alpha=0.7)  # Set transparency to 70%

    # Calculate centroids and put labels
    for label in unique_labels:
        # Find pixels with each label and calculate their mean position
        positions = np.where(pred_panoptic_map == label)
        centroid_x, centroid_y = np.mean(positions[1]), np.mean(positions[0])
        ax.text(centroid_x, centroid_y, str(label), color='white', fontsize=12, ha='center', va='center')

    ax.axis('off')

    # Save the figure
    # ensure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, format='jpg', bbox_inches='tight')
    plt.close(fig)

In [8]:
def load_image(file_path):
    try:
        # Open the image file
        image = Image.open(file_path)
        return image
    except IOError:
        print("Error: Unable to load image. Please check the file path.")
        return None

In [50]:
#iterate over all images in "datasets/images" and segment them, store in "datasets/segmented" with filename "segmented_{filename}"
for filename in os.listdir("datasets/images"):
    new_filename = f"segmented_{filename}"
    #check if already segmented
    if os.path.isfile(f"datasets/segmented/{new_filename}"):
        continue
    image = load_image(f"datasets/images/{filename}")
    if filename == '.DS_Store' or filename == 'map.json':
        continue
    pred_panoptic_map = segment(image)
    save_segmentation_image(image, pred_panoptic_map, f"datasets/segmented/segmented_{filename}")

Error: Unable to load image. Please check the file path.
.DS_Store
Error: Unable to load image. Please check the file path.
map.json
