# 🎗️ **OneFormer: One Transformer to Rule Universal Image Segmentation** 🎗️

## CVPR 2023

[[`Project Page`](https://praeclarumjj3.github.io/oneformer/)] [[`arXiv`](https://arxiv.org/abs/2211.06220)] [[`GitHub`](https://github.com/SHI-Labs/OneFormer)] [[`HuggingFace Space`](https://huggingface.co/spaces/shi-labs/OneFormer)] [[`HuggingFace transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/oneformer)]

![Teaser](https://praeclarumjj3.github.io/oneformer/teaser.svg)

#### OneFormer is the **first** multi-task universal image segmentation framework based on transformers. OneFormer needs to be trained only once with a single universal architecture, a single model, and on a single dataset , to outperform existing frameworks across semantic, instance, and panoptic segmentation tasks!



This notebook provides a quickstart guide to using OneFormer for inference on images. We hope OneFormer inspires more research into developing train-once universal image segmentation frameworks. ✌

#### If you found OneFormer useful in your research, please consider starring ⭐ us on [[`GitHub`](https://github.com/SHI-Labs/OneFormer)] and citing 📚 us in your research!

# Setup OneFormer Project

In [None]:
######
#@title 1. Clone OneFormer Repo
######
%cd /content/
!rm -rf OneFormer/
!git clone https://github.com/SHI-Labs/OneFormer-Colab.git
! mv OneFormer-Colab OneFormer
%cd /content/OneFormer/

In [None]:
######
#@title 2. Install Dependencies. 
#@markdown It may take several minutes for all installations to finish.
######

# # Install opencv (required for running the demo)
!pip3 install -U opencv-python --quiet
!pip3 install natten -f https://shi-labs.com/natten/wheels/cu113/torch1.10.1/index.html --quiet

# # # Install other dependencies
!pip3 install git+https://github.com/cocodataset/panopticapi.git --quiet
!pip3 install git+https://github.com/mcordts/cityscapesScripts.git --quiet

!pip3 install -r requirements.txt --quiet
!pip3 install ipython-autotime --quiet
!pip3 install imutils --quiet

In [None]:
import sys, os, distutils.core
!git clone 'https://github.com/facebookresearch/detectron2'
dist = distutils.core.run_setup("./detectron2/setup.py")
!python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])} --quiet
sys.path.insert(0, os.path.abspath('./detectron2'))

In [None]:
######
#@title 3. Import Libraries and other Utilities
######
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
setup_logger(name="oneformer")

# Import libraries
import numpy as np
import cv2
import torch
from google.colab.patches import cv2_imshow
import imutils

# Import detectron2 utilities
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data import MetadataCatalog
from demo.defaults import DefaultPredictor
from demo.visualizer import Visualizer, ColorMode


# import OneFormer Project
from oneformer import (
    add_oneformer_config,
    add_common_config,
    add_swin_config,
    add_dinat_config,
    add_convnext_config,
)

In [None]:
######
#@title 4. Define helper functions
######
cpu_device = torch.device("cpu")
SWIN_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml",
            "coco": "configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml",
            "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",}

DINAT_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml",
            "coco": "configs/coco/oneformer_dinat_large_bs16_100ep.yaml",
            "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",}

def setup_cfg(dataset, model_path, use_swin):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_common_config(cfg)
    add_swin_config(cfg)
    add_dinat_config(cfg)
    add_convnext_config(cfg)
    add_oneformer_config(cfg)
    if use_swin:
      cfg_path = SWIN_CFG_DICT[dataset]
    else:
      cfg_path = DINAT_CFG_DICT[dataset]
    cfg.merge_from_file(cfg_path)
    cfg.MODEL.DEVICE = 'cpu'
    cfg.MODEL.WEIGHTS = model_path
    cfg.freeze()
    return cfg

def setup_modules(dataset, model_path, use_swin):
    cfg = setup_cfg(dataset, model_path, use_swin)
    predictor = DefaultPredictor(cfg)
    metadata = MetadataCatalog.get(
        cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
    )
    if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
        from cityscapesscripts.helpers.labels import labels
        stuff_colors = [k.color for k in labels if k.trainId != 255]
        metadata = metadata.set(stuff_colors=stuff_colors)
    
    return predictor, metadata

def panoptic_run(img, predictor, metadata):
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    predictions = predictor(img, "panoptic")
    panoptic_seg, segments_info = predictions["panoptic_seg"]
    out = visualizer.draw_panoptic_seg_predictions(
    panoptic_seg.to(cpu_device), segments_info, alpha=0.5
)
    return out

def instance_run(img, predictor, metadata):
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    predictions = predictor(img, "instance")
    instances = predictions["instances"].to(cpu_device)
    out = visualizer.draw_instance_predictions(predictions=instances, alpha=0.5)
    return out

def semantic_run(img, predictor, metadata):
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    predictions = predictor(img, "semantic")
    out = visualizer.draw_sem_seg(
        predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
    )
    return out

TASK_INFER = {"panoptic": panoptic_run, 
              "instance": instance_run, 
              "semantic": semantic_run}

