In [None]:
from beam import Image, endpoint, env, Volume, function
from PIL import Image as PILImage
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image

BEAM_VOLUME_CACHE_PATH = "./weights"

In [None]:
    # Initialize the object detection model and processor
    processor = DetrImageProcessor.from_pretrained(
        "facebook/detr-resnet-101",
        revision="no_timm",
        cache_dir=BEAM_VOLUME_CACHE_PATH,
    )
    model = DetrForObjectDetection.from_pretrained(
        "facebook/detr-resnet-101",
        revision="no_timm",
        cache_dir=BEAM_VOLUME_CACHE_PATH,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print(f"Using device: {device}")


In [None]:
@function(
    gpu="T4",
)
def detector(image='soco.jpg'):

    processor = DetrImageProcessor.from_pretrained(
        "facebook/detr-resnet-101",
        revision="no_timm",
        cache_dir=BEAM_VOLUME_CACHE_PATH,
    )
    model = DetrForObjectDetection.from_pretrained(
        "facebook/detr-resnet-101",
        revision="no_timm",
        cache_dir=BEAM_VOLUME_CACHE_PATH,
    )
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print(f"Using device: {device}")
    
    image = PILImage.open(image)
    width, height = image.size


    # Process the image for object detection
    processed_inputs = processor(images=image, return_tensors="pt")
    processed_inputs = processed_inputs.to(device)
    outputs = model(**processed_inputs)


    # Convert outputs to COCO API format and filter with threshold
    target_sizes = torch.tensor([image.size[::-1]]).to(device)
    results = processor.post_process_object_detection(
     outputs, target_sizes=target_sizes, threshold=0.6
    )[0]
    
    return results



In [None]:
def show_results(results):
    # Prepare detection results
    for score, label, box in zip(
         results["scores"], results["labels"], results["boxes"]
    ):
        box = [round(i, 2) for i in box.tolist()]
        label_name = model.config.id2label[label.item()]
        confidence = round(score.item(), 3)

        print(
            f"Detected {label_name} with confidence {confidence} at location {box}"
        )


In [None]:
def show_jpg(filename):
    """Display a JPG image from the current working directory."""
    img = Image.open(filename)
    display(img)

In [None]:
#image = 'waffles_and_jack.jpg'
image = 'soco.jpg'
show_jpg(image)
results = detector(image=image)
show_results(results)