Image segmentation models separate areas corresponding to different areas of interest in an image. These models work by assigning a label to each pixel. There are several types of segmentation: semantic segmentation, instance segmentation, and panoptic segmentation.

In this guide, we will:

1. Take a look at different types of segmentation.
2. Have an end-to-end fine-tuning example for semantic segmentation.

# Libraries

In [None]:
pip install -q datasets transformers evaluate accelerate

In [None]:

import json
import torch
import requests
import evaluate
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datasets import load_dataset
from huggingface_hub import cached_download, hf_hub_url
from torchvision.transforms import ColorJitter
from transformers import pipeline, AutoImageProcessor, AutoModelForSemanticSegmentation, TrainingArguments, Trainer


# Types of Segmentation

## Semantic Segmentation

Semantic segmentation assigns a label or class to every single pixel in an image. If we were to take a look at a semantic segmentation model output, it will assign the same class to every instance of an object it comes across in an image. For example, all cats will be labeled as “cat” instead of “cat-1”, “cat-2”. We can use transformers’ image segmentation pipeline to quickly infer a semantic segmentation model. Let’s take a look at the example image.

The model we will use is NVIDIA'S SegFormer: nvidia/segformer-b1-finetuned-cityscapes-1024-1024.

In [None]:
# Get the image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/segmentation_input.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
# Get segmentation pipeline output results
semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
results = semantic_segmentation(image)
results

In [None]:
# Taking a look at the mask for the building class, we can see every building is classified with the same mask.
labels = [seg_dict['label'] for seg_dict in results]
required_label = 'building'
results[labels.index(required_label)]["mask"]

## Instance Segmentation

In instance segmentation, the goal is not to classify every pixel, but to predict a mask for every instance of an object in a given image. It works very similar to object detection, where there is a bounding box for every instance, there’s a segmentation mask instead. 

We will use Facebook's facebook/mask2former-swin-large-cityscapes-instance for this.

In [None]:
instance_segmentation = pipeline("image-segmentation", "facebook/mask2former-swin-large-cityscapes-instance")
results = instance_segmentation(image)
results

In [None]:
# Check out one of the car instances
results[2]["mask"]

## Panoptic Segmentation

Panoptic segmentation combines semantic segmentation and instance segmentation, where every pixel is classified into a class and an instance of that class, and there are multiple masks for each instance of a class. We'll use Facebook's facebook/mask2former-swin-large-cityscapes-panoptic for panoptic segmentation.

In [None]:
panoptic_segmentation = pipeline("image-segmentation", "facebook/mask2former-swin-large-cityscapes-panoptic")
results = panoptic_segmentation(image)

# Result show we have more classes. 
# We will later illustrate to see that every pixel is classified into one of the classes.
results

# Fine-tuning a model for Semantic Segmentation

Seeing all types of segmentation, let’s have a deep dive on fine-tuning a model for semantic segmentation. We will now:

a. Finetune SegFormer on the SceneParse150 dataset.<br>
b. Use the fine-tuned model for inference.

Common real-world applications of semantic segmentation include training self-driving cars to identify pedestrians and important traffic information, identifying cells and abnormalities in medical imagery, and monitoring environmental changes from satellite imagery.

In [None]:
# Load SceneParse150 dataset
ds = load_dataset("scene_parse_150", split="train[:50]") # Load subset first, for experimentation

ds = ds.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
# Inspect the data set
# image: a PIL image of the scene.
# annotation: a PIL image of the segmentation map, which is also the model’s target.
# scene_category: a category id that describes the image scene like “kitchen” or “office”.
train_ds[0]

In [None]:
# In this guide, you’ll only need image and annotation, both of which are PIL images.
train_ds[0]["image"]

In [None]:
# Create a dictionary that maps a label id to a label class
# Download the mappings from the Hub and create the id2label and label2id dictionaries:
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

## Preprocess

In [None]:
# Load a SegFormer image processor to prepare the images and annotations for the model
# Some datasets, like this one, use the zero-index as the background class
# However, the background class isn’t actually included in the 150 classes...
# so you’ll need to set do_reduce_labels=True to subtract one from all the labels
# The zero-index is replaced by 255 so it’s ignored by SegFormer’s loss function

checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)

