In [None]:
import torch
from torchvision.models import detection
import numpy as np
from PIL import Image

def load_image(image_path):
    """
    Loads an image from the given file path using the PIL library.
    """
    image = Image.open(image_path)
    array = np.array(image)
    return array

# Define model factory
def model_factory(model_name):
    if model_name == 'RetinaNet':
        # Load RetinaNet model in inference mode
        model = detection.retinanet_resnet50_fpn(pretrained=True, pretrained_backbone=True)
    elif model_name == 'FasterRCNN':
        # Load FasterRCNN model in inference mode
        model = detection.fasterrcnn_resnet50_fpn(pretrained=True, pretrained_backbone=True)
    elif model_name == 'SSDLite':
        # Load SSD Lite model in inference mode
        model = detection.ssd_lite_mobilenet_v3_large(pretrained=True, pretrained_backbone=True)
    elif model_name == 'Yolov5':
        # Load Yolov5 model in inference mode
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
    else:
        raise ValueError('Invalid model name')
    model.eval()
    return model


# Define function to detect with a given model
def detect_with_model(image, model_name):
    # Get model
    model = model_factory(model_name)
    # Preprocess image
    image = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0)
    
    if torch.cuda.is_available():
        image = image.cuda()
        model.eval().cuda()
        print("using gpu")

    # Perform detection
    with torch.no_grad():
        detections = model(image)

    #detections = model.postprocess(detections)

    return detections

In [None]:
arr = load_image("/workspace/tests/pexels-pixabay-45201.jpg")
result = detect_with_model(arr,"RetinaNet")
print(result)

In [None]:
arr = load_image("/workspace/tests/pexels-pixabay-45201.jpg")
result = detect_with_model(arr,"FasterRCNN")
print(result)

In [None]:
arr = load_image("/workspace/tests/pexels-pixabay-45201.jpg")
result = detect_with_model(arr,"SSDLite")
print(result)

In [None]:
arr = load_image("/workspace/tests/pexels-pixabay-45201.jpg")
result = detect_with_model(arr,"Yolov5")
print(result)