# Quantization Experiments

In [None]:
import mlflow
import mlflow.pyfunc
import torch
from models.classification.resnet_classification import ResNetClassification
from models.detection.resnet_detection import ResNetDetection
from models.segmentation.resnet_segmentation import ResNetSegmentation
from scripts.quantization_methods import apply_quantization


mlflow.set_tracking_uri("http://localhost:5000")


classification_model = ResNetClassification.load_from_checkpoint('path/to/classification_model.ckpt')
detection_model = ResNetDetection.load_from_checkpoint('path/to/detection_model.ckpt')
segmentation_model = ResNetSegmentation.load_from_checkpoint('path/to/segmentation_model.ckpt')

quantization_methods = ['dynamic', 'static', 'quantization_aware']

for method in quantization_methods:
    with mlflow.start_run():

        quantized_classification_model = apply_quantization(classification_model, method)
        quantized_detection_model = apply_quantization(detection_model, method)
        quantized_segmentation_model = apply_quantization(segmentation_model, method)


        classification_accuracy = evaluate_model(quantized_classification_model, 'classification')
        detection_mAP = evaluate_model(quantized_detection_model, 'detection')
        segmentation_mIoU = evaluate_model(quantized_segmentation_model, 'segmentation')


        mlflow.log_param('quantization_method', method)
        mlflow.log_metric('classification_accuracy', classification_accuracy)
        mlflow.log_metric('detection_mAP', detection_mAP)
        mlflow.log_metric('segmentation_mIoU', segmentation_mIoU)

        print(f'Logged metrics for {method} quantization method.')