# Post Training Pruning and Quantization of YOLOv8s

## Imports

In [1]:
import ultralytics
import torch
import torch.nn.utils.prune as prune
import os
import cv2
import numpy as np
import copy
import time
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType, CalibrationDataReader, quantize_static, QuantFormat

In [2]:
# Process images for input into models
# need for inference and for calibration during quantization
def preprocess(image):
    img = cv2.resize(image, (640, 640))
    # normalize image
    img = np.array(img).astype(np.float32) / 255.0
    # reorder channels for input to model
    img = np.transpose(img, (2,0,1))
    # add batch dimension
    img = np.expand_dims(img, axis=0)
    return img

## Loading Model

In [3]:
# load models
base_model = ultralytics.YOLO('yolov8s.pt')
prune_unstruct_model = copy.deepcopy(base_model)

## Pruning

In [4]:
# prune weights and biases of convolutional layers
for name, module in prune_unstruct_model.named_modules():
    if 'conv' in name:
        prune.l1_unstructured(module, name='weight', amount=0.15)
        prune.remove(module, 'weight')
        if module.bias is not None:
            prune.l1_unstructured(module, name='bias', amount=0.15)
            prune.remove(module, 'bias')

In [None]:
torch.save(prune_unstruct_model, 'models/pruned.pt')
pruned_model = torch.load('models/pruned.pt')
pruned_model.export(format='onnx')
os.rename('yolov8s.onnx', 'models/pruned.onnx')

# Quantization

In [None]:
# export base model in onnx format
base_model.export(format = 'onnx')
os.rename('yolov8s.onnx', 'models/yolov8s.onnx')

### Dynamic Quantization

In [7]:
# preprocess model for quantization
!python -m onnxruntime.quantization.preprocess --input models/yolov8s.onnx --output models/processed.onnx

In [8]:
# Dynamically quantize model to unsigned int8
processed_model = 'models/processed.onnx'
quant_model = 'models/dynamic_quantized.onnx'
quantize_dynamic(processed_model, quant_model, weight_type=QuantType.QUInt8)

### Static Quantization

In [9]:
# class for getting calibration data for static quantization 
# https://quark.docs.amd.com/release-0.5.0/onnx/user_guide_datareader.html
class DataReader(CalibrationDataReader):
    def __init__(self, image_paths):
        self.image_paths = image_paths
        self.iterator = iter(self.image_paths)

    def get_next(self):
        try:
            image_path = next(self.iterator)
            image = cv2.imread(image_path)
            input_data = preprocess(image)
            return {"images": input_data}
        except StopIteration:
            return None

In [None]:
calibration_set = ['data/calibrate/000000005127.jpg', 'data/calibrate/000000008447.jpg', 'data/calibrate/000000010064.jpg',
                   'data/calibrate/000000011829.jpg', 'data/calibrate/000000016280.jpg', 'data/calibrate/000000021086.jpg',
                   'data/calibrate/000000026680.jpg', 'data/calibrate/000000027726.jpg', 'data/calibrate/000000028308.jpg',
                   'data/calibrate/000000029638.jpg', 'data/calibrate/000000038137.jpg', 'data/calibrate/000000038312.jpg']

calibration_data_reader = DataReader(calibration_set)

In [10]:
# Static quantization of prepared model
quantize_static(processed_model, "models/static_quantized.onnx",
                weight_type=QuantType.QInt8,
                activation_type=QuantType.QInt8,
                calibration_data_reader=calibration_data_reader,
                quant_format=QuantFormat.QDQ, # mixed precision quantization
                # exclude nodes in detect head
                nodes_to_exclude=['/model.22/Concat_3', '/model.22/Split', '/model.22/Sigmoid'
                                 '/model.22/dfl/Reshape', '/model.22/dfl/Transpose', '/model.22/dfl/Softmax', 
                                 '/model.22/dfl/conv/Conv', '/model.22/dfl/Reshape_1', '/model.22/Slice_1',
                                 '/model.22/Slice', '/model.22/Add_1', '/model.22/Sub', '/model.22/Div_1',
                                  '/model.22/Concat_4', '/model.22/Mul_2', '/model.22/Concat_5'],
                per_channel=False,
                reduce_range=True,)

# Testing Inference

### Functions for Inference

