In [None]:
import gradio as gr
from ultralytics import YOLO
from PIL import Image
import re


# Cargar modelos
ids_model = YOLO("/home/gnz/GitHub/yolo11_container/YOLO_IDs/ID_YOLO_container/weights/best.pt")
# ids_model = YOLO("/home/gonzadzz/GitHub/yolo11_container/YOLO_IDs/ID_YOLO_container/weights/best.pt")
char_model = YOLO("/home/gnz/GitHub/yolo11_container/YOLO_Characters/Character_YOLO_container_finetune_large/weights/best.pt")
# char_model = YOLO("/home/gonzadzz/GitHub/yolo11_container/YOLO_Characters/Character_YOLO_container_finetune_large/weights/best.pt")

# Reglas RegEx para validación
rules = {
    "code-container": {"attribute": "code-container", "regex": r"^[A-Z]{4}\d{7}$"},
    "cn-11": {"attribute": "cn-11", "regex": r"^[A-Z]{4}\d{7}$"},
    "cn-4": {"attribute": "cn-4", "regex": r"^[A-Z]{4}$"},
    "cn-7": {"attribute": "cn-7", "regex": r"^\d{7}$"},
    "iso-type": {"attribute": "iso-type", "regex": r"^.{2}[A-Z0-9]{2}$"}  # ajustado a ISO tipo
}

# Reglas de validación
def parse_detecciones(detecciones, rules):
    parsed = {}
    for key, value in detecciones.items():
        if key in rules:
            attr = rules[key]["attribute"]
            pattern = rules[key]["regex"]

            # Validar con regex
            match = bool(re.match(pattern, value))

            # Resultado estructurado para Gradio JSON
            parsed[attr] = {
                "value": value,
                "valid": "✔️" if match else "❌"
            }
    return parsed

def predict(image):
    detecciones = {}
    crops_con_labels = []
    texto_reconstruido_imgs = []

    # 1. Detección con primer modelo (IDs)
    results_id = ids_model.predict(image, conf=0.25)
    img_with_boxes = results_id[0].plot()  # Imagen con bounding boxes
    img_with_boxes_pil = Image.fromarray(img_with_boxes)

    # Variables auxiliares para armar code-container
    cn11_code = None
    cn4_code = None
    cn7_code = None

    # 2. Procesar cada detección del primer modelo
    for box in results_id[0].boxes:
        cls_id = int(box.cls[0].item())
        cls_name = ids_model.names[cls_id]

        # Coordenadas del crop
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        crop = image.crop((x1, y1, x2, y2))

        # 3. Pasar crop al modelo OCR
        results_char = char_model.predict(crop, conf=0.25)
        chars_detected = []

        for cbox in results_char[0].boxes:
            c_cls_id = int(cbox.cls[0].item())
            c_cls_name = char_model.names[c_cls_id]
            cx1, cy1, cx2, cy2 = cbox.xyxy[0].tolist()
            char_crop = crop.crop((cx1, cy1, cx2, cy2))
            chars_detected.append((cx1, cy1, c_cls_name, char_crop))

        # 4. Ordenar caracteres según la clase
        text_pred = ""
        if cls_name in ["cn-11", "iso-type"]:
            # Detectar si es vertical u horizontal
            if crop.height > crop.width * 1.5:
                # Vertical → ordenar por Y
                chars_detected = sorted(chars_detected, key=lambda x: x[1])
            else:
                # Horizontal → ordenar por X
                chars_detected = sorted(chars_detected, key=lambda x: x[0])
            text_pred = "".join([c[2] for c in chars_detected])
        elif cls_name in ["cn-4", "cn-7"]:
            # Siempre horizontales → ordenar por X
            chars_detected = sorted(chars_detected, key=lambda x: x[0])
            text_pred = "".join([c[2] for c in chars_detected])

        # 5. Guardar en diccionario y variables auxiliares
        detecciones[cls_name] = text_pred
        if cls_name == "cn-11":
            cn11_code = text_pred
        elif cls_name == "cn-4":
            cn4_code = text_pred
        elif cls_name == "cn-7":
            cn7_code = text_pred

        # 6. Guardar crops anotados
        crop_with_boxes = results_char[0].plot()
        crops_con_labels.append(Image.fromarray(crop_with_boxes))

        # 7. Guardar reconstrucción solo si hubo caracteres
        if chars_detected:
            widths, heights = zip(*(c[3].size for c in chars_detected))
            total_width = sum(widths)
            max_height = max(heights)
            new_img = Image.new("RGB", (total_width, max_height), color=(0, 0, 0))
            x_offset = 0
            for _, _, _, char_crop in chars_detected:
                new_img.paste(char_crop, (x_offset, 0))
                x_offset += char_crop.width
            texto_reconstruido_imgs.append(new_img)

    # 8. Construir code-container
    if cn11_code:
        detecciones["code-container"] = cn11_code
    elif cn4_code and cn7_code:
        detecciones["code-container"] = cn4_code + cn7_code

    # 9. Validar detecciones con reglas
    parsed = parse_detecciones(detecciones, rules)
    detecciones["validation"] = parsed


    return img_with_boxes_pil, crops_con_labels, texto_reconstruido_imgs, detecciones



# Interfaz de Gradio
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detección IDs"),
        gr.Gallery(label="Crops con OCR", columns=2, height="auto"),
        gr.Gallery(label="Texto reconstruido en renglón", columns=1, height="auto"),
        gr.JSON(label="Resultados OCR")
    ],
    title="Container OCR Detector",
    description="Detecta IDs de contenedores. Si hay clase cn-11 se usa como code-container; si no, se genera con cn-4 + cn-7."
)

if __name__ == "__main__":
    demo.launch()


* Running on local URL:  http://127.0.0.1:7866
* To create a public link, set `share=True` in `launch()`.



0: 448x640 2 cn-11s, 441.5ms
Speed: 4.3ms preprocess, 441.5ms inference, 1.2ms postprocess per image at shape (1, 3, 448, 640)

0: 128x640 2 2s, 1 3, 1 4, 2 5s, 2 6s, 3 7s, 1 C, 1 S, 1 U, 194.8ms
Speed: 1.0ms preprocess, 194.8ms inference, 1.8ms postprocess per image at shape (1, 3, 128, 640)

0: 128x640 3 0s, 3 1s, 2 2s, 2 5s, 1 7, 1 9, 194.9ms
Speed: 1.1ms preprocess, 194.9ms inference, 2.1ms postprocess per image at shape (1, 3, 128, 640)

0: 448x640 1 cn-11, 1 iso-type, 310.4ms
Speed: 6.1ms preprocess, 310.4ms inference, 1.0ms postprocess per image at shape (1, 3, 448, 640)

0: 416x640 (no detections), 288.1ms
Speed: 1.7ms preprocess, 288.1ms inference, 0.8ms postprocess per image at shape (1, 3, 416, 640)

0: 128x640 4 1s, 1 4, 2 8s, 1 A, 1 E, 2 Ls, 1 T, 1 U, 202.2ms
Speed: 1.1ms preprocess, 202.2ms inference, 3.3ms postprocess per image at shape (1, 3, 128, 640)

0: 384x640 1 cn-11, 1 iso-type, 301.5ms
Speed: 8.5ms preprocess, 301.5ms inference, 1.8ms postprocess per image at sh