Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions integrations/ultralytics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ cp sparseml/integrations/ultralytics/deepsparse/*.py yolov5
cd yolov5

# install deepsparse and server dependencies
pip install deepsparse flask flask-cors
pip install deepsparse sparseml flask flask-cors
```

Note: on new Ubuntu systems, to install `cv2` running `sudo apt-get update && apt-get install -y python3-opencv`
may be necessary.


### Benchmarking
`benchmarking.py` is a script for benchmarking sparsified and quantized YOLOv3
Expand All @@ -130,8 +133,8 @@ performance with DeepSparse. For a full list of options run `python benchmarkin
To run a benchmark run:
```bash
python benchmark.py
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_90 \
--batch-size 32 \
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \
--batch-size 1 \
--quantized-inputs
```

Expand Down
7 changes: 5 additions & 2 deletions integrations/ultralytics/deepsparse/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ pip install deepsparse sparseml flask flask-cors

### Server

First, start up the host `server.py` with your model of choice.
First, start up the host `server.py` with your model of choice, SparseZoo stubs are
also supported.

Example command:
```bash
python server.py ~/models/yolov3-pruned_quant.onnx
python server.py \
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \
--quantized-inputs
```

You can leave that running as a detached process or in a spare terminal.
Expand Down
77 changes: 60 additions & 17 deletions integrations/ultralytics/deepsparse/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
usage: benchmark.py [-h] [-e {deepsparse,onnxruntime,torch}]
[--data-path DATA_PATH]
[--image-shape IMAGE_SHAPE [IMAGE_SHAPE ...]]
[-b BATCH_SIZE] [-c NUM_CORES] [-i NUM_ITERATIONS]
[-w NUM_WARMUP_ITERATIONS] [-q] [--fp16] [--device DEVICE]
[-b BATCH_SIZE] [-c NUM_CORES] [-s NUM_SOCKETS]
[-i NUM_ITERATIONS] [-w NUM_WARMUP_ITERATIONS] [-q]
[--fp16] [--device DEVICE]
model_filepath

Benchmark sparsified YOLOv3 models
Expand Down Expand Up @@ -53,7 +54,13 @@
The batch size to run the benchmark for
-c NUM_CORES, --num-cores NUM_CORES
The number of physical cores to run the benchmark on,
defaults to all physical cores available on the system
defaults to None where it uses all physical cores
available on the system. For DeepSparse benchmarks,
this value is the number of cores per socket
-s NUM_SOCKETS, --num-sockets NUM_SOCKETS
For DeepSparse benchmarks only. The number of physical
cores to run the benchmark on. Defaults to None where
is uses all sockets available on the system
-i NUM_ITERATIONS, --num-iterations NUM_ITERATIONS
The number of iterations the benchmark will be run for
-w NUM_WARMUP_ITERATIONS, --num-warmup-iterations NUM_WARMUP_ITERATIONS
Expand Down Expand Up @@ -110,14 +117,17 @@

from deepsparse import compile_model, cpu
from deepsparse.benchmark import BenchmarkResults
from deepsparse_utils import load_image, postprocess_nms, pre_nms_postprocess
from deepsparse_utils import (
YoloPostprocessor,
load_image,
modify_yolo_onnx_input_shape,
postprocess_nms,
)
from sparseml.onnx.utils import override_model_batch_size
from sparsezoo.models.detection import yolo_v3 as zoo_yolo_v3
from sparsezoo.utils import load_numpy_list


CORES_PER_SOCKET, AVX_TYPE, _ = cpu.cpu_details()

DEEPSPARSE_ENGINE = "deepsparse"
ORT_ENGINE = "onnxruntime"
TORCH_ENGINE = "torch"
Expand Down Expand Up @@ -180,10 +190,22 @@ def parse_args():
"-c",
"--num-cores",
type=int,
default=CORES_PER_SOCKET,
default=None,
help=(
"The number of physical cores to run the benchmark on, "
"defaults to all physical cores available on the system"
"defaults to None where it uses all physical cores available on the system. "
"For DeepSparse benchmarks, this value is the number of cores per socket"
),
)
parser.add_argument(
"-s",
"--num-sockets",
type=int,
default=None,
help=(
"For DeepSparse benchmarks only. The number of physical cores to run the "
"benchmark on. Defaults to None where is uses all sockets available on the "
"system"
),
)
parser.add_argument(
Expand Down Expand Up @@ -227,7 +249,6 @@ def parse_args():
)

args = parser.parse_args()

if args.engine == TORCH_ENGINE and args.device is None:
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -273,30 +294,46 @@ def _load_model(args) -> Any:
raise ValueError(f"half precision is not supported for {args.engine}")
if args.quantized_inputs and args.engine == TORCH_ENGINE:
raise ValueError(f"quantized inputs not supported for {args.engine}")
if args.num_cores != CORES_PER_SOCKET and args.engine == TORCH_ENGINE:
if args.num_cores is not None and args.engine == TORCH_ENGINE:
raise ValueError(
f"overriding default num_cores not supported for {args.engine}"
)
if (
args.num_cores != CORES_PER_SOCKET
args.num_cores is not None
and args.engine == ORT_ENGINE
and onnxruntime.__version__ < "1.7"
):
print(
raise ValueError(
"overriding default num_cores not supported for onnxruntime < 1.7.0. "
"If using an older build with OpenMP, try setting the OMP_NUM_THREADS "
"environment variable"
)
if args.num_sockets is not None and args.engine != DEEPSPARSE_ENGINE:
raise ValueError(f"Overriding num_sockets is not supported for {args.engine}")

# scale static ONNX graph to desired image shape
if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]:
args.model_filepath, _ = modify_yolo_onnx_input_shape(
args.model_filepath, args.image_shape
)

# load model
if args.engine == DEEPSPARSE_ENGINE:
print(f"Compiling deepsparse model for {args.model_filepath}")
model = compile_model(args.model_filepath, args.batch_size, args.num_cores)
model = compile_model(
args.model_filepath, args.batch_size, args.num_cores, args.num_sockets
)
if args.quantized_inputs and not model.cpu_vnni:
print(
"WARNING: VNNI instructions not detected, "
"quantization speedup not well supported"
)
elif args.engine == ORT_ENGINE:
print(f"loading onnxruntime model for {args.model_filepath}")

sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = args.num_cores
if args.num_cores is not None:
sess_options.intra_op_num_threads = args.num_cores
sess_options.log_severity_level = 3
sess_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
Expand Down Expand Up @@ -376,7 +413,7 @@ def _run_model(
def benchmark_yolo(args):
model = _load_model(args)
print("Loading dataset")
dataset = _load_images(args.data_path, args.image_shape)
dataset = _load_images(args.data_path, tuple(args.image_shape))
total_iterations = args.num_iterations + args.num_warmup_iterations
data_loader = _iter_batches(dataset, args.batch_size, total_iterations)

Expand All @@ -388,6 +425,12 @@ def benchmark_yolo(args):
flush=True,
)

postprocessor = (
YoloPostprocessor(args.image_shape)
if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]
else None
)

results = BenchmarkResults()
progress_bar = tqdm(total=args.num_iterations)

Expand All @@ -403,8 +446,8 @@ def benchmark_yolo(args):
outputs = _run_model(args, model, batch)

# post-processing
if args.engine != TORCH_ENGINE:
outputs = pre_nms_postprocess(outputs)
if postprocessor:
outputs = postprocessor.pre_nms_postprocess(outputs)

# NMS
outputs = postprocess_nms(outputs)
Expand Down
136 changes: 107 additions & 29 deletions integrations/ultralytics/deepsparse/deepsparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
"""


from typing import List, Tuple, Union
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union

import cv2
import numpy
import onnx
import torch

from sparseml.onnx.utils import get_tensor_dim_shape, set_tensor_dim_shape
from sparsezoo import Zoo

# ultralytics/yolov5 imports
from utils.general import non_max_suppression

Expand All @@ -37,22 +42,13 @@
]


def _get_grid(size: int) -> torch.Tensor:
# adapted from yolov5.yolo.Detect._make_grid
coords_y, coords_x = torch.meshgrid([torch.arange(size), torch.arange(size)])
grid = torch.stack((coords_x, coords_y), 2)
return grid.view(1, 1, size, size, 2).float()


# Yolo V3 specific variables
_YOLO_V3_ANCHORS = [
torch.Tensor([[10, 13], [16, 30], [33, 23]]),
torch.Tensor([[30, 61], [62, 45], [59, 119]]),
torch.Tensor([[116, 90], [156, 198], [373, 326]]),
]
_YOLO_V3_ANCHOR_GRIDS = [t.clone().view(1, -1, 1, 1, 2) for t in _YOLO_V3_ANCHORS]
_YOLO_V3_OUTPUT_SHAPES = [80, 40, 20]
_YOLO_V3_GRIDS = [_get_grid(grid_size) for grid_size in _YOLO_V3_OUTPUT_SHAPES]


def load_image(
Expand All @@ -70,28 +66,53 @@ def load_image(
return img


def pre_nms_postprocess(outputs: List[numpy.ndarray]) -> torch.Tensor:
class YoloPostprocessor:
"""
:param outputs: raw outputs of a YOLOv3 model before anchor grid processing
:return: post-processed model outputs without NMS.
"""
# postprocess and transform raw outputs into single torch tensor
processed_outputs = []
for idx, pred in enumerate(outputs):
pred = torch.from_numpy(pred)
pred = pred.sigmoid()
Class for performing postprocessing of YOLOv3 model predictions

# get grid and stride
grid = _YOLO_V3_GRIDS[idx]
anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx]
stride = 640 / _YOLO_V3_OUTPUT_SHAPES[idx]
:param image_size: size of input image to model. used to calculate stride based on
output shapes
"""

# decode xywh box values
pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride
pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid
# flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5)
processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1)))
return torch.cat(processed_outputs, 1)
def __init__(self, image_size: Tuple[int]):
self._image_size = image_size
self._grids = {} # Dict[Tuple[int], torch.Tensor]

def pre_nms_postprocess(self, outputs: List[numpy.ndarray]) -> torch.Tensor:
"""
:param outputs: raw outputs of a YOLOv3 model before anchor grid processing
:return: post-processed model outputs without NMS.
"""
# postprocess and transform raw outputs into single torch tensor
processed_outputs = []
for idx, pred in enumerate(outputs):
pred = torch.from_numpy(pred)
pred = pred.sigmoid()

# get grid and stride
grid_shape = pred.shape[2:4]
grid = self._get_grid(grid_shape)
anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx]
stride = self._image_size[0] / grid_shape[0]

# decode xywh box values
pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride
pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid
# flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5)
processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1)))
return torch.cat(processed_outputs, 1)

def _get_grid(self, grid_shape: Tuple[int]) -> torch.Tensor:
if grid_shape not in self._grids:
# adapted from yolov5.yolo.Detect._make_grid
coords_y, coords_x = torch.meshgrid(
[torch.arange(grid_shape[0]), torch.arange(grid_shape[1])]
)
grid = torch.stack((coords_x, coords_y), 2)
self._grids[grid_shape] = grid.view(
1, 1, grid_shape[0], grid_shape[1], 2
).float()
return self._grids[grid_shape]


def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]:
Expand All @@ -102,3 +123,60 @@ def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]:
# run nms in PyTorch, only post-process first output
nms_outputs = non_max_suppression(outputs)
return [output.cpu().numpy() for output in nms_outputs]


def modify_yolo_onnx_input_shape(
model_path: str, image_shape: Tuple[int]
) -> Tuple[str, Optional[NamedTemporaryFile]]:
"""
Creates a new YOLOv3 ONNX model from the given path that accepts the given input
shape. If the given model already has the given input shape no modifications are
made. Uses a tempfile to store the modified model file.

:param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model
to be loaded
:param image_shape: 2-tuple of the image shape to resize this yolo model to
:return: filepath to an onnx model reshaped to the given input shape will be the
original path if the shape is the same. Additionally returns the
NamedTemporaryFile for managing the scope of the object for file deletion
"""
original_model_path = model_path
if model_path.startswith("zoo:"):
# load SparseZoo Model from stub
model = Zoo.load_model_from_stub(model_path)
model_path = model.onnx_file.downloaded_path()
print(f"Downloaded {original_model_path} to {model_path}")

model = onnx.load(model_path)
model_input = model.graph.input[0]

initial_x = get_tensor_dim_shape(model_input, 2)
initial_y = get_tensor_dim_shape(model_input, 3)

if not (isinstance(initial_x, int) and isinstance(initial_y, int)):
return model_path, None # model graph does not have static integer input shape

if (initial_x, initial_y) == tuple(image_shape):
return model_path, None # no shape modification needed

scale_x = initial_x / image_shape[0]
scale_y = initial_y / image_shape[1]
set_tensor_dim_shape(model_input, 2, image_shape[0])
set_tensor_dim_shape(model_input, 3, image_shape[1])

for model_output in model.graph.output:
output_x = get_tensor_dim_shape(model_output, 2)
output_y = get_tensor_dim_shape(model_output, 3)
set_tensor_dim_shape(model_output, 2, int(output_x / scale_x))
set_tensor_dim_shape(model_output, 3, int(output_y / scale_y))

tmp_file = NamedTemporaryFile() # file will be deleted after program exit
onnx.save(model, tmp_file.name)

print(
f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n"
f"Original model path: {original_model_path}, new temporary model saved to "
f"{tmp_file.name}"
)

return tmp_file.name, tmp_file
Loading