In [79]:
import os

import torch
from torchvision.io import ImageReadMode
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

In [80]:
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

In [81]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(root)))
        # remove files without image extensions
        self.imgs = [img for img in self.imgs if img.endswith(('.png', '.jpg', '.jpeg'))]

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.imgs[idx])
        # read the image and remove alpha channel
        img = read_image(img_path, ImageReadMode.RGB)
        return img
    def __len__(self):
        return len(self.imgs)

test_data = TestDataset("eagle_images", transforms=preprocess)

In [82]:
print(test_data.imgs)

['Duke Farms_2022_03_08_13_58_46.png', 'Duke Farms_2022_03_08_14_57_27.png', 'Duke Farms_2022_03_14_23_06_55.png', 'Duke Farms_2022_03_16_17_01_15.png', 'Duke Farms_2022_04_02_09_26_42.png', 'Duke Farms_2022_04_26_07_53_24.png', 'Duke Farms_2022_05_07_18_35_13.png', 'National Arboretum A_2022_03_21_19_05_20.png', 'National Arboretum A_2022_03_29_19_43_05.png', 'National Arboretum A_2022_04_04_06_28_34.png', 'National Arboretum A_2022_04_04_08_37_52.png', 'National Arboretum A_2022_04_28_11_13_18.png', 'National Arboretum A_2022_05_02_01_24_18.png', 'National Arboretum A_2022_05_07_03_23_45.png', 'National Arboretum A_2022_05_07_10_34_48.png', 'National Arboretum A_2022_05_09_03_41_02.png', 'National Arboretum A_2022_05_15_08_33_14.png', 'eagle_in_nest.jpg', 'eagle_in_nest2.jpg', 'empty_nest.jpg', 'hawk.jpg']


In [83]:
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=0)

In [84]:
for _, batch in enumerate(test_data_loader):
    prediction = model(preprocess(batch))[0]
    for image in batch:
        labels = []
        for i in range(len(prediction["labels"])):
            labels.append(weights.meta["categories"][prediction["labels"][i]] + " " + str(round(prediction["scores"][i].item(), 3)))

        box = draw_bounding_boxes(image, boxes=prediction["boxes"],
                                labels=labels,
                                colors="red",
                                width=4)
        im = to_pil_image(box.detach())
        im.show()



KeyboardInterrupt: 