In [1]:
CUDA_VISIBLE_DEVICES=-1

In [8]:
import torch
import torchvision

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model.eval()
    
    @staticmethod
    def get_lengths(img : torch.Tensor):
        h, w = img.size(2), img.size(3)
        longest_side = torch.max(torch.tensor([h, w], dtype=torch.short).detach())
        resize_value = torch.ceil(longest_side / 32) * 32
        return h, w, resize_value.int().item()
    
    @staticmethod
    def preprocess(img):
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to('cpu')
        img = img.permute(0,3,1,2)
        img = img.float()  # uint8 to fp16/32
        h, w, resize_value = Wrapper.get_lengths(img)
        padding = torch.zeros((1, 3, resize_value, resize_value))
        padding[:, :, :h, :w] = img
        padding /= 255  # 0 - 255 to 0.0 - 1.0
        return padding
    
    @staticmethod
    def xywh2xyxy(x):
        y = x.clone()
        y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
        y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
        y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
        y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
        return y
    
    @staticmethod
    def _non_max_suppression(pred, orig_img, conf_threshold=0.5, iou_threshold=0.4, max_det=300):
        pred.squeeze_()
        boxes, scores, cls = pred[:4, :].T, pred[4:, :].amax(0), pred[4:, :].argmax(0).to(torch.int)
        keep = scores.argsort(0, descending=True)[:max_det]
        boxes, scores, cls = boxes[keep], scores[keep], cls[keep]
        boxes = Wrapper.xywh2xyxy(boxes)
        candidate_idx = torch.arange(0, scores.shape[0])
        candidate_idx = candidate_idx[scores > conf_threshold]

        boxes, scores, cls = boxes[candidate_idx], scores[candidate_idx], cls[candidate_idx]
        final_idx = torchvision.ops.nms(boxes, scores, iou_threshold=iou_threshold)

        boxes = boxes[final_idx]
        scores = scores[final_idx]
        cls = cls[final_idx]

        boxes[:, [0,2]] = boxes[:, [0,2]].clamp(min=0, max=orig_img.size(2)) # width for x 
        boxes[:, [1,3]] = boxes[:, [1,3]].clamp(min=0, max=orig_img.size(1)) # height for y
                
        return torch.cat([boxes, scores.unsqueeze(1), cls.unsqueeze(1)], dim=1)

    @staticmethod
    def postprocess(pred, orig_img):
        result = Wrapper._non_max_suppression(pred, orig_img)
        return result

    def forward(self, imgs):
        orig_img = imgs.clone()
        imgs = Wrapper.preprocess(imgs)
        preds = self.model(imgs)
        result = Wrapper.postprocess(preds[0], orig_img)
        return result

In [9]:
path = 'data/samples/image1.jpg'

In [10]:
from ultralytics import YOLO
from torchvision.io import read_image

image = read_image(path, mode=torchvision.io.ImageReadMode.RGB).unsqueeze(0).permute(0,2,3,1)
yolo = YOLO("yolov8n.pt", task='detect')
model = yolo.model

In [11]:
wrapped = Wrapper(model)

In [12]:
from time import time

start = time()
result = wrapped(image)

print(time() - start)
result

0.30008721351623535


tensor([[2.2067e+02, 3.9544e+02, 3.4574e+02, 8.6122e+02, 8.5562e-01, 0.0000e+00],
        [6.6942e+02, 4.0530e+02, 8.0967e+02, 8.7953e+02, 8.4762e-01, 0.0000e+00],
        [0.0000e+00, 2.1678e+02, 8.1000e+02, 7.5988e+02, 7.9169e-01, 5.0000e+00],
        [5.4363e+01, 4.0058e+02, 2.0823e+02, 8.9836e+02, 7.8034e-01, 0.0000e+00],
        [0.0000e+00, 5.5170e+02, 6.6415e+01, 8.7339e+02, 5.7889e-01, 0.0000e+00]])

In [13]:
dynamic = {}
dynamic['image'] = {1 : 'height', 2 : 'width'} # Input shape: (1, H, W, 3)
dynamic['output'] = {0 : 'num_boxes'} # Output shape: (N, 6)

torch.onnx.export(
    wrapped, 
    image, 
    'wrapped_model.onnx',
    input_names=['image'],
    output_names=['output'],
    dynamic_axes=dynamic if dynamic else None,
    opset_version=17
)

  longest_side = torch.max(torch.tensor([h, w], dtype=torch.short).detach())
  return h, w, resize_value.int().item()
  elif self.dynamic or self.shape != shape:


verbose: False, log level: Level.ERROR



In [14]:
import onnxruntime as ort

sess = ort.InferenceSession('wrapped_model.onnx')
inputs = sess.get_inputs()
outputs = sess.get_outputs()

In [15]:
path = 'data/samples/image2.jpg'

In [16]:
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
from time import time

image = read_image(path, mode=torchvision.io.ImageReadMode.RGB).unsqueeze(0)
result = sess.run(
    None,
    {inputs[0].name : image.permute(0,2,3,1).numpy()}
)
result

[array([[     132.98,      41.823,      671.31,      684.01,     0.73314,           0],
        [     52.443,       730.7,       158.1,         840,     0.50371,           0]], dtype=float32)]

In [17]:
def get_yolo_labels_for_image(result):
    indices = list(result[0][:, 5].astype('int'))
    labels = []
    for k in indices:
        labels.append(yolo.names[k])
    return labels
labels = get_yolo_labels_for_image(result)
labels

['person', 'person']

In [18]:
plotted = draw_bounding_boxes(image.squeeze(), torch.from_numpy(result[0][:, :4]), width=4, labels=labels)

In [19]:
from pathlib import Path
import imageio.v3 as iio

iio.imwrite(f'{Path(path).name}', plotted.permute(1,2,0).numpy())