In [5]:
import torch
import torch.nn.utils.prune as prune
from ultralytics import YOLO

def prune_model(model, amount=0.1):
  for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
      prune.l1_unstructured(module, amount=amount, name='weight')
      prune.remove(module, 'weight')
  return model

model = YOLO('yolov8s.pt')

results = model.val(data="coco8.yaml")
print(f"mAP50-95: {results.box.map}")

torch_model = model.model

print(torch_model)

print("Pruning model...")
pruned_torch_model = prune_model(torch_model, amount=0.05)
print("Model pruned.")

model.model = pruned_torch_model

print("Saving pruned model...")

model.save("yolov8s_trained_pruned.pt")

print("Pruned model saved.")

# Optional: Fine-tune the pruned model
# Uncomment the following lines if you wnat to fine-tune
# results = model.train(data='your_dataset.yaml', epochs=50)
# model.save('yolov8s_trained_pruned_finetuned.pt')

# Evaluate the pruned model (and optionally the fine-tuned model)

model = YOLO('yolov8s_trained_pruned.pt')
results = model.val(data="coco8.yaml")
print(f"mAP50-95: {results.box.map}")

# If you fine-tuned, you can evaluate the fine-tuned model similarly
# fine_tuned_model = YOLO('yolov8s_trained_pruned_finetuned.pt')
# results = fine_tuned_model.val()
# print(f"Fine-tuned mAP50-95: {results.box.map}")

Ultralytics 8.3.5 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (Tesla T4, 15102MiB)
YOLOv8s summary (fused): 168 layers, 11,156,544 parameters, 0 gradients, 28.6 GFLOPs


[34m[1mval: [0mScanning /content/datasets/coco8/labels/val.cache... 4 images, 0 backgrounds, 0 corrupt: 100%|██████████| 4/4 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:00<00:00,  5.60it/s]


                   all          4         17       0.87       0.92      0.947      0.719
                person          3         10      0.855       0.59      0.707      0.379
                   dog          1          1      0.886          1      0.995      0.796
                 horse          1          2      0.794          1      0.995        0.8
              elephant          1          2          1      0.931      0.995       0.55
              umbrella          1          1      0.686          1      0.995      0.895
          potted plant          1          1          1          1      0.995      0.895
Speed: 0.7ms preprocess, 16.0ms inference, 0.0ms loss, 4.0ms postprocess per image
Results saved to [1mruns/detect/val5[0m
mAP50-95: 0.7193132380369528
DetectionModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(32, 64, kernel_s

[34m[1mval: [0mScanning /content/datasets/coco8/labels/val.cache... 4 images, 0 backgrounds, 0 corrupt: 100%|██████████| 4/4 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:00<00:00,  5.20it/s]


                   all          4         17      0.887      0.897      0.949       0.72
                person          3         10      0.861        0.5      0.716      0.384
                   dog          1          1      0.939          1      0.995      0.796
                 horse          1          2      0.811          1      0.995        0.8
              elephant          1          2          1       0.88      0.995       0.55
              umbrella          1          1      0.711          1      0.995      0.895
          potted plant          1          1          1          1      0.995      0.895
Speed: 0.3ms preprocess, 16.3ms inference, 0.0ms loss, 1.9ms postprocess per image
Results saved to [1mruns/detect/val6[0m
mAP50-95: 0.7201941345228444
