In [1]:
import timm
import torch
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms
from typing import List, Any
import matplotlib.pyplot as plt
import numpy as np
import  matplotlib.patches as patches

In [2]:
# timm.list_models('*efficientnetv2*')
train_root : str = "/datagrid/public_datasets/COCO/train2017"
val_root : str = "/datagrid/public_datasets/COCO/val2017"
train_annotations : str = "/datagrid/public_datasets/COCO/annotations/instances_train2017.json"
val_annotations : str = "/datagrid/public_datasets/COCO/annotations/instances_val2017.json"

class COCODataset(Dataset):
    def __init__(self, root, annFile):
        super().__init__()
        self.root = root 
        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.category_names = {}
        for cat in self.coco.loadCats(self.coco.getCatIds()):
            self.category_names[cat['id']] = cat['name']
    
    def _load_image(self, id : int) -> Image.Image:
        path = self.coco.loadImgs(id)[0]["file_name"]
        return Image.open(Path(self.root)/path).convert("RGB")

    def _load_target(self, id: int) -> List[Any]:
        targets = self.coco.loadAnns(self.coco.getAnnIds(id))
        return targets

    def __getitem__(self, index):
        id = self.ids[index]
        img = self._load_image(id)
        target = self._load_target(id)
        transformed = self.transforms(img)
        return img, target
    
    def __len__(self) -> int:
        return len(self.ids)



In [5]:
dataset = COCODataset(train_root, train_annotations)

loading annotations into memory...
Done (t=21.01s)
creating index...
index created!


In [None]:
print(dataset.category_names)

{1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'micro

In [6]:
%matplotlib inline
import matplotlib.pyplot as plt
def plt_img_bb(img, targets):
    cmap = plt.get_cmap("tab20b")
    detection_labels = torch.tensor([x["category_id"] for x in targets])
    detection_rects = torch.tensor([x["bbox"] for x in targets])
    unique_labels = detection_labels.unique()
    n_cls_pred = len(unique_labels)
    colors = [cmap(i) for i in np.linspace(0, 1, n_cls_pred)]

    dpi = 80
    width, height = img.size
    figsize = width / float(dpi), height / float(dpi)

    fig  = plt.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.imshow(img)
    for box, label in zip(detection_rects, detection_labels):
        x1, y1, w, h = box
        color = colors[int(np.where(unique_labels == label)[0])]
        bbox = patches.Rectangle((x1, y1), w, h, linewidth=3,edgecolor=color, facecolor = "none")
        ax.add_patch(bbox)
        plt.text( x1, y1, s=dataset.category_names[int(label)], color="white",fontsize="x-large",
                verticalalignment="top", bbox={"color": color, "pad": 0}, fontfamily="serif")
    plt.axis("off")
    plt.savefig("tesp.jpg", dpi = dpi, transparent=True)
    plt.show()

In [9]:
dataset[16]
for i in range(50):
    im, bb = dataset[i]
    plt_img_bb(im, bb)

display-im6.q16: unable to open X server `' @ error/display.c/DisplayImageCommand/432.