# Run Inference using OneFormer on CPU


## ADE20K Dataset

In [None]:
######
#@markdown We use `DiNAT-L` as the default backbone. To use Swin-L as backbone, select the checkbox below.
use_swin = False #@param {type: 'boolean'}

In [None]:
######
#@title A. Initialize Model
######
# download model checkpoint
import os
import subprocess
if not use_swin:
  if not os.path.exists("250_16_dinat_l_oneformer_ade20k_160k.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/ade20k/250_16_dinat_l_oneformer_ade20k_160k.pth', shell=True)
  predictor, metadata = setup_modules("ade20k", "250_16_dinat_l_oneformer_ade20k_160k.pth", use_swin)
else:
  if not os.path.exists("250_16_swin_l_oneformer_ade20k_160k.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/ade20k/250_16_swin_l_oneformer_ade20k_160k.pth', shell=True)
  predictor, metadata = setup_modules("ade20k", "250_16_swin_l_oneformer_ade20k_160k.pth", use_swin)

In [None]:
######
#@title B. Display Sample Image. You can modify the path and try your own images!
######

# change path here for another image
img = cv2.imread("samples/ade20k.jpeg")
img = imutils.resize(img, width=640)
cv2_imshow(img)

In [None]:
######
#@title C. Run Inference (CPU)
#@markdown Specify the **task**. `Default: panoptic`. Execution may take upto 2 minutes
######
###### Specify Task Here ######
task = "panoptic" #@param
##############################
%load_ext autotime
out = TASK_INFER[task](img, predictor, metadata).get_image()
cv2_imshow(out[:, :, ::-1])

## Cityscapes Dataset

In [None]:
######
#@markdown We use `DiNAT-L` as the default backbone. To use Swin-L as backbone, select the checkbox below.
use_swin = False #@param {type: 'boolean'}

In [None]:
######
#@title A. Initialize Model
######
# download model checkpoint
import os
import subprocess
if not use_swin:
  if not os.path.exists("250_16_dinat_l_oneformer_cityscapes_90k.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/cityscapes/250_16_dinat_l_oneformer_cityscapes_90k.pth', shell=True)
  predictor, metadata = setup_modules("cityscapes", "250_16_dinat_l_oneformer_cityscapes_90k.pth", use_swin)
else:
  if not os.path.exists("250_16_swin_l_oneformer_cityscapes_90k.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/cityscapes/250_16_swin_l_oneformer_cityscapes_90k.pth', shell=True)
  predictor, metadata = setup_modules("cityscapes", "250_16_swin_l_oneformer_cityscapes_90k.pth", use_swin)

In [None]:
######
#@title B. Display Sample Image. You can modify the path and try your own images!
######

# change path here for another image
img = cv2.imread("samples/cityscapes.png")
img = imutils.resize(img, width=512)
cv2_imshow(img)

In [None]:
######
#@title C. Run Inference (CPU)
#@markdown Specify the **task**. `Default: panoptic`. Execution may take upto 2 minutes
######
task = "panoptic" #@param
%load_ext autotime
out = TASK_INFER[task](img, predictor, metadata).get_image()
cv2_imshow(out[:, :, ::-1])

## COCO Dataset

In [None]:
######
#@markdown We use `DiNAT-L` as the default backbone. To use Swin-L as backbone, select the checkbox below.
use_swin = False #@param {type: 'boolean'}

In [None]:
######
#@title A. Initialize Model
######
# download model checkpoint
import os
import subprocess
if not use_swin:
  if not os.path.exists("150_16_dinat_l_oneformer_coco_100ep.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth', shell=True)
  predictor, metadata = setup_modules("coco", "150_16_dinat_l_oneformer_coco_100ep.pth", use_swin)
else:
  if not os.path.exists("150_16_swin_l_oneformer_coco_100ep.pth"):
    subprocess.run('wget https://shi-labs.com/projects/oneformer/coco/150_16_swin_l_oneformer_coco_100ep.pth', shell=True)
  predictor, metadata = setup_modules("coco", "150_16_swin_l_oneformer_coco_100ep.pth", use_swin)

In [None]:
######
#@title B. Display Sample Image. You can modify the path and try your own images!
######

# change path here for another image
img = cv2.imread("samples/coco.jpeg")
img = imutils.resize(img, width=512)
cv2_imshow(img)

In [None]:
######
#@title C. Run Inference (CPU)
#@markdown Specify the **task**. `Default: panoptic`. Execution may take upto 2 minutes
######
task = "panoptic" #@param
%load_ext autotime
out = TASK_INFER[task](img, predictor, metadata).get_image()
cv2_imshow(out[:, :, ::-1])

# More Information on OneFormer 🎗️
- [Project Page](https://praeclarumjj3.github.io/oneformer/)
- [GitHub Repo](https://SHI-Labs/OneFormer)
- [ArXiv preprint](https://arxiv.org/abs/2211.06220)
- [HuggingFace Space](https://huggingface.co/spaces/shi-labs/OneFormer)
- [HuggingFace transformers](https://huggingface.co/docs/transformers/main/en/model_doc/oneformer)