In [13]:
import logging

import dvc.api
import mlflow
from PIL import Image
from src.data.image_preprocessing import crop_image
from src.features.dataset import get_dataset
from src.features.dataset_generator import ImageDatasetType
from src.features.postprocessing import post_process_plate
from src.models.metrics import lev_dist
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from src.models.fetch_model import fetch_model
import numpy as np

In [18]:
model_name = "bigger_model_with_1_batch_norm"
model_version = "1"
tr_ocr_processor = "microsoft/trocr-small-printed"
tr_ocr_model = "microsoft/trocr-small-printed"

model = fetch_model(model_name=model_name, model_version=model_version)

def evaluate_bbox_detection():
    # Evaluate bbox detection
    [loss, root_mean_squared_error] = model.evaluate(test_set_bbox)

    mlflow.log_metrics({
        "loss": loss,
        "root_mean_squared_error": root_mean_squared_error
    })


def evaluate_ocr():

    transformer_processor = TrOCRProcessor.from_pretrained(tr_ocr_processor)
    transformer_model = VisionEncoderDecoderModel.from_pretrained(tr_ocr_model)
    bboxes = model.predict(test_set_bbox, batch_size=16)
    
    _accuracy, _accuracy_post_processed, _lev_dist, _lev_dist_post_processed = 0, 0, 0, 0

    for (bbox, sample) in zip(bboxes, test_set_plates):
        
        plate = sample[1][0].numpy().decode()
        image = sample[0][0].numpy().astype(np.uint8)

        try:
            image = crop_image(image, bbox)
        except:
            pass
        
        pixel_values = transformer_processor(image, return_tensors="pt").pixel_values
        generated_ids = transformer_model.generate(pixel_values)
        generated_text = transformer_processor.batch_decode(
            generated_ids, skip_special_tokens=True)[0]

        _accuracy += 1 if generated_text == plate else 0
        _lev_dist += lev_dist(generated_text, plate)

        generated_text = post_process_plate(generated_text)

        _accuracy_post_processed += 1 if generated_text == plate else 0
        _lev_dist_post_processed += lev_dist(generated_text, plate)


    n_samples = len(bboxes)
    print({
        "accuracy": _accuracy/n_samples,
        "accuracy_post_processed": _accuracy_post_processed/n_samples,
        "lev_dist": _lev_dist/n_samples,
        "lev_dist_post_processed": _lev_dist_post_processed/n_samples
    })
    mlflow.log_metrics({
        "accuracy": _accuracy/n_samples,
        "accuracy_post_processed": _accuracy_post_processed/n_samples,
        "lev_dist": _lev_dist/n_samples,
        "lev_dist_post_processed": _lev_dist_post_processed/n_samples
    })

test_set_bbox = get_dataset(
    "test", dataset_generator_type=ImageDatasetType.BboxImagesDatasetGenerator, batch_size=1, shuffle=False)
test_set_plates = get_dataset(
    "test", dataset_generator_type=ImageDatasetType.PlateImagesDatasetGenerator, batch_size=1, shuffle=False)

run_name = f"test_{model_name}_v{model_version}"
with mlflow.start_run(run_name=run_name):
    evaluate_bbox_detection()
    evaluate_ocr()


2023-01-21 01:06:37,540 - urllib3.connectionpool - DEBUG - https://dagshub.com:443 "GET /gianfrancodemarco/plate-recognition.mlflow/api/2.0/mlflow/model-versions/get-download-uri?name=bigger_model_with_1_batch_norm&version=1 HTTP/1.1" 200 None
2023-01-21 01:06:37,847 - urllib3.connectionpool - DEBUG - https://dagshub.com:443 "GET /gianfrancodemarco/plate-recognition.mlflow/api/2.0/mlflow-artifacts/artifacts?path=3bc5e8d24d35422db879e9f1cda81180%2F708fd1f2dace4076a4ccb0cbf7b26cf6%2Fartifacts%2Fmodel HTTP/1.1" 200 None
2023-01-21 01:06:38,100 - urllib3.connectionpool - DEBUG - https://dagshub.com:443 "GET /gianfrancodemarco/plate-recognition.mlflow/api/2.0/mlflow-artifacts/artifacts?path=3bc5e8d24d35422db879e9f1cda81180%2F708fd1f2dace4076a4ccb0cbf7b26cf6%2Fartifacts%2Fmodel HTTP/1.1" 200 None
2023-01-21 01:06:38,462 - urllib3.connectionpool - DEBUG - https://dagshub.com:443 "GET /gianfrancodemarco/plate-recognition.mlflow/api/2.0/mlflow-artifacts/artifacts?path=3bc5e8d24d35422db879e9f1cd

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


2023-01-21 01:06:43,770 - urllib3.connectionpool - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
2023-01-21 01:06:44,401 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /microsoft/trocr-small-printed/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2023-01-21 01:06:44,626 - urllib3.connectionpool - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
2023-01-21 01:06:45,118 - urllib3.connectionpool - DEBUG - https://dagshub.com:443 "POST /gianfrancodemarco/plate-recognition.mlflow/api/2.0/mlflow/runs/update HTTP/1.1" 200 None


KeyboardInterrupt: 

In [None]:
with mlflow.start_run(
    run_id='b5c95c76e5c54bd688eab7ef15008425'
):

    mlflow.log_params(params)