In [5]:
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2 import model_zoo

import cv2
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Check versions

In [2]:
import torch

In [3]:
print(torch.__version__)

1.12.0+cu113


In [4]:
torch.cuda.is_available()

True

In [5]:
print(torch.version.cuda)

11.3


### Create Detector

In [9]:
class Detector:
    def __init__(self, model_type="OD"):
        self.cfg = get_cfg()
        
        #Load model config and pretrained model
        # Defines whether to do object detection/ segmentation
        if model_type== "OD": #object detection
            self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
            self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
        
        if model_type == "IS": #instance segmentation
            self.cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
            self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        
        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
        self.cfg.MODEL.DEVICE = "cuda" #cpu or cuda
        
        self.predictor = DefaultPredictor(self.cfg) 
        
    def onImage(self, imagePath):
        image = cv2.imread(imagePath)
        predictions = self.predictor(image)

        viz = Visualizer(image[:, :, ::-1], metadata = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]),
        scale=1.2) 
        #instance_mode = ColorMode.IMAGE_BW
        output = viz.draw_instance_predictions(predictions["instances"].to("cpu"))

        cv2.imshow("Results", output.get_image()[:,:,::-1])
        cv2.waitKey(0)
                         
                         
        cv2.imshow("image",image)
        cv2.waitKey(0)

In [12]:
detector = Detector(model_type="IS")

In [16]:
detector.onImage("image1.jpg") # Problem: Crashes