In [None]:
from ultralytics import YOLO
from PIL import Image
import gradio as gr
import numpy as np
import easyocr
import re


# Cargar modelos
# ids_model = YOLO("/home/gnz/GitHub/yolo11_container/YOLO_IDs/ID_YOLO_container/weights/best.pt")
ids_model = YOLO("/home/gonzadzz/GitHub/yolo11_container/YOLO_IDs/ID_YOLO_container/weights/best.pt")
# char_model = YOLO("/home/gnz/GitHub/yolo11_container/YOLO_Characters/Character_YOLO_container_finetune_large/weights/best.pt")
char_model = YOLO("/home/gonzadzz/GitHub/yolo11_container/YOLO_Characters/Character_YOLO_container_finetune_extra_large_phase2/weights/best.pt")
# Inicializar EasyOCR
ocr_model = easyocr.Reader(['en','es'])

# Reglas RegEx para validación
rules = {
    "code-container": {"attribute": "code-container", "regex": r"^[A-Z]{4}\d{7}$"},
    "cn-11": {"attribute": "cn-11", "regex": r"^[A-Z]{4}\d{7}$"},
    "cn-4": {"attribute": "cn-4", "regex": r"^[A-Z]{4}$"},
    "cn-7": {"attribute": "cn-7", "regex": r"^\d{7}$"},
    "iso-type": {"attribute": "iso-type", "regex": r"^.{2}[A-Z0-9]{2}$"}  # ajustado a ISO tipo
}

# Reglas de validación
def parse_detecciones(detecciones, rules):
    parsed = {}
    for key, value in detecciones.items():
        if key in rules:
            attr = rules[key]["attribute"]
            pattern = rules[key]["regex"]

            # Validar con regex
            match = bool(re.match(pattern, value))

            # Resultado estructurado para Gradio JSON
            parsed[attr] = {
                "value": value,
                "valid": "✔️" if match else "❌"
            }
    return parsed

def calculate_check_digit(container_code: str) -> str | None:
    """
    Calcula o valida el dígito de check digit de un código de contenedor ISO 6346.
    
    - Si el código tiene 10 caracteres → calcula el dígito y devuelve el código completo (11).
    - Si el código tiene 11 caracteres → valida el dígito, si es correcto devuelve el mismo,
      si es incorrecto devuelve el código corregido.
    - Si el código tiene más de 11 → toma los primeros 10, calcula el dígito y devuelve esos 11.
    - Si los primeros 4 caracteres no son letras, devuelve None.
    """
    # Validar longitud mínima
    if len(container_code) < 10:
        return None

    # Validar que los primeros 4 sean letras
    if not container_code[:4].isalpha():
        return None

    # Tomar primeros 10 caracteres
    code_10 = container_code[:10]

    # Mapeo de letras a valores ISO 6346
    letter_values = {
        'A': 10, 'B': 12, 'C': 13, 'D': 14, 'E': 15, 'F': 16, 'G': 17, 'H': 18, 'I': 19, 'J': 20,
        'K': 21, 'L': 23, 'M': 24, 'N': 25, 'O': 26, 'P': 27, 'Q': 28, 'R': 29, 'S': 30, 'T': 31,
        'U': 32, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38
    }

    # Convertir a valores numéricos
    values = []
    for char in code_10:
        if char.isalpha():
            values.append(letter_values[char.upper()])
        else:
            values.append(int(char))

    # Calcular suma ponderada con 2^(posición)
    total = sum(val * (2 ** i) for i, val in enumerate(values))

    # Resto módulo 11
    check_digit = total % 11
    if check_digit == 10:
        check_digit = 0

    # Caso 10 caracteres → devolver con check digit
    if len(container_code) == 10:
        return code_10 + str(check_digit)

    # Caso 11 caracteres → validar o corregir
    if len(container_code) == 11:
        last_digit = container_code[10]
        if last_digit.isdigit() and int(last_digit) == check_digit:
            return container_code  # es válido
        else:
            # Corregir último carácter
            return code_10 + str(check_digit)

    # Caso más de 11 caracteres → recortar y recalcular
    if len(container_code) > 11:
        return code_10 + str(check_digit)

    return None



