In [None]:
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw
import torchvision.transforms as T
import torchvision.ops as ops
import torchvision.transforms.functional as F
from io import BytesIO
import numpy as np

import matplotlib.pyplot as plt

def detectar_y_dibujar_detr(image_path, model_path='aerial_animals_DETR.pth', score_threshold=0.1, iou_threshold=0.1):
    # Cargar processor y modelo fine-tuneado
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
    #model = DetrForObjectDetection.from_pretrained(model_path)


    # Cargar modelo con num_labels correcto y permitir discrepancias de tamaños
    model = DetrForObjectDetection.from_pretrained(
        "facebook/detr-resnet-50",
        num_labels=6,  # Ajusta al número de clases de tu entrenamiento
        ignore_mismatched_sizes=True
    )

    # Cargar pesos entrenados
    model.load_state_dict(torch.load(model_path, map_location="cpu"))


    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(DEVICE)

    model.eval()

    # Cargar imagen
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    # Obtener predicciones procesadas
    target_sizes = torch.tensor([image.size[::-1]], device=DEVICE)  # (H, W)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=score_threshold)[0]

    # Aplicar NMS si se desea (DETR no lo aplica por diseño)
    boxes = results["boxes"]
    scores = results["scores"]
    labels = results["labels"]

    keep_nms = ops.nms(boxes, scores, iou_threshold)
    boxes = boxes[keep_nms]
    labels = labels[keep_nms]

    # Dibujar resultados sobre imagen
    draw = ImageDraw.Draw(image)
    for box in boxes:

        x1, y1, x2, y2 = box.tolist()
        draw.rectangle([x1, y1, x2, y2], outline="blue", width=10)

    fig, ax = plt.subplots(figsize=(10, 10))

    # Convertir imagen con cajas a tensor
    output_tensor = T.ToTensor()(image)
    plt.tight_layout()
    plt.imshow(F.to_pil_image(output_tensor))
    plt.axis('off')

    buf = BytesIO()
    fig.savefig(buf, format='jpeg', bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    plt.close(fig)  # Avoid displaying


    image = Image.open(buf).copy()
    buf.close()
    image_array = np.array(image)
    # plt.show()

    return image_array, len(boxes)





image_array, num_boxes = detectar_y_dibujar_detr('19c019842e984b53a75251fa6a4c54e05682b762.JPG')
image_array
# plt.tight_layout()
# plt.imshow(F.to_pil_image(output_tensor))
# plt.axis('off')
# plt.show()

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

output_tensor = detectar_y_dibujar_detr('308982734a08ca0092bd98b963655acea0a162b0.JPG')
plt.imshow(F.to_pil_image(output_tensor))
plt.axis('off')
plt.show()


