In [15]:
import os.path as osp
import torch
import torch.utils.data as data
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from torchvision import datasets
from torchvision.transforms import transforms

dataset = datasets.VOCDetection(root='../datasets',download=False,year='2007',
                                transform=transforms.Compose([transforms.ToTensor()]),image_set='val')
a = next(iter(dataset))



<class 'dict'>


In [24]:
VOC_CLASSES = (  # always index 0
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')

In [None]:
class VOCAnnotationTransform():
    """将 VOC 标注转换为包含边界框坐标和标签索引的张量。

    该类初始化时使用一个字典来查找类别名称对应的索引。

    参数：
        class_to_ind (dict, 可选): 一个类别名称到索引的字典查找表
            （默认情况下，VOC 20 个类别按字母顺序索引）
        keep_difficult (bool, 可选): 是否保留难以识别的实例
            （默认值: False）
        height (int): 图像的高度
        width (int): 图像的宽度
"""
    def __init__(self, class_to_ind=None, keep_difficult=False):
        self.class_to_ind = class_to_ind or dict(
            zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        self.keep_difficult = keep_difficult

    def __call__(self, target, width, height):
        """
    参数：
        target (annotation): 需要转换为可用格式的目标标注，
            该标注将是一个 ET.Element。

    返回：
        一个包含边界框列表的列表，格式为 [边界框坐标, 类别名称]。
"""

        res = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # scale height or width
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]

        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

In [None]:
class VOCDetection(data.Dataset):
    def __init__(self, root, split='trainval', transform=None, target_transform=VOCAnnotationTransform(),year = '2007'):
        self.root = root
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        self.transform = transform
        self.target_transform = target_transform
        rootpath = osp.join(root,'VOC'+year)
        for line in open(osp.join(rootpath, 'ImageSets', 'Main', split + '.txt')):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, index):
        img, target,h,w = self.pull_item(index)
        return img, target

    def pull_item(self, index):
        img_id = self.ids[index]

        target = ET.parse(self._annopath % img_id).getroot()
        img = cv2.imread(self._imgpath % img_id)
        height, width, channels = img.shape

        if self.target_transform is not None:
            target = self.target_transform(target, width, height)

        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            # to rgb
            img = img[:, :, (2, 1, 0)]
            # img = img.transpose(2, 0, 1)
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target, height, width
        # return torch.from_numpy(img), target, height, width

    def __len__(self):
        return len(self.ids)

    def pull_image(self, index):
        img_id = self.ids[index]
        return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)

    def pull_anno(self, index):
        img_id = self.ids[index]
        anno = ET.parse(self._annopath % img_id).getroot()
        gt = self.target_transform(anno, 1, 1)
        return img_id[1], gt

    def pull_tensor(self, index):
        return torch.Tensor(self.pull_image(index)).unsqueeze_(0)