In [11]:
# Dict for classes in COCO
classes = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train',
                        7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant',
                        11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog',
                        17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant',
                        21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag',
                        27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis',
                        31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove',
                        36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass',
                        41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple',
                        48: 'sandwich', 49: 'orange', 50: 'broccoli',
                        51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch',
                        58: 'potted plant', 59: 'bed', 60: 'dining table',
                        61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard',
                        67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster',
                        71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors',
                        77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

# Define color palette for plotting detections
colors = np.random.uniform(0, 255, size=(len(classes), 3))

In [12]:
# Load model and initialize onnx session
def load_model(model_path):
    # run inference on CPU
    session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    model_inputs = session.get_inputs()
    input_shape = model_inputs[0].shape
    input_w, input_h = input_shape[2], input_shape[3]
    return session, input_w, input_h

In [13]:
# Run inference on image (detect objects in provided image)
def detect(session, img_data):
    ort = onnxruntime.OrtValue.ortvalue_from_numpy(img_data)
    results = session.run(["output0"], {"images": ort})
    return results

In [14]:
# draws given box and label/confidence on image 
def draw_boxes(img, box, score, class_id):
    # top left x, top left y, width, height
    x, y, w, h = box
    color = colors[class_id]
    cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
    label = f"{classes[class_id]}: {score:.2f}"
    cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)
    return img

In [15]:
# Postprocess results
# Determines highest probability predictions, draws predictions
# Returns image with detections and dict of detections
# https://github.com/ultralytics/ultralytics/blob/e5cb35edfc3bbc9d7d7db8a6042778a751f0e39e/examples/YOLOv8-OpenCV-ONNX-Python/
def postprocess(results, image, input_w, input_h, confidence=0.35, iou=0.5):
    img_h, img_w = image.shape[:2]
    outputs = np.transpose(np.squeeze(results[0]))
    rows = outputs.shape[0]
    boxes, scores, class_ids = [], [], []
    # scales to rescale bounding boxes to image size
    x_scale, y_scale = img_w / input_w, img_h / input_h

    for i in range(rows):
        class_scores = outputs[i][4:]
        max_score = np.amax(class_scores)
        if max_score >= confidence:
            class_id = np.argmax(class_scores)
            # bounding box
            x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
            # rescale bounding box to pixels, reformat from center to top left
            left = int((x - w / 2) * x_scale)
            top = int((y - h / 2) * y_scale)
            width = int(w * x_scale)
            height = int(h * y_scale)

            boxes.append([left, top, width, height])
            scores.append(max_score)
            class_ids.append(class_id)
            
    # non maximum suppression on potential detections
    indices = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou)
    detections = []
    for i in indices:
        box = boxes[i]
        score = scores[i]
        class_id = class_ids[i]
        detection = {
            'class_id': class_ids[i],
            'confidence': scores[i],
            'box': box}
        detections.append(detection)
        # draw detection on image
        image = draw_boxes(image, box, score, class_id)
    return image, detections

In [16]:
# Main inference function
# Reads image, preprocesses image, detects objects, postprocesses detections
def inference(img_path, session, input_w, input_h):
    image = cv2.imread(img_path)
    img_data = preprocess(image)
    results = detect(session, img_data)
    img_out, detections = postprocess(results, image, input_w, input_h)
    return img_out, detections

### Functions for Mean Average Precision Calculation

In [17]:
# Creates dict of labels for each image
def load_labels(label_dir):
    labels = {}
    for file in os.listdir(label_dir):
        if file.endswith('.txt'):
            image_name = file.replace('.txt', '.jpg')
            with open(os.path.join(label_dir, file), 'r') as f:
                boxes = []
                for line in f:
                    parts = line.strip().split()
                    class_id = int(parts[0])
                    box = [int(pixel) for pixel in parts[1:]]
                    boxes.append({'class_id': class_id, 'box': box})
            labels[image_name] = boxes
    return labels

In [18]:
# Reformat detections and labels structure to match input to ODMetrics
def reformat_data(detections, labels):
    formatted_detections = []
    formatted_labels = []
    
    for image_name in detections.keys():
        detection_data = detections[image_name]
        label_data = labels.get(image_name, [])
        
        # Format detections
        formatted_detections.append({
            "boxes": [item["box"] for item in detection_data],
            "labels": [int(item["class_id"]) for item in detection_data],
            "scores": [item["confidence"] for item in detection_data],
        })
        
        # Format labels
        formatted_labels.append({
            "boxes": [item["box"] for item in label_data],
            "labels": [int(item["class_id"]) for item in label_data],
        })
    
    return formatted_detections, formatted_labels

