In [17]:
from ultralytics import YOLO
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnx
import onnxsim
import os

In [None]:
def create_yolo11_cls_model(variant='n'):
    model_name = f'yolo11{variant}-cls.pt'
    return YOLO(model_name)

def train_yolo_model(model, data_path='.', epochs=10, imgsz=256, batch=32):
    results = model.train(data=data_path, epochs=epochs, imgsz=imgsz, batch=batch, verbose=False)
    return model, results

def simplify_onnx_model(input_path, output_path):
    model = onnx.load(input_path)
    model_simp, check = onnxsim.simplify(model)
    onnx.save(model_simp, output_path)

def dynamic_quantization(input_model_path, output_model_path):
    quantize_dynamic(
        model_input=input_model_path,
        model_output=output_model_path,
        weight_type=QuantType.QUInt8,
        per_channel=True,
        reduce_range=True
    )
yolo11_n = create_yolo11_cls_model('n')
yolo11_s = create_yolo11_cls_model('s')

In [None]:
yolo11_n = train_yolo_model(yolo11_n)

In [None]:
yolo11_s = train_yolo_model(yolo11_s)

In [None]:
models_dir = "../models/classification"
os.makedirs(models_dir, exist_ok=True)
model = YOLO('/home/semyon/runs/classify/train5/weights/best.pt')
metrics = model.val(data='.')
print(f"Test accuracy: {metrics.top1:.3f}")
    
model_path = os.path.join(models_dir, "model.onnx")
simplified_path = os.path.join(models_dir, "model_simplified.onnx")
quantized_path = os.path.join(models_dir, "model_dynamic_quant.onnx")
    
simplify_onnx_model(model_path, simplified_path)
dynamic_quantization(simplified_path, quantized_path)