In [5]:
import os
import sys

sys.path.append(os.path.join(os.path.abspath(''), '..'))

In [6]:
import torch

weights = "yolov8l-seg.pt"
device = torch.device('cpu')

In [8]:
import os
import sys
import random
import argparse
import warnings
import onnx
import torch
import torch.nn as nn
from copy import deepcopy

from ultralytics import YOLO
from ultralytics.utils.torch_utils import select_device
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder


class WarpModel(nn.Module):
    def __init__(self, weights: str = "yolov8l-seg.pt", nc: int = 80):
        super().__init__()
        self.nc = nc

        self.model = YOLO(weights)
        self.model = deepcopy(self.model.model).to(device)
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()
        self.model.float()
        self.model = self.model.fuse()
        for k, m in self.model.named_modules():
            if isinstance(m, (Detect, RTDETRDecoder)):
                m.dynamic = False
                m.export = True
                m.format = "onnx"
            elif isinstance(m, C2f):
                m.forward = m.forward_split

    def forward(self, x):
        preds, protos = self.model(x)
        preds = preds.permute((0, 2, 1))

        boxes = preds[:, :, :4]
        classes = preds[:, :, 4 : self.nc + 4]
        masks = preds[:, :, self.nc + 4 :]

        return (
            torch.cat(
                (
                    boxes,
                    torch.ones(
                        (boxes.shape[0], boxes.shape[1], 1),
                        device=boxes.device,
                        dtype=boxes.dtype,
                    ),
                    classes,
                    masks,
                ),
                dim=2,
            ),
            protos,
        )


model = WarpModel("yolov8l-seg.pt", 80)
with torch.no_grad():
    out = model(torch.randn((4, 3, 640, 640)))
out[0].shape, out[1].shape


YOLOv8l-seg summary (fused): 295 layers, 45973568 parameters, 0 gradients, 220.5 GFLOPs


(torch.Size([4, 8400, 117]), torch.Size([4, 32, 160, 160]))