This notebook demonstrates how to run inference using a model trained with RF-DETR (Region-Focused DEtection TRansformer)—a transformer-based object detection framework designed for efficient and accurate detection of objects, particularly in cluttered or regionally focused visual scenes. We will walk through the setup, load the trained weights, and perform predictions on test images or video frames to visualize bounding boxes and class outputs.

In [None]:
!pip install -q rfdetr supervision roboflow

In [None]:
# Import libraries.
from rfdetr.util.coco_classes import COCO_CLASSES
from rfdetr import RFDETRLarge
from google.colab import drive
import natsort
import io
import requests
import supervision as sv
from PIL import Image
import warnings
from typing import Dict, List, Optional, Tuple, Any
import json
import glob
import os
warnings.filterwarnings("ignore")

!export CUDA_LAUNCH_BLOCKING=1
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Connect to the google drive.
drive.mount('/content/gdrive')

try:
  !ln -s /content/gdrive/My\ Drive/ /mydrive
  print('Successful')
except Exception as e:
  print(e)
  print('Not successful')

## Load label and RF-DETR model.


In [None]:
# Define the categories used while training the model.
CLASSES = {
    0: 'dairy_product_packet'
}

In [None]:
model = RFDETRLarge(pretrain_weights="/mydrive/LLM/rf-detr/data/output/checkpoint_best_total.pth")
model.optimize_for_inference()

In [None]:
#@title Utils

DEFAULT_COLOR_HEX_LIST = [
    "#ffff00", "#ff9b00", "#ff66ff", "#3399ff", "#ff66b2", "#ff8080",
    "#b266ff", "#9999ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
]

def visualize_detections(
    image: Image.Image,
    detections: sv.Detections,
    class_names: List[str],
    threshold: float = 0.5,
    color_hex_list: List[str] = DEFAULT_COLOR_HEX_LIST
) -> Image.Image:
  """Visualizes bounding boxes and class labels on an image.

  This function uses the Supervision (sv) library to filter detections
  by confidence, calculate optimal annotation styles, and draw the
  annotations on a copy of the input image.

  Args:
      image (Image.Image): Input image from PIL.
      detections (sv.Detections): Output from the detection model.
      class_names (List[str]): List mapping class IDs to class names.
      threshold (float, optional): Minimum confidence score to keep
          a detection. Defaults to 0.5.
      color_hex_list (List[str], optional): A list of hex color strings
          for the color palette. Defaults to DEFAULT_COLOR_HEX_LIST.

  Returns:
      Image.Image: A new annotated image with bounding boxes and
          class labels.
  """
  # Filter detections by confidence threshold
  detections = detections[detections.confidence > threshold]

  # Determine optimal text scale and box thickness based on image resolution
  text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
  thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)

  # Define color palette from the provided hex list
  color_palette = sv.ColorPalette.from_hex(color_hex_list)

  # Set up annotators
  box_annotator = sv.BoxAnnotator(
      color=color_palette,
      thickness=thickness)

  label_annotator = sv.LabelAnnotator(
      color=color_palette,
      text_color=sv.Color.BLACK,
      text_scale=text_scale,
  )

  # Generate label strings
  labels = [
      f"{class_names[class_id]} {confidence:.2f}"
      for class_id, confidence in zip(
          detections.class_id, detections.confidence)
  ]

  # Create annotated copy of the image
  # We use keyword arguments for clarity on the `annotate` calls
  annotated_image = image.copy()
  annotated_image = box_annotator.annotate(
      scene=annotated_image,
      detections=detections
  )

  annotated_image = label_annotator.annotate(
      scene=annotated_image,
      detections=detections,
      labels=labels
  )

  return annotated_image

## Read images

In [None]:
files = glob.glob('/mydrive/LLM/milk_pouches/data/dairy/images/*')

In [None]:
image = Image.open(files[25])
detections = model.predict(image, threshold=0.50)
annotated = visualize_detections(
    image=image,
    detections=detections,
    class_names=CLASSES,
    threshold=0.50
)

sv.plot_image(annotated)