In [1]:
import torchvision.transforms as transforms
import cv2
import numpy as np
import torchvision
import torch
from PIL import Image

In [2]:
coco_names = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', #'watermelon', 'tree'
]

In [3]:
COLORS = np.random.uniform(0, 255, size=(len(coco_names), 3))

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [5]:
def predict(image, model, device, detection_threshold):
    
    image = transform(image).to(device)
    image = image.unsqueeze(0) 
    outputs = model(image) 
    
    pred_classes = [coco_names[i] for i in outputs[0]['labels'].cpu().numpy()]
    
    pred_scores = outputs[0]['scores'].detach().cpu().numpy()
   
    pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy()
    
    boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32)
    
    return boxes, pred_classes, outputs[0]['labels']

In [6]:
def draw_boxes(boxes, classes, labels, image):
    
    image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
    for i, box in enumerate(boxes):
        color = COLORS[labels[i]]
        cv2.rectangle(
            image,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 2
        )
        cv2.putText(image, classes[i], (int(box[0]), int(box[1]-5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, 
                    lineType=cv2.LINE_AA)
    return image

In [7]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, 
                                                    min_size=1024)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


  0%|          | 0.00/160M [00:00<?, ?B/s]

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
import os
paths = os.listdir("/content/inputs")
paths = paths[31:]
print(paths)

['v7w_107947.jpg', 'v7w_107997.jpg', 'v7w_107901.jpg', 'v7w_150370.jpg', 'v7w_150394.jpg', 'v7w_150362.jpg', 'v7w_150258.jpg', 'v7w_150367.jpg', 'v7w_61524.jpg', 'v7w_150408.jpg', 'v7w_107932.jpg', 'v7w_150356.jpg', 'v7w_61544.jpg', 'v7w_107941.jpg', 'v7w_107907.jpg', 'v7w_150368.jpg', 'v7w_61593.jpg', 'v7w_107977.jpg', 'v7w_150282.jpg', 'v7w_150326.jpg', 'v7w_61526.jpg', 'v7w_150332.jpg', 'v7w_107994.jpg', 'v7w_61543.jpg', 'v7w_61603.jpg', 'v7w_150287.jpg', 'v7w_150355.jpg', 'v7w_150405.jpg', 'v7w_61583.jpg', 'v7w_150411.jpg', 'v7w_61556.jpg', 'v7w_150262.jpg', 'v7w_150372.jpg', 'v7w_150309.jpg', 'v7w_107988.jpg', 'v7w_150407.jpg', 'v7w_150349.jpg', 'v7w_61549.jpg', 'v7w_61581.jpg', 'v7w_150290.jpg', 'v7w_107913.jpg', 'v7w_150264.jpg', 'v7w_107998.jpg', 'v7w_150305.jpg', 'v7w_150391.jpg', 'v7w_150268.jpg', 'v7w_61585.jpg', 'v7w_107918.jpg', 'v7w_61607.jpg', 'v7w_107959.jpg', 'v7w_150263.jpg', 'v7w_61538.jpg', 'v7w_61531.jpg', 'v7w_150376.jpg', 'v7w_150298.jpg', 'v7w_107982.jpg', 'v7w_

In [20]:
from google.colab.patches import cv2_imshow

#Run inference
count = 30
for path in paths:
  image = Image.open(f"/content/inputs/{path}")
  model.eval().to(device)
  boxes, classes, labels = predict(image, model, device, 0.8)
  image = draw_boxes(boxes, classes, labels, image)
  #cv2_imshow(image)
  cv2.imwrite(f"/content/drive/MyDrive/11777/outputs/{count}.jpg", image)
  count += 1