########################################################################################################################



def predict(image):
    detecciones_yolo = {}
    detecciones_easy = {}
    crops_con_labels = []
    texto_reconstruido_imgs = []

    # Variables auxiliares para armar code-container
    cn11_code_yolo, cn4_code_yolo, cn7_code_yolo = None, None, None
    cn11_code_easy, cn4_code_easy, cn7_code_easy = None, None, None

    # 1. Detección con primer modelo (IDs)
    results_id = ids_model.predict(image, conf=0.5)

    # 1a. Recolectar todas las detecciones con sus confidences
    detections = []
    for box in results_id[0].boxes:
        cls_id = int(box.cls[0].item())
        cls_name = ids_model.names[cls_id]
        conf = float(box.conf[0].item())
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        detections.append({
            "cls_name": cls_name,
            "conf": conf,
            "coords": (x1, y1, x2, y2)
        })

    # 1b. Filtrar solo la detección de mayor confidence por clase
    best_detections = {}
    for det in detections:
        cls_name = det["cls_name"]
        if cls_name not in best_detections or det["conf"] > best_detections[cls_name]["conf"]:
            best_detections[cls_name] = det

    # Imagen con todas las bounding boxes originales
    img_with_boxes = results_id[0].plot()
    img_with_boxes_pil = Image.fromarray(img_with_boxes)

    # 2. Procesar cada detección filtrada
    for cls_name, det in best_detections.items():
        x1, y1, x2, y2 = det["coords"]
        crop = image.crop((x1, y1, x2, y2))

        # 2a. Pasar crop al modelo OCR (YOLO chars)
        results_char = char_model.predict(crop, conf=0.5)
        chars_detected = []

        for cbox in results_char[0].boxes:
            c_cls_id = int(cbox.cls[0].item())
            c_cls_name = char_model.names[c_cls_id]
            cx1, cy1, cx2, cy2 = cbox.xyxy[0].tolist()
            char_crop = crop.crop((cx1, cy1, cx2, cy2))
            chars_detected.append((cx1, cy1, c_cls_name, char_crop))

        # 2b. Ordenar y concatenar caracteres para YOLO char
        text_pred = ""
        if cls_name in ["cn-11", "iso-type"]:
            if crop.height > crop.width * 1.5:  # vertical
                chars_detected = sorted(chars_detected, key=lambda x: x[1])
            else:  # horizontal
                chars_detected = sorted(chars_detected, key=lambda x: x[0])
            text_pred = "".join([c[2] for c in chars_detected])
        elif cls_name in ["cn-4", "cn-7"]:
            chars_detected = sorted(chars_detected, key=lambda x: x[0])
            text_pred = "".join([c[2] for c in chars_detected])

        # Guardar detecciones YOLO
        detecciones_yolo[cls_name] = text_pred
        if cls_name == "cn-11":
            cn11_code_yolo = text_pred
        elif cls_name == "cn-4":
            cn4_code_yolo = text_pred
        elif cls_name == "cn-7":
            cn7_code_yolo = text_pred

        # 2c. Guardar crops anotados
        crop_with_boxes = results_char[0].plot()
        crops_con_labels.append(Image.fromarray(crop_with_boxes))

        # 2d. EasyOCR
        # Regla: cn-4 / cn-7 horizontales → OCR siempre sobre crop original
        if cls_name in ["cn-4", "cn-7", "cn-11", "iso-type"] and crop.width > crop.height:
            ocr_text = ocr_model.readtext(np.array(crop), detail=0)
        else:
            if chars_detected:
                # Reconstrucción horizontal de chars
                widths, heights = zip(*(c[3].size for c in chars_detected))
                total_width = sum(widths)
                max_height = max(heights)
                new_img = Image.new("RGB", (total_width, max_height), color=(0,0,0))
                x_offset = 0
                for _, _, _, char_crop in chars_detected:
                    new_img.paste(char_crop, (x_offset,0))
                    x_offset += char_crop.width
                texto_reconstruido_imgs.append(new_img)

                ocr_text = ocr_model.readtext(np.array(new_img), detail=0)
            else:
                # No hay chars detectados → OCR sobre crop original
                ocr_text = ocr_model.readtext(np.array(crop), detail=0)

        # Limpiar espacios y guardar en detecciones_easy
        if ocr_text:
            ocr_text_clean = "".join(ocr_text).replace(" ", "")
            ocr_text_clean = re.sub(r'[^A-Z0-9]', '', ocr_text_clean.upper())
            detecciones_easy[cls_name] = ocr_text_clean
            if cls_name == "cn-11":
                cn11_code_easy = ocr_text_clean
            elif cls_name == "cn-4":
                cn4_code_easy = ocr_text_clean
            elif cls_name == "cn-7":
                cn7_code_easy = ocr_text_clean

    # 3. Construir code-container YOLO
    if cn11_code_yolo:
        detecciones_yolo["code-container"] = cn11_code_yolo
    elif cn4_code_yolo and cn7_code_yolo:
        detecciones_yolo["code-container"] = cn4_code_yolo + cn7_code_yolo

    # 4. Construir code-container EasyOCR
    if cn11_code_easy:
        detecciones_easy["code-container"] = cn11_code_easy
    elif cn4_code_easy and cn7_code_easy:
        detecciones_easy["code-container"] = cn4_code_easy + cn7_code_easy

    # 5. Validar ambos
    parsed_yolo = parse_detecciones(detecciones_yolo, rules)
    parsed_easy = parse_detecciones(detecciones_easy, rules) if detecciones_easy else {}

    # 6. Calcular validated_code_container para ambos
    validated_yolo = None
    validated_easy = None

    if "code-container" in detecciones_yolo:
        validated_yolo = calculate_check_digit(detecciones_yolo["code-container"])

    if "code-container" in detecciones_easy:
        validated_easy = calculate_check_digit(detecciones_easy["code-container"])

    # 7. Armar salida final
    salida_json = {
        "output_yolo_char": detecciones_yolo,
        "output_easy_ocr": detecciones_easy,
        "validation": {
            "yolo_char": parsed_yolo,
            "easy_ocr": parsed_easy
        },
        "validated_code_container": {
            "yolo_char": validated_yolo,
            "easy_ocr": validated_easy
        }
    }

    return img_with_boxes_pil, crops_con_labels, texto_reconstruido_imgs, salida_json

   


