In [1]:
"""python
    xml文件解析
"""

import json
import os
import torch
import random
import xml.etree.ElementTree as ET    #解析xml文件所用工具
import torchvision.transforms.functional as FT

In [2]:
# os.walk方法用于通过在目录树中游走输出在目录中的文件名
def getFileList(path):
    for root, dirs, files in os.walk(path):
        return files
file_name = getFileList('../../train/03/Annotations/')
file_name

In [3]:
#解析xml文件，最终返回这张图片中所有目标的标注框及其类别信息，以及这个目标是否是一个difficult目标
def parse_annotation(annotation_path,name):
    #解析xml
    tree = ET.parse(os.path.join(annotation_path,name))
    tree_name = os.path.splitext(name)[0]
    root = tree.getroot()

    boxes = list()    #存储bbox  分别表示目标的左上角和右下角坐标
    labels = list()    #存储bbox对应的label  目标类别
    difficulties = list()    #存储bbox对应的difficult信息  表示此目标是否是一个难以识别的目标

    #遍历xml文件中所有的object，有多少个object就有多少个目标
    for object in root.iter('object'):
        #提取每个object的difficult、label、bbox信息
        difficult = int(object.find('difficult').text == '1')
        label = object.find('name').text.lower().strip()  ## strip去除首尾的空格
        bbox = object.find('bndbox')
        xmin = int(bbox.find('xmin').text) - 1
        ymin = int(bbox.find('ymin').text) - 1
        xmax = int(bbox.find('xmax').text) - 1
        ymax = int(bbox.find('ymax').text) - 1
        #存储
        boxes.append([xmin, ymin, xmax, ymax])
        # labels.append(annotation_path[-11:-9])
        labels.append(label)
        difficulties.append(difficult)
    
    #返回包含图片标注信息的字典
    return {'name':tree_name,'boxes': boxes, 'labels': labels, 'difficulties': difficulties}

In [4]:
# test 
tmp=parse_annotation('../../../train/02/Annotations/','02_0017.xml')
tmp

{'name': '02_0017',
 'boxes': [[123, 127, 176, 154]],
 'labels': ['defect'],
 'difficulties': [0]}

In [5]:
def create_data_lists(path, output_folder):
    
    #获取数据集的绝对路径
    path = os.path.abspath(path)

    train_images = list()
    train_objects = list()
    n_objects = 0
    
    file_name = getFileList(path)
    print(file_name)
    #根据图片id，解析图片的xml文件，获取标注信息
    for id in file_name:
        # Parse annotation's XML file
        objects = parse_annotation(path, id)
        print(os.path.join(path, id))
        print(objects)
        if len(objects['boxes']) == 0:    #如果没有目标则跳过
            continue
        n_objects += len(objects['boxes'])        #统计目标总数
        train_objects.append(objects)    #存储每张图片的标注信息到列表train_objects

    with open(os.path.join(output_folder, 'TRAIN_objects.json'), 'w') as j:
        json.dump(train_objects, fp=j)

In [6]:
# create_data_lists(path='../../../train/01/Annotations',output_folder='../../../train/01')
create_data_lists(path='../../../train/02/Annotations',output_folder='../../../train/02')
# create_data_lists(path='../../../train/03/Annotations',output_folder='../../../train/03')

['000081.xml', '000082.xml', '000083.xml', '000084.xml', '000085.xml', '000086.xml', '000087.xml', '000088.xml', '000089.xml', '000090.xml', '000091.xml', '000092.xml', '000093.xml', '000094.xml', '000095.xml', '000096.xml', '000097.xml', '000098.xml', '000099.xml', '000100.xml', '000101.xml', '000102.xml', '000103.xml', '000104.xml', '000105.xml', '000106.xml', '000107.xml', '000108.xml', '000109.xml', '000110.xml', '000111.xml', '000112.xml', '000113.xml', '000114.xml', '000115.xml', '000116.xml', '000117.xml', '000118.xml', '000119.xml', '000120.xml', '000121.xml', '000122.xml', '000123.xml', '000124.xml', '000125.xml', '000126.xml', '000127.xml', '000128.xml', '000129.xml', '000130.xml', '000131.xml', '000132.xml', '000133.xml', '000134.xml', '000135.xml', '000136.xml', '000137.xml', '000138.xml', '000139.xml', '000140.xml', '000141.xml', '000142.xml', '000143.xml', '000144.xml', '000145.xml', '000146.xml', '000147.xml', '000148.xml', '000149.xml', '000150.xml', '000151.xml', '0001