In [None]:
'''
File: DETRModelEvaluator.ipynb
Author: Ishir Garg (ishirgarg@berkeley.edu)
Date: 3/18/24

Evaluator for DETR model
'''
import sys
sys.path.append("../")

from ModelEvaluator import ModelEvaluator
import numpy as np
import cv_utils
import torch
from PIL import Image
import torchvision.transforms as T


In [None]:
class DeepForestModelEvaluator(ModelEvaluator):
    def load_model(self):
        '''Loads the latest DeepForest model from the library'''
        detr_model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
        detr_model.eval()
        return detr_model
    
    def predict_image(self, model, rgb_image) -> dict:
        image = Image.fromarray(rgb_image).convert("RGB")

        transform = T.Compose([
            T.Resize(800),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        image = transform(image).unsqueeze(0)

        detection = model(image)
        logits = detection["pred_logits"]
        bboxes = detection["pred_boxes"]

        probabilities = logits.softmax(-1)[0, :, :-1]
        rescaled_bboxes = self._rescale_bboxes(bboxes[0], (rgb_image.shape[0], rgb_image.shape[1])).detach().numpy()

        return {
            "bboxes": rescaled_bboxes,
            "scores": probabilities.max(dim=1)[0].detach().numpy(),
        }
        
    
    # IN GENERAL, THIS FUNCTION IS NOT NECESSARY... ITS ONLY A HELPER FOR PROCESSING DETR OUTPUTS
    def _rescale_bboxes(self, out_bbox, size):
        img_w, img_h = size
        b = cv_utils.box_cxcywh_to_xyxy(out_bbox)
        b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
        return b

In [None]:
evaluator = DeepForestModelEvaluator("/Users/ishirgarg/Github/UAV_Playground/NEON/evaluation", "/Users/ishirgarg/Github/UAV_Playground/NEON/annotations")
detections = evaluator.evaluate_model(confidence_threshold=0, iou_threshold=0.4)

print(detections["metrics"])

In [None]:
for i in range(evaluator.dataset_len()):
    evaluator.plot_image_annotations(i, (4,4))