In [3]:
import torch
import cv2
import os
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.structures import Boxes


In [2]:

def setup_layout_detector(cfg_path, weights_path, conf_threshold=0.3):
    """
    Set up the Faster R-CNN layout detector with Detectron2.
    :param cfg_path: Path to the YAML configuration file.
    :param weights_path: Path to the pre-trained weights file.
    :param conf_threshold: Confidence threshold for detections.
    :return: A Detectron2 DefaultPredictor instance.
    """
    cfg = get_cfg()
    cfg.merge_from_file(cfg_path)
    cfg.MODEL.WEIGHTS = weights_path
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold
    cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    return DefaultPredictor(cfg)

def predict_image(image_path, predictor):
    """
    Run inference on a given image using the layout detector.
    :param image_path: Path to the input image.
    :param predictor: Detectron2 DefaultPredictor instance.
    :return: Tuple (pred_boxes, pred_classes).
    """
    im = cv2.imread(image_path)
    if im is None:
        raise FileNotFoundError(f"Image at {image_path} not found.")
    
    outputs = predictor(im)
    instances = outputs["instances"]
    pred_boxes = instances.pred_boxes.tensor.cpu().numpy()
    pred_classes = instances.pred_classes.cpu().numpy()
    return pred_boxes, pred_classes

def vis(image_path, pred_boxes, pred_classes, output_path):
    """
    Visualize predictions and save the annotated image as a PNG file.
    :param image_path: Path to the original image.
    :param pred_boxes: Array of predicted bounding boxes.
    :param pred_classes: Array of predicted class indices.
    :param output_path: Path to save the annotated image.
    """
    class_dict = {0: "logo", 1: "input", 2: "button", 3: "label", 4: "block"}  # Update class labels as needed
    im = cv2.imread(image_path)
    if im is None:
        raise FileNotFoundError(f"Image at {image_path} not found.")
    
    # Initialize the visualizer
    visualizer = Visualizer(im[:, :, ::-1], scale=1.2)
    labels = [class_dict.get(cls, f"Class {cls}") for cls in pred_classes]
    annotated_im = visualizer.overlay_instances(
        boxes=Boxes(pred_boxes), labels=labels
    ).get_image()

    # Save the annotated image
    cv2.imwrite(output_path, annotated_im[:, :, ::-1])
    print(f"Visualization saved at: {output_path}")


In [4]:

if __name__ == "__main__":
    # Define paths
    cfg_path = "faster_rcnn_web.yaml"  # Update to your YAML configuration path
    weights_path = "layout_detector.pth"  # Update to your model weights path
    test_image_path = "shot.png"  # Update to your test image path
    output_image_path = "annotated_image.png"  # Path to save the annotated image

    # Initialize the layout detector
    predictor = setup_layout_detector(cfg_path, weights_path, conf_threshold=0.3)

    # Run inference on the test image
    print("Running inference...")
    pred_boxes, pred_classes = predict_image(test_image_path, predictor)
    print(f"Detected Boxes:\n{pred_boxes}")
    print(f"Detected Classes:\n{pred_classes}")

    # Visualize predictions
    print("Visualizing predictions...")
    vis(test_image_path, pred_boxes, pred_classes, output_image_path)


Config 'faster_rcnn_web.yaml' has no VERSION. Assuming it to be compatible with latest v2.
The checkpoint state_dict contains keys that are not used by the model:
  [35mpixel_mean[0m
  [35mpixel_std[0m


Running inference...
Detected Boxes:
[[  33.88331      3.2002234  511.58066    138.9013   ]
 [1932.86       268.82523   2069.9487     348.05768  ]
 [ 990.05273    272.19006   1862.5286     347.1824   ]
 [  16.412138   608.4681    1500.4865     688.3149   ]
 [  30.73153    361.14218    505.37436    452.51486  ]
 [  37.12949    270.08264    893.3168     346.99707  ]
 [  31.932842   202.7868     526.45355    256.99435  ]
 [ 842.10645    276.02762    907.5118     339.47083  ]
 [  18.453646   459.5318     530.69006    524.647    ]]
Detected Classes:
[0 2 1 4 2 1 4 2 2]
Visualizing predictions...
Visualization saved at: annotated_image.png
