In [None]:
from mlmodule.utils import list_files_in_dir
from mlmodule.torch.data.images import ImageDataset
from mlmodule.contrib.vinvl import VinVLDetector
from mlmodule.box import BBoxOutput
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
from PIL import Image
import os
%matplotlib inline

In [None]:
# Load VinVL model
torch_device = torch.device('cuda')
vinvl = VinVLDetector(device=torch_device, score_threshold=0.5)
# Pretrained model
vinvl.load()

In [None]:
# Getting data
base_path = os.path.join("../tests", "fixtures", "objects")
file_names = list_files_in_dir(base_path, allowed_extensions=('jpg',))[:50]
dataset = ImageDataset(file_names)

In [None]:
# Get the detections
indices, detections = vinvl.bulk_inference(dataset, data_loader_options={'batch_size': 10})

In [None]:
# Get labels and attributes
labels = vinvl.get_labels()
attribute_labels = vinvl.get_attribute_labels()

In [None]:
from mlmodule.contrib.vinvl.utils import postprocess_attr

for i, img_path in enumerate(indices):
    print(f'Object with attributes detected for {img_path}')
    img = Image.open(img_path).convert('RGB')
    plt.figure()
    plt.imshow(img)
    for k, det in enumerate(detections[i]):
        label = labels[det.labels[0]]
        attr_labels = det.attributes[det.attr_scores > 0.5]
        attr_scores = det.attr_scores[det.attr_scores > 0.5]
        attributes = postprocess_attr(attribute_labels, attr_labels, attr_scores)
        print(f'{k+1}: {",".join(list(attributes[0]))} {label} ({det.probability:.2f})')
        bbox0 = det.bounding_box[0].x
        bbox1 = det.bounding_box[0].y
        bbox2 = det.bounding_box[1].x
        bbox3 = det.bounding_box[1].y
        plt.gca().add_patch(Rectangle((bbox0, bbox1),
                                        bbox2 - bbox0,
                                        bbox3 - bbox1, fill=False,
                                      edgecolor='red', linewidth=2, alpha=0.5))
        plt.text(
            bbox0, bbox1, f'{label}', color='white', fontsize=12)
