<a href="https://colab.research.google.com/github/linhle32/Interactive-Models-with-Widget/blob/main/object_detection_with_detectron2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Object Detection on Images

Object detection is more complicated than image classification. A model usually must do two tasks:
- **segmentation** - determine groups of pixels in an image that belong to the same objects
- **classification** - determine the class of each found object in the image

Due to the complexity of the problem, models for this task are usually very big and not suitable for us to try without proper computational resources. Instead, we will use **pretrained models** - those that has previously undergone training and are for the public to use.

In this example, we will use the `detectron2` model which was developed by Facebook research. First, we need to copy the package from GitHub and install it. Due to an issue with the current version, we will use an older `detectron2`.

This notebook can be used as-is, you do not need to modify anything. At the end, I added a small and simple GUI application to interact with the model.

In [None]:
!git clone https://github.com/facebookresearch/detectron2.git
!pip install 'git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13'

The code below is adapted from the demo example of `detectron2` for images. I mainly removed the unnecessary codes and functions to keep it minimal.

In [None]:
import argparse
import glob
import multiprocessing as mp
import numpy as np
import os
import tempfile
import time
import warnings
import torch
from google.colab import files
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
from detectron2.data import MetadataCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
import matplotlib.pyplot as plt
from IPython.display import clear_output

def setup_cfg(config_file, opts, confidence_threshold):
    cfg = get_cfg()
    cfg.merge_from_file(config_file)
    cfg.merge_from_list(opts)
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = confidence_threshold
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold
    cfg.freeze()
    return cfg

def run_on_image(predictor, image, metadata):
    vis_output = None
    predictions = predictor(image)
    image = image[:, :, ::-1]
    visualizer = Visualizer(image, metadata, instance_mode=ColorMode.IMAGE)
    if "sem_seg" in predictions:
        vis_output = visualizer.draw_sem_seg(
            predictions["sem_seg"].argmax(dim=0).to(torch.device("cpu"))
        )
    if "instances" in predictions:
        instances = predictions["instances"].to(torch.device("cpu"))
        vis_output = visualizer.draw_instance_predictions(predictions=instances)

    return predictions, vis_output

mp.set_start_method("spawn", force=True)
cfgfile = 'detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'
opts = ['MODEL.WEIGHTS','detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl']
cfg = setup_cfg(cfgfile, opts, 0.7)
predictor = DefaultPredictor(cfg)
metadata = MetadataCatalog.get(
    cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)

model_final_f10217.pkl: 178MB [00:00, 184MB/s]                           


# Application

Finally, we launch our application on top of the `detectron2` model as below.

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import io
from PIL import Image

button_predict = widgets.Button(description="Detect")
uploader = widgets.FileUpload(multiple=False)
output = widgets.Output()
display(button_predict, uploader, output)


@output.capture()
def on_predict_clicked(b):
  output.clear_output()
  try:
    image = Image.open(io.BytesIO(list(uploader.value.values())[0]['content']))
    image = np.array(image.convert("RGB"))
    predictions, visualized_output = run_on_image(predictor, image, metadata)
    plt.figure(figsize=(8,8))
    plt.imshow(visualized_output.get_image()[:, :, ::-1])
    plt.show()
  except:
    print('Please upload an image first')

button_predict.on_click(on_predict_clicked)

Button(description='Detect', style=ButtonStyle())

FileUpload(value={}, description='Upload')

Output()