In [1]:
import os

import torch
from torchvision.datasets import CocoDetection

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image
   
class myCocoDetection(CocoDetection):
    
    def __init__(
        self, root, annFile, remove_invalid_data, transform):
        super(myCocoDetection, self).__init__(root, annFile)
    
        if remove_invalid_data:
            ids = []
            for img_id in self.ids:
                ann_ids = self.coco.getAnnIds(imgIds=img_id)
                anno = self.coco.loadAnns(ann_ids)
                if bool(anno):
                    ids.append(img_id)
            self.ids = ids
            
        self.transform = transform
    
    def _load_image(self, idx: int):
        path = self.coco.loadImgs(idx)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")
    
    def _load_target(self, idx):
        return self.coco.loadAnns(self.coco.getAnnIds(idx))
    
    def __getitem__(self, index: int):
        idx = self.ids[index]
        image = self._load_image(idx)
        target = self._load_target(idx)
        
        bboxes, labels = [], []
        for obj in target:
            bbox = [obj['bbox'][0],
                    obj['bbox'][1],
                    obj['bbox'][0] + obj['bbox'][2],
                    obj['bbox'][1] + obj['bbox'][3]]
            bboxes.append(bbox)
            
            labels.append(obj['category_id'])
        
        if self.transform:
            image = self.transform(image)
            bboxes = self.transform(np.array(bboxes)).reshape(-1, 4)
            
            targets ={}
            labels = torch.tensor(labels).type(torch.int64)
            bboxes = bboxes.type(torch.FloatTensor)
            targets['boxes'] = bboxes
            targets['labels'] = labels
            
            return image, targets        
            
        return image, bboxes, labels
    
    

In [3]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

image_path = '../../data/COCO/val2017'
ann_path = '../../data/COCO/annotations/instances_val2017.json'
def collate_fn(batch):
    return tuple(zip(*batch))

dataset = myCocoDetection(root=image_path, annFile=ann_path, transform=transforms.ToTensor(), remove_invalid_data=True)
dataset_loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
print(iter(dataset_loader).next())

loading annotations into memory...
Done (t=0.70s)
creating index...
index created!
((tensor([[[0.6667, 0.6784, 0.6863,  ..., 0.2706, 0.2667, 0.2745],
         [0.6745, 0.6902, 0.6941,  ..., 0.2706, 0.2824, 0.2784],
         [0.6863, 0.6941, 0.6980,  ..., 0.2745, 0.2706, 0.2784],
         ...,
         [0.7373, 0.7176, 0.7569,  ..., 0.7294, 0.7294, 0.7333],
         [0.7294, 0.7333, 0.7294,  ..., 0.7765, 0.7647, 0.7294],
         [0.7294, 0.7333, 0.7294,  ..., 0.5059, 0.4941, 0.4196]],

        [[0.5333, 0.5569, 0.5647,  ..., 0.2980, 0.2980, 0.2784],
         [0.5529, 0.5686, 0.5725,  ..., 0.3020, 0.3137, 0.2941],
         [0.5647, 0.5725, 0.5765,  ..., 0.3059, 0.3020, 0.2941],
         ...,
         [0.7412, 0.7176, 0.7333,  ..., 0.6157, 0.6157, 0.6118],
         [0.7176, 0.7216, 0.7176,  ..., 0.5255, 0.4706, 0.3451],
         [0.7176, 0.7216, 0.7176,  ..., 0.2353, 0.2235, 0.1608]],

        [[0.2863, 0.3020, 0.3098,  ..., 0.1647, 0.1529, 0.1451],
         [0.3020, 0.3137, 0.3176,  ...

In [42]:
def coco_show(dataloader, figsize):
    CLASSES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]
    
    images, targets = iter(dataloader).next()
    
    fig, axes = plt.subplots(*figsize, figsize=(15,10))
    axes = axes.ravel()
    for i in range(len(images)):
        img = images[i].permute(1,2,0)
        
        for box, label in zip(bboxes[i][0], labels[i]):
            rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=1, edgecolor='r', facecolor='none') 
            axes[i].add_patch(rect)
            axes[i].text(box[0], box[1], CLASSES[label], fontsize=15)
            
        axes[i].imshow(img)
        axes[i].axis('off')
        
    plt.tight_layout()
    plt.show() 

In [6]:
images, targets = iter(dataset_loader).next()

In [None]:
for box, label in targets[i]:
    

In [9]:
targets[0]

{'boxes': tensor([[236.9800, 142.5100, 261.6800, 212.0100],
         [  7.0300, 167.7600, 156.3500, 262.6300],
         [557.2100, 209.1900, 638.5600, 287.9200],
         [358.9800, 218.0500, 414.9800, 320.8800],
         [290.6900, 218.0000, 352.5200, 316.4800],
         [413.2000, 223.0100, 443.3700, 304.3700],
         [317.4000, 219.2400, 338.9800, 230.8300],
         [412.8000, 157.6100, 465.8500, 295.6200],
         [384.4300, 172.2100, 399.5500, 207.9500],
         [512.2200, 205.7500, 526.9600, 221.7200],
         [493.1000, 174.3400, 513.3900, 282.6500],
         [604.7700, 305.8900, 619.1100, 351.6000],
         [613.2400, 308.2400, 626.1200, 354.6800],
         [447.7700, 121.1200, 461.7400, 143.0000],
         [549.0600, 309.4300, 585.7400, 399.1000],
         [350.7600, 208.8400, 362.1300, 231.3900],
         [412.2500, 219.0200, 421.8800, 231.5400],
         [241.2400, 194.9900, 255.4600, 212.6200],
         [336.7900, 199.5000, 346.5200, 216.2300],
         [321.2100, 23