<a href="https://colab.research.google.com/github/comapi5/yolox-notebooks/blob/main/predict_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
pip install yolox==0.3.0 logzero
wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth

In [None]:
import torch
from logzero import logger
from yolox.exp import get_exp
from yolox.utils import get_model_info

exp = get_exp(None, "yolox-x")
logger.info(f"exp name: {exp.exp_name}")

exp.test_conf = 0.25
exp.nmsthre = 0.45
exp.test_size = (640, 640)

model = exp.get_model()
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

model.cuda()
model.half()  # to FP16
model.eval()

logger.info("loading checkpoint.")
ckpt = torch.load("yolox_x.pth", map_location="cpu")
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")

In [None]:
import cv2
from yolox.data.data_augment import ValTransform

img_path = "/content/drive/MyDrive/Colab Notebooks/dog.jpg"
raw_img = cv2.imread(img_path)
logger.info(f"input img shape; {raw_img.shape}")

preproc = ValTransform(legacy=False)
img, _ = preproc(raw_img, None, exp.test_size)
ratio = min(exp.test_size[0] / raw_img.shape[0], exp.test_size[1] / raw_img.shape[1])
logger.info(f"after ValTransform img shape; {img.shape}")

img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
img = img.cuda()
img = img.half()

In [None]:
from yolox.utils import get_model_info, postprocess, vis
from yolox.data.datasets import COCO_CLASSES

with torch.no_grad():
  outputs = model(img)
  outputs = postprocess(
      outputs, 
      exp.num_classes, 
      exp.test_conf, 
      exp.nmsthre, 
      class_agnostic=True
      )

In [None]:
output = outputs[0].cpu()

bboxes = output[:, 0:4]

# resize
bboxes /= ratio

cls = output[:, 6]
scores = output[:, 4] * output[:, 5]

vis_res = vis(raw_img, bboxes, scores, cls, exp.test_conf, COCO_CLASSES)
logger.info(f"result shape: {vis_res.shape}")

# visualize
import numpy as np
from PIL import Image
vis_res = cv2.cvtColor(vis_res, cv2.COLOR_RGB2BGR)
Image.fromarray(np.uint8(vis_res))