In [38]:
from torchvision.datasets import VOCDetection
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_tensor, to_pil_image
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import numpy as np
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt

In [2]:
# VOC 2007 dataset을 저장할 위치
path2data = './data'
if not os.path.exists(path2data):
    os.mkdir(path2data)

In [25]:
classes = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor"
]

In [45]:
class MyCustomDatasets(VOCDetection):
    def __getitem__(self, idx):
        img = np.array(Image.open(self.images[idx]).convert('RGB'))
        target = self.parse_voc_xml(ET.parse(self.annotations[idx]).getroot())

        coordinates = []
        labels = []

        for obj in target['annotation']['object']:
            coordinates.append([val for key, val in obj['bndbox'].items()])
            labels.append(classes.index(obj['name']))

        if self.transforms:
            img, target = self.transforms(img, coordinates)

        return img, coordinates, labels


In [46]:
train_ds = MyCustomDatasets(path2data, year='2007', image_set='train', download=True)
val_ds = MyCustomDatasets(path2data, year='2007', image_set='test', download=True)

Using downloaded and verified file: ./data/VOCtrainval_06-Nov-2007.tar
Extracting ./data/VOCtrainval_06-Nov-2007.tar to ./data
Using downloaded and verified file: ./data/VOCtest_06-Nov-2007.tar
Extracting ./data/VOCtest_06-Nov-2007.tar to ./data


In [68]:
img, coordinates, labels = train_ds[2]
colors = np.random.randint(0, 255, size=(80,3), dtype=np.uint8)

img = to_pil_image(img)
draw = ImageDraw.Draw(img)
coordinates = np.array(coordinates)
for bbox, label in zip(coordinates, labels):
    color = [int(c) for c in colors[label]]
    name = classes[label]

    draw.rectangle((bbox[0], bbox[1], bbox[2],bbox[3]), outline=(255,0,255), width=3)
    draw.text((bbox[0],bbox[1]), name, fill=(255,255,255,0))
plt.imshow(np.array(img))

TypeError: ignored

In [58]:
colors = np.random.randint(0, 255, size=(80,3), dtype=np.uint8)
colors[2].dtype

dtype('uint8')

In [50]:
coordinates

[['9', '230', '245', '500'],
 ['230', '220', '334', '500'],
 ['2', '178', '90', '500'],
 ['2', '1', '117', '369'],
 ['3', '2', '243', '462'],
 ['225', '1', '334', '486']]