In [None]:
# It is common to apply some data augmentations to an image dataset to make a model more robust against overfitting
# In this guide, you’ll use the ColorJitter function from torchvision to randomly change the color properties of an image

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

In [None]:
# We now need to create two preprocessing functions to prepare the images and annotations for the model
# These functions convert the images into pixel_values and annotations to labels

# For the training set, jitter is applied before providing the images to the image processor
def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs

# For the test set, the image processor crops and normalizes the images, and correspondingly crops the labels
# NB: No data augmentation is applied during testing
def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs

In [None]:
# To apply the jitter over the entire dataset, use the Datasets set_transform function
# The transform is applied on the fly which is faster and consumes less disk space
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

## Evaluation

In [None]:
# Load the mean Intersection over Union (IoU) metric 
metric = evaluate.load("mean_iou")

In [None]:
# Create a function to compute the metrics
# Predictions need to be converted to logits first, and then reshaped to match the size of the labels 
def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if isinstance(value, np.ndarray):
                metrics[key] = value.tolist()
        return metrics

## Training

In [None]:
# Load SegFormer with AutoModelForSemanticSegmentation
# Pass the mapping between label ids and label classes
model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

In [None]:
# Final 3 steps: define hyperparameters, instantiate Trainer, then call train on the instantiated object
# It's important you don’t remove unused columns because this’ll drop the image column
training_args = TrainingArguments(
    output_dir="segformer_scene_parse_150_model",
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False
)

# Pass the training arguments to Trainer...
# along with the model, dataset, tokenizer, data collator, and compute_metrics function
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

## Inference

In [None]:
# Reload the dataset and load an image for inference
ds = load_dataset("scene_parse_150", split="train[:50]")
ds = ds.train_test_split(test_size=0.2)
test_ds = ds["test"]
image = ds["test"][10]["image"]
image

In [None]:
# Inference without a pipeline

# Process the image with an image processor 
encoding = image_processor(image, return_tensors="pt")

# Place the pixel_values on a GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pixel_values = encoding.pixel_values.to(device)

# Pass your input to the model and return the logits
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()

# Next, rescale the logits to the original image size
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]


In [None]:
# To visualize the results, load the dataset color palette as ade_palette()
# It maps each class to their RGB values
def ade_palette():
    return np.asarray([
      [0, 0, 0],
      [120, 120, 120],
      [180, 120, 120],
      [6, 230, 230],
      [80, 50, 50],
      [4, 200, 3],
      [120, 120, 80],
      [140, 140, 140],
      [204, 5, 255],
      [230, 230, 230],
      [4, 250, 7],
      [224, 5, 255],
      [235, 255, 7],
      [150, 5, 61],
      [120, 120, 70],
      [8, 255, 51],
      [255, 6, 82],
      [143, 255, 140],
      [204, 255, 4],
      [255, 51, 7],
      [204, 70, 3],
      [0, 102, 200],
      [61, 230, 250],
      [255, 6, 51],
      [11, 102, 255],
      [255, 7, 71],
      [255, 9, 224],
      [9, 7, 230],
      [220, 220, 220],
      [255, 9, 92],
      [112, 9, 255],
      [8, 255, 214],
      [7, 255, 224],
      [255, 184, 6],
      [10, 255, 71],
      [255, 41, 10],
      [7, 255, 255],
      [224, 255, 8],
      [102, 8, 255],
      [255, 61, 6],
      [255, 194, 7],
      [255, 122, 8],
      [0, 255, 20],
      [255, 8, 41],
      [255, 5, 153],
      [6, 51, 255],
      [235, 12, 255],
      [160, 150, 20],
      [0, 163, 255],
      [140, 140, 140],
      [250, 10, 15],
      [20, 255, 0],
      [31, 255, 0],
      [255, 31, 0],
      [255, 224, 0],
      [153, 255, 0],
      [0, 0, 255],
      [255, 71, 0],
      [0, 235, 255],
      [0, 173, 255],
      [31, 0, 255],
      [11, 200, 200],
      [255, 82, 0],
      [0, 255, 245],
      [0, 61, 255],
      [0, 255, 112],
      [0, 255, 133],
      [255, 0, 0],
      [255, 163, 0],
      [255, 102, 0],
      [194, 255, 0],
      [0, 143, 255],
      [51, 255, 0],
      [0, 82, 255],
      [0, 255, 41],
      [0, 255, 173],
      [10, 0, 255],
      [173, 255, 0],
      [0, 255, 153],
      [255, 92, 0],
      [255, 0, 255],
      [255, 0, 245],
      [255, 0, 102],
      [255, 173, 0],
      [255, 0, 20],
      [255, 184, 184],
      [0, 31, 255],
      [0, 255, 61],
      [0, 71, 255],
      [255, 0, 204],
      [0, 255, 194],
      [0, 255, 82],
      [0, 10, 255],
      [0, 112, 255],
      [51, 0, 255],
      [0, 194, 255],
      [0, 122, 255],
      [0, 255, 163],
      [255, 153, 0],
      [0, 255, 10],
      [255, 112, 0],
      [143, 255, 0],
      [82, 0, 255],
      [163, 255, 0],
      [255, 235, 0],
      [8, 184, 170],
      [133, 0, 255],
      [0, 255, 92],
      [184, 0, 255],
      [255, 0, 31],
      [0, 184, 255],
      [0, 214, 255],
      [255, 0, 112],
      [92, 255, 0],
      [0, 224, 255],
      [112, 224, 255],
      [70, 184, 160],
      [163, 0, 255],
      [153, 0, 255],
      [71, 255, 0],
      [255, 0, 163],
      [255, 204, 0],
      [255, 0, 143],
      [0, 255, 235],
      [133, 255, 0],
      [255, 0, 235],
      [245, 0, 255],
      [255, 0, 122],
      [255, 245, 0],
      [10, 190, 212],
      [214, 255, 0],
      [0, 204, 255],
      [20, 0, 255],
      [255, 255, 0],
      [0, 153, 255],
      [0, 41, 255],
      [0, 255, 204],
      [41, 0, 255],
      [41, 255, 0],
      [173, 0, 255],
      [0, 245, 255],
      [71, 0, 255],
      [122, 0, 255],
      [0, 255, 184],
      [0, 92, 255],
      [184, 255, 0],
      [0, 133, 255],
      [255, 214, 0],
      [25, 194, 194],
      [102, 255, 0],
      [92, 0, 255],
  ])

In [None]:
# Then combine and plot image + predicted segmentation map
color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

# Using the model on a custom dataset


In [None]:
# You could also create and use your own dataset if you prefer
# You can train using the run_semantic_segmentation.py script instead of a notebook instance
# The script requires 2 things
# 1. a DatasetDict with two Image columns, “image” and “label”, and...
# 2. an id2label dictionary mapping the class integers to their class names

In [None]:

# Example DatasetDict with two Image columns, “image” and “label”
from datasets import Dataset, DatasetDict, Image

image_paths_train = ["path/to/image_1.jpg/jpg", "path/to/image_2.jpg/jpg", ..., "path/to/image_n.jpg/jpg"]
label_paths_train = ["path/to/annotation_1.png", "path/to/annotation_2.png", ..., "path/to/annotation_n.png"]

image_paths_validation = [...]
label_paths_validation = [...]

def create_dataset(image_paths, label_paths):
    dataset = Dataset.from_dict({"image": sorted(image_paths),
                                "label": sorted(label_paths)})
    dataset = dataset.cast_column("image", Image())
    dataset = dataset.cast_column("label", Image())
    return dataset

# step 1: create Dataset objects
train_dataset = create_dataset(image_paths_train, label_paths_train)
validation_dataset = create_dataset(image_paths_validation, label_paths_validation)

# step 2: create DatasetDict
dataset = DatasetDict({
     "train": train_dataset,
     "validation": validation_dataset,
     }
)

# step 3: push to Hub (assumes you have ran the huggingface-cli login command in a terminal/notebook)
dataset.push_to_hub("your-name/dataset-repo")

# optionally, you can push to a private repo on the Hub
# dataset.push_to_hub("name of repo on the hub", private=True)

In [None]:

# Example id2label dictionary mapping the class integers to their class names
import json
# simple example
id2label = {0: 'cat', 1: 'dog'}
with open('id2label.json', 'w') as fp:
json.dump(id2label, fp)