# Cat detection model
pruned YOLOv8 using COCO dataset to create a lightweight cat detection model

## To do
Head pruning
- ❌ Remove segmentation/pose estimation - not needed for default model
- ❌ Reduce box regression granularity - poor idea for small object detection
- ✅ Remove unused classes
- ✅ Update dataset to reflect pruning
- ❌ Remove P5 + retrain - made performance worse
- ❓ Add P2 + retrain

Whole model
- Prune
- ✅ Quantise


## Configure notebook

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

In [None]:
import copy

import torch
from matplotlib import pyplot as plt
from onnxruntime.quantization import quantize_dynamic, QuantType
from ultralytics import YOLO

import utils

In [None]:
%matplotlib inline

In [None]:
device = utils.get_best_device()
eval = "4000"
max_dets = 5
imgsz = 960

## Load and evaluate original model

In [None]:
# Load original YOLO model
model_original = YOLO(os.path.join("..", "models", "yolov8n.pt"))
_ = model_original.model.to(device)

In [None]:
# Evaluate original model on dataset with no large objects
metrics_original = model_original.val(
    data=os.path.join("..", f"coco{eval}_nolarge.yaml"),
    classes=[0, 15],
    max_det=max_dets,
    imgsz=imgsz,
)

## Structural pruning of detection head

In [None]:
# Copy original model
model_pruned = copy.deepcopy(model_original)
_ = model_pruned.model.to(device)

In [None]:
# Structural pruning of unused classes

# define new classes
keep_classes = [0, 15]
new_nc = len(keep_classes)

# extract detection head
detect = model_pruned.model.model[-1]

# loop over each of the multi-scale heads
for i, conv in enumerate(detect.cv3):

    # copy last layer - same architecture but reduced number of classes
    old_conv = conv[-1]
    new_conv = torch.nn.Conv2d(
        in_channels=old_conv.in_channels,
        out_channels=new_nc,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=True,
    )

    # copy weights to new layer
    with torch.no_grad():
        new_conv.weight.copy_(old_conv.weight[keep_classes])
        new_conv.bias.copy_(old_conv.bias[keep_classes])
    new_conv.requires_grad_(False)

    # insert layer back into head
    conv[-1] = new_conv

# update metadata
model_pruned.model.names = {0: "person", 1: "cat"}
detect.no = detect.no - detect.nc + new_nc
detect.nc = new_nc

# save model to disk (also syncs some important thing seemingly)
model_pruned.save(os.path.join("..", "models", "yolov8n_pruned.pt"))
model_pruned = YOLO(os.path.join("..", "models", "yolov8n_pruned.pt"))
_ = model_pruned.model.to(device)

In [None]:
# Evaluate pruned model
metrics_pruned = model_pruned.val(
    data=os.path.join("..", f"coco{eval}_nolarge_subclass_downsample.yaml"),
    classes=[0, 1],
    max_det=max_dets,
    imgsz=imgsz,
)

In [None]:
# # Structural pruning of unnecessary large object detection

# # extract detection head
# detect = model_pruned.model.model[-1]

# # remove last layer - corresponding to large object detection
# detect.nl = 2
# detect.cv2 = detect.cv2[: detect.nl]
# detect.cv3 = detect.cv3[: detect.nl]
# detect.f = detect.f[: detect.nl]
# detect.stride = torch.tensor([8, 16, 32][: detect.nl], dtype=torch.float32)

# # save model to disk (also syncs some important thing seemingly)
# model_pruned.save(os.path.join("..", "models", "yolov8n_pruned.pt"))
# model_pruned = YOLO(os.path.join("..", "models", "yolov8n_pruned.pt"))
# _ = model_pruned.model.to(device)

In [None]:
# # Evaluate pruned model
# metrics_pruned = model_pruned.val(
#     data=os.path.join("..", f"coco{eval}_nolarge_subclass_downsample.yaml"),
#     classes=[0, 1],
#     max_det=max_dets,
#     imgsz=imgsz,
# )

In [None]:
# # Retrain the model head

# # reload model to avoid inference mode lockin
# model_pruned = YOLO(os.path.join("..", "models", "yolov8n_pruned.pt"))
# _ = model_pruned.model.to(device)

# # unfreeze head only
# for p in model_pruned.model.parameters():
#     p.requires_grad = False
# for p in model_pruned.model.model[-1].parameters():
#     p.requires_grad = True

# # retrain the model
# model_pruned.train(
#     data=os.path.join("..", "coco_nolarge_subclass_downsample.yaml"),
#     epochs=30,
#     imgsz=imgsz,
#     lr0=1e-4,
#     patience=3,
#     device=device
# )

## Quantise model

In [None]:
# quantise model
model_pruned.export(format="onnx", imgsz=imgsz, opset=13, simplify=True)
quantize_dynamic(
    "../models/yolov8n_pruned.onnx",
    "../models/yolov8n_quantised.onnx",
    weight_type=QuantType.QUInt8,
)
model_quant = YOLO(os.path.join("..", "models", "yolov8n_pruned_q.onnx"), task="detect")

In [None]:
# Evaluate quantised model
metrics_pruned = model_quant.val(
    data=os.path.join("..", f"coco{eval}_nolarge_subclass_downsample.yaml"),
    classes=[0, 1],
    max_det=max_dets,
    imgsz=imgsz,
)

## Plot predictions made on sample images

In [None]:
fig, ax_all = plt.subplots(ncols=3, nrows=2, figsize=(15, 6))
sample_images_dir = os.path.join("..", "sample_images")
for ax, img_pth in zip(ax_all.flatten(), os.listdir(sample_images_dir)):
    res = model_quant.predict(
        os.path.join(sample_images_dir, img_pth), max_det=max_dets, imgsz=imgsz
    )[0]
    ax.imshow(res.plot()[:, :, ::-1])
    ax.axis("off")

fig.tight_layout()