In [3]:
import os
import cv2
import xml.etree.ElementTree as ET
import numpy as np
from matplotlib import pyplot as plt


def get_data(input_path, cat=None):
    all_imgs = []

    classes_count = {}

    class_mapping = {}

    visualise = False

    # add VOC2012 if using.
    data_paths = [os.path.join(input_path, s) for s in ['VOC2012']]
    
    print("data path:", data_paths)

    print('Parsing annotation files')
 

    for data_path in data_paths:

        annot_path = os.path.join(data_path, 'Annotations')
        imgs_path = os.path.join(data_path, 'JPEGImages')

        # load all train images or only one category.
        imgsets_path_trainval = os.path.join(
            data_path, 'ImageSets', 'Main', 'train.txt')
        imgsets_path_test = os.path.join(
            data_path, 'ImageSets', 'Main', 'val.txt')

        trainval_files = []
        test_files = []
        try:
            with open(imgsets_path_trainval) as f:
                for line in f:
                    trainval_files.append(line.strip() + '.jpg')
        except Exception as e:
            print(e)

        try:
            with open(imgsets_path_test) as f:
                for line in f:
                    test_files.append(line.strip() + '.jpg')
        except Exception as e:
            if data_path[-7:] == 'VOC2012':
                # this is expected, most pascal voc distibutions dont have the test.txt file
                pass
            else:
                print(e)

        annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]
        idx = 0
        for annot in annots:
            try:
                idx += 1

                et = ET.parse(annot)
                element = et.getroot()

                element_objs = element.findall('object')
                element_filename = element.find('filename').text
                element_width = int(element.find('size').find('width').text)
                element_height = int(element.find('size').find('height').text)

                if len(element_objs) > 0:
                    annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
                                       'height': element_height, 'bboxes': []}

                    if element_filename in trainval_files:#if the image id appears in the trainval list 
                        annotation_data['imageset'] = 'trainval'
                    elif element_filename in test_files: #if the image id appears in the test list ("val.txt")
                        annotation_data['imageset'] = 'test'
                    else: #everything else goes to the trainval 
                        annotation_data['imageset'] = 'trainval'

                for element_obj in element_objs:
                    class_name = element_obj.find('name').text
                    if class_name not in classes_count:
                        classes_count[class_name] = 1
                    else:
                        classes_count[class_name] += 1

                    if class_name not in class_mapping:
                        class_mapping[class_name] = len(class_mapping)

                    obj_bbox = element_obj.find('bndbox')
                    x1 = int(round(float(obj_bbox.find('xmin').text)))
                    y1 = int(round(float(obj_bbox.find('ymin').text)))
                    x2 = int(round(float(obj_bbox.find('xmax').text)))
                    y2 = int(round(float(obj_bbox.find('ymax').text)))
                    difficulty = 1  # parse all files.
                    annotation_data['bboxes'].append(
                        {'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})

                if cat and class_name == cat:
                    all_imgs.append(annotation_data)
                elif not cat:
                    all_imgs.append(annotation_data)

                if visualise:
                    img = plt.imread(annotation_data['filepath'])
                    for bbox in annotation_data['bboxes']:
                        cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[
                            'x2'], bbox['y2']), (0, 0, 255))
                    
                        plt.figure(figsize=(15,10))
                        plt.title('Image with Bounding Box')
                        plt.imshow(img)
                        plt.axis("off")
                        plt.show()

               

            except Exception as e:
                print(e)
                continue
    return all_imgs, classes_count, class_mapping


In [4]:
all_imgs,classes_count,class_mapping=get_data(r'D:\Computer science\FASTER_RCNN_COLAB\TRAIN\VOCdevkit')

data path: ['D:\\Computer science\\FASTER_RCNN_COLAB\\TRAIN\\VOCdevkit\\VOC2012']
Parsing annotation files


In [5]:
all_imgs

[{'filepath': 'D:\\Computer science\\FASTER_RCNN_COLAB\\TRAIN\\VOCdevkit\\VOC2012\\JPEGImages\\2007_000027.jpg',
  'width': 486,
  'height': 500,
  'bboxes': [{'class': 'person',
    'x1': 174,
    'x2': 349,
    'y1': 101,
    'y2': 351,
    'difficult': 1}],
  'imageset': 'trainval'},
 {'filepath': 'D:\\Computer science\\FASTER_RCNN_COLAB\\TRAIN\\VOCdevkit\\VOC2012\\JPEGImages\\2007_000032.jpg',
  'width': 500,
  'height': 281,
  'bboxes': [{'class': 'aeroplane',
    'x1': 104,
    'x2': 375,
    'y1': 78,
    'y2': 183,
    'difficult': 1},
   {'class': 'aeroplane',
    'x1': 133,
    'x2': 197,
    'y1': 88,
    'y2': 123,
    'difficult': 1},
   {'class': 'person',
    'x1': 195,
    'x2': 213,
    'y1': 180,
    'y2': 229,
    'difficult': 1},
   {'class': 'person',
    'x1': 26,
    'x2': 44,
    'y1': 189,
    'y2': 238,
    'difficult': 1}],
  'imageset': 'trainval'},
 {'filepath': 'D:\\Computer science\\FASTER_RCNN_COLAB\\TRAIN\\VOCdevkit\\VOC2012\\JPEGImages\\2007_000033.jpg

In [5]:
classes_count

{'chair': 1432,
 'car': 1644,
 'horse': 406,
 'person': 5447,
 'bicycle': 418,
 'cat': 389,
 'dog': 538,
 'train': 328,
 'aeroplane': 331,
 'diningtable': 310,
 'tvmonitor': 367,
 'bird': 599,
 'bottle': 634,
 'motorbike': 390,
 'pottedplant': 625,
 'boat': 398,
 'sofa': 425,
 'sheep': 353,
 'cow': 356,
 'bus': 272}