In [1]:
CUDA_VISIBLE_DEVICES=-1

In [2]:
import torch
import torchvision

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model.eval()
    
    @staticmethod
    def get_lengths(img : torch.Tensor):
        h, w = img.size(2), img.size(3)
        longest_side = torch.max(torch.tensor([h, w], dtype=torch.short).detach())
        resize_value = torch.ceil(longest_side / 32) * 32
        return h, w, resize_value.int().item()
    
    @staticmethod
    def preprocess(img):
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to('cpu')
        img = img.permute(0,3,1,2)
        img = img.float()  # uint8 to fp16/32
        h, w, resize_value = Wrapper.get_lengths(img)
        padding = torch.zeros((1, 3, resize_value, resize_value))
        padding[:, :, :h, :w] = img
        padding /= 255  # 0 - 255 to 0.0 - 1.0
        return padding
    
    @staticmethod
    def xywh2xyxy(x):
        y = x.clone()
        y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
        y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
        y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
        y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
        return y
    
    @staticmethod
    def _non_max_suppression(pred, orig_img, conf_threshold=0.5, iou_threshold=0.4, max_det=300):
        pred.squeeze_()
        boxes, scores, cls = pred[:4, :].T, pred[4:, :].amax(0), pred[4:, :].argmax(0).to(torch.int)
        keep = scores.argsort(0, descending=True)[:max_det]
        boxes, scores, cls = boxes[keep], scores[keep], cls[keep]
        boxes = Wrapper.xywh2xyxy(boxes)
        candidate_idx = torch.arange(0, scores.shape[0])
        candidate_idx = candidate_idx[scores > conf_threshold]

        boxes, scores, cls = boxes[candidate_idx], scores[candidate_idx], cls[candidate_idx]
        final_idx = torchvision.ops.nms(boxes, scores, iou_threshold=iou_threshold)

        boxes = boxes[final_idx]
        scores = scores[final_idx]
        cls = cls[final_idx]

        boxes[:, [0,2]] = boxes[:, [0,2]].clamp(min=0, max=orig_img.size(2)) # width for x 
        boxes[:, [1,3]] = boxes[:, [1,3]].clamp(min=0, max=orig_img.size(1)) # height for y
                
        return torch.cat([boxes, scores.unsqueeze(1), cls.unsqueeze(1)], dim=1)

    @staticmethod
    def postprocess(pred, orig_img):
        result = Wrapper._non_max_suppression(pred, orig_img)
        return result

    def forward(self, imgs):
        orig_img = imgs.clone()
        imgs = Wrapper.preprocess(imgs)
        preds = self.model(imgs)
        result = Wrapper.postprocess(preds[0], orig_img)
        return result

In [3]:
from ultralytics import YOLO
import torch

image = torch.randint(0, 255, (1, 1080, 810,3), dtype=torch.uint8)
yolo = YOLO("yolov8m.pt", task='detect')
model = yolo.model

In [4]:
wrapped = Wrapper(model)

In [5]:
from time import time

start = time()
result = wrapped(image)

print(time() - start)
result

1.53065824508667


tensor([], size=(0, 6))

In [6]:
dynamic = {}
dynamic['image'] = {1 : 'height', 2 : 'width'} # Input shape: (1, H, W, 3)
dynamic['output'] = {0 : 'num_boxes'} # Output shape: (N, 6)

torch.onnx.export(
    wrapped, 
    image, 
    'wrapped_model.onnx',
    input_names=['image'],
    output_names=['output'],
    dynamic_axes=dynamic if dynamic else None,
    opset_version=17
)

  longest_side = torch.max(torch.tensor([h, w], dtype=torch.short).detach())
  return h, w, resize_value.int().item()
  elif self.dynamic or self.shape != shape:


verbose: False, log level: Level.ERROR



In [9]:
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import draw_bounding_boxes, save_image
import onnxruntime as ort
from pathlib import Path
from typing import Optional
from tqdm import tqdm

class ONNXSession:
    def __init__(self, model : str, data : str) -> None:
        self.model = model
        self.data = Path(data)
    
    @staticmethod
    def _get_coco_labels_for_boxes(result):
        indices = list(result[0][:, 5].astype('int'))
        labels = []
        for k in indices:
            labels.append(yolo.names[k])
        return labels
    
    def run(self, num_examples : Optional[int] = None):
        if self.data.is_file():
            paths = [self.data]
        elif self.data.is_dir():
            paths = sorted([path for path in self.data.iterdir()])

        for path in tqdm(paths[:num_examples]):
            img = read_image(str(path), ImageReadMode.RGB)
            img_copy = img.clone()
            img = img.permute(1,2,0).unsqueeze(0)

            sess = ort.InferenceSession('wrapped_model.onnx')
            inputs = sess.get_inputs()

            result = sess.run(
                None,
                {inputs[0].name : img.numpy()}
            )
            # Create directory for results
            results_dir = Path('results/')
            if not results_dir.exists():
                results_dir.mkdir(parents=True, exist_ok=True)
            # Extract labels from class indices
            labels = ONNXSession._get_coco_labels_for_boxes(result)
            # Slice for boxes
            boxes = torch.from_numpy(result[0][:, :4])
            if boxes.numel() > 0:
                img_copy = draw_bounding_boxes(img_copy, boxes, width=4, labels=labels)
            save_image(img_copy / 255.0, f'{results_dir}/{path.name}')

In [10]:
sess = ONNXSession('wrapped_model.onnx', 'data/coco128/images/train2017/')
sess.run()

100%|██████████| 128/128 [03:30<00:00,  1.64s/it]