# Interfaz de Gradio
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detección IDs"),
        gr.Gallery(label="Crops con OCR", columns=2, height="auto"),
        gr.Gallery(label="Texto reconstruido en renglón", columns=1, height="auto"),
        gr.JSON(label="Resultados OCR")
    ],
    title="Container OCR Detector",
    description="Detecta IDs de contenedores. Si hay clase cn-11 se usa como code-container; si no, se genera con cn-4 + cn-7."
)

if __name__ == "__main__":
    demo.launch()


* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.



0: 384x640 1 cn-11, 1 iso-type, 57.1ms
Speed: 4.6ms preprocess, 57.1ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)

0: 256x640 (no detections), 48.1ms
Speed: 1.1ms preprocess, 48.1ms inference, 0.6ms postprocess per image at shape (1, 3, 256, 640)

0: 128x640 2 0s, 1 3, 1 4, 1 7, 1 9, 1 T, 1 U, 71.9ms
Speed: 1.5ms preprocess, 71.9ms inference, 2.8ms postprocess per image at shape (1, 3, 128, 640)

0: 384x640 1 cn-11, 1 iso-type, 89.4ms
Speed: 4.8ms preprocess, 89.4ms inference, 2.7ms postprocess per image at shape (1, 3, 384, 640)

0: 640x96 1 1, 2 2s, 1 G, 72.3ms
Speed: 0.8ms preprocess, 72.3ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 96)

0: 640x64 1 1, 1 2, 1 3, 1 4, 2 8s, 1 9, 1 A, 63.4ms
Speed: 0.7ms preprocess, 63.4ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 64)

0: 384x640 1 cn-11, 1 iso-type, 71.5ms
Speed: 7.3ms preprocess, 71.5ms inference, 2.1ms postprocess per image at shape (1, 3, 384, 640)

0: 128x640 1 0, 1 1, 1 