In [19]:
from od_metrics import ODMetrics

# Calculate mean average precision for detections
def calc_map(detections, label_dir):
    # Load labels for all images
    labels = load_labels(label_dir)

    f_detections, f_labels = reformat_data(detections, labels)
    metrics = ODMetrics()
    output = metrics.compute(f_labels, f_detections)

    return output['mAP@[.5 | all | 100]']

In [20]:
# Loads model, runs inference and calculates statistics

def evaluate_model(model_path, image_path, label_path, out_path):
    # Load model
    session, input_w, input_h = load_model(model_path)
    
    # Setup for outputs
    model_name = os.path.splitext(os.path.basename(model_path))[0]
    model_output_dir = os.path.join(out_path, model_name)
    os.makedirs(model_output_dir, exist_ok=True)
    
    # Get image files
    image_files = [os.path.join(image_path, f) for f in os.listdir(image_path) if f.endswith(('.jpg', '.png'))]
    
    total_time = 0
    detections = {}

    # Run inference on each image
    for img in image_files:
        start_time = time.perf_counter()
        img_out, dets = inference(img, session, input_w, input_h)
        end_time = time.perf_counter()
        
        total_time += (end_time - start_time)
        detections[os.path.basename(img)] = dets
        
        # Save image
        output_file_path = os.path.join(model_output_dir, os.path.basename(img))
        cv2.imwrite(output_file_path, img_out)

    # Mean average precision
    mean_ap = calc_map(detections, label_path)
    
    # Average inference time per image
    avg_inference_time = total_time / len(image_files) if image_files else 0
    
    return mean_ap, total_time, avg_inference_time

## Testing Models

In [21]:
image_path = 'data/images'
label_path = 'data/labels'
out_path = 'outputs'
os.makedirs(out_path, exist_ok=True)

In [22]:
base_model = 'models/yolov8s.onnx'
pruned_model = 'models/pruned.onnx'
quantized_dynamic_model = 'models/dynamic_quantized.onnx'
quantized_static_model = 'models/static_quantized.onnx'

### Base Model

In [23]:
mean_ap, total_time, avg_time = evaluate_model(base_model, image_path, label_path, out_path)
print(f"Mean Average Precision (mAP): {mean_ap}")
print(f"Total Inference Time: {total_time} seconds")
print(f"Average Inference Time: {avg_time} seconds")

Mean Average Precision (mAP): 0.39466801486037867
Total Inference Time: 16.473916871938854 seconds
Average Inference Time: 0.13074537199951472 seconds


### Pruned Model

In [24]:
mean_ap, total_time, avg_time = evaluate_model(pruned_model, image_path, label_path, out_path)
print(f"Mean Average Precision (mAP): {mean_ap}")
print(f"Total Inference Time: {total_time} seconds")
print(f"Average Inference Time: {avg_time} seconds")

Mean Average Precision (mAP): 0.2811851637021608
Total Inference Time: 16.180145043879747 seconds
Average Inference Time: 0.12841384955460117 seconds


### Dynamic Quantized Model

In [25]:
mean_ap, total_time, avg_time = evaluate_model(quantized_dynamic_model, image_path, label_path, out_path)
print(f"Mean Average Precision (mAP): {mean_ap}")
print(f"Total Inference Time: {total_time} seconds")
print(f"Average Inference Time: {avg_time} seconds")

Mean Average Precision (mAP): 0.39286253774346397
Total Inference Time: 23.82644220686052 seconds
Average Inference Time: 0.1890987476734962 seconds


### Static Quantized Model

In [26]:
mean_ap, total_time, avg_time = evaluate_model(quantized_static_model, image_path, label_path, out_path)
print(f"Mean Average Precision (mAP): {mean_ap}")
print(f"Total Inference Time: {total_time} seconds")
print(f"Average Inference Time: {avg_time} seconds")

Mean Average Precision (mAP): 0.32395280241791236
Total Inference Time: 11.88469374878332 seconds
Average Inference Time: 0.09432296626018508 seconds
