# Panoptic Segmentation

In panoptic segmentation the image is segmented into known objects (things) but also considering amorphous regions (stuff).

We follow these steps:

1. Load a pre-trained model
1. Test the model in one image
1. Show the segmentation result
1. Use Detectron2 utils to show more information about the results

Based on: https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/DETR_panoptic.ipynb

In [None]:
! pip install git+https://github.com/cocodataset/panopticapi.git

In [1]:
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import io
import math
import torch
import panopticapi
import requests
import torchvision.transforms as T

ModuleNotFoundError: No module named 'panopticapi'

### Load a pre-trained model

In this case we will be using the DETR pre-trained model. It is based on a transformer architecture (https://github.com/facebookresearch/detr).

In [None]:
model, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_resnet101_panoptic', pretrained=True, return_postprocessor=True, num_classes=250)
model.eval()

### Test the model in one image

In [None]:
#url = "http://images.cocodataset.org/val2017/000000281759.jpg"
#url = "http://images.cocodataset.org/val2017/000000289222.jpg"
#url = "http://images.cocodataset.org/val2017/000000439715.jpg"
url = "http://images.cocodataset.org/val2017/000000324158.jpg"

im = Image.open(requests.get(url, stream=True).raw)
im

In [None]:
# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# mean-std normalize the input image (batch-size: 1)
# and pass it to the model
img = transform(im).unsqueeze(0)
with torch.no_grad():
  out = model(img)

In [None]:
# let's explore what the model outputs
print(out.keys())
print(out['pred_logits'].shape)
print(out['pred_boxes'].shape)
print(out['pred_masks'].shape)

In [None]:
# compute the scores, excluding the "no-object" class (the last one)
scores = out['pred_logits'].softmax(dim=-1)[..., :-1].max(dim=-1)[0]
# threshold the confidence
keep = scores > 0.85

In [None]:
# plot the masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:
    for a in line:
        a.axis('off')
for i, mask in enumerate(out["pred_masks"][keep]):
    ax = axs[i // ncols, i % ncols]
    ax.imshow(mask.detach().numpy(), cmap="cividis")
    ax.axis('off')
fig.tight_layout()

### Show the segmentation result

In [None]:
# the post-processor expects as input the target size of the predictions (which we set here to the image size)
result = postprocessor(out, torch.as_tensor(img.shape[-2:]).unsqueeze(0))[0]

In [None]:
import itertools
import seaborn as sns
from panopticapi.utils import id2rgb, rgb2id
palette = itertools.cycle(sns.color_palette())

# The segmentation is stored in a special-format png
panoptic_seg = Image.open(io.BytesIO(result['png_string']))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8).copy()
# We retrieve the ids corresponding to each mask
panoptic_seg_id = rgb2id(panoptic_seg)

# Finally we color each mask individually
panoptic_seg[:, :, :] = 0
for id in range(panoptic_seg_id.max() + 1):
  panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
plt.figure(figsize=(15,15))
plt.imshow(panoptic_seg)
plt.axis('off')
plt.show()

### Use Detectron2 utils to show more information

The function createResultsImage can be used to show further information.

In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
# function to create an image with the segments overlayed in the image
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from copy import deepcopy
def createResultsImage(im, result):
    # We extract the segments info and the panoptic result from DETR's prediction
    segments_info = deepcopy(result["segments_info"])
    # Panoptic predictions are stored in a special format png
    panoptic_seg = Image.open(io.BytesIO(result['png_string']))
    final_w, final_h = panoptic_seg.size
    # We convert the png into an segment id map
    panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
    panoptic_seg = torch.from_numpy(rgb2id(panoptic_seg))

    # Detectron2 uses a different numbering of coco classes, here we convert the class ids accordingly
    meta = MetadataCatalog.get("coco_2017_val_panoptic_separated")
    for i in range(len(segments_info)):
        c = segments_info[i]["category_id"]
        segments_info[i]["category_id"] = meta.thing_dataset_id_to_contiguous_id[c] if segments_info[i]["isthing"] else meta.stuff_dataset_id_to_contiguous_id[c]

    # Finally we visualize the prediction
    v = Visualizer(np.array(im.copy().resize((final_w, final_h)))[:, :, ::-1], meta, scale=1.0)
    v._default_font_size = 20
    v = v.draw_panoptic_seg_predictions(panoptic_seg, segments_info, area_threshold=0)

    return Image.fromarray(v.get_image())

In [None]:
extendedResults = createResultsImage(im, result)
extendedResults