### **数据预处理**

**一、数据集label格式转化：xml to yolo**

In [2]:
import xml.etree.ElementTree as ET
import glob
import os
from tqdm import tqdm

**xml格式处理简单学习**

In [3]:
import xml.etree.ElementTree as ET
Annotations_path = "VOC2012/Annotations/*.xml"
for filename in glob.glob(Annotations_path):
    tree = ET.parse(filename)
    root = tree.getroot()
    print(root)
    break

<Element 'annotation' at 0x7582dc2e4040>


In [4]:
# 遍历元素
for child in root:
    print(f"元素名: {child.tag}, 属性: {child.attrib}, 文本: {child.text}")

元素名: filename, 属性: {}, 文本: 2011_007141.jpg
元素名: folder, 属性: {}, 文本: VOC2011
元素名: object, 属性: {}, 文本: 
		
元素名: segmented, 属性: {}, 文本: 0
元素名: size, 属性: {}, 文本: 
		
元素名: source, 属性: {}, 文本: 
		


In [5]:
# 查找指定的元素
filename = root.find('filename')
print(type(filename.text))
print((filename.text))

<class 'str'>
2011_007141.jpg


In [7]:
# 对于root有多个同名的tag，使用迭代
# 解析所有目标
objects = []
for obj in root.iter('object'):
    obj_data = {
        'name': obj.find('name').text,
        'bbox': [
            int(obj.find('bndbox/xmin').text),
            int(obj.find('bndbox/ymin').text),
            int(obj.find('bndbox/xmax').text),
            int(obj.find('bndbox/ymax').text)
        ]
    }
objects.append(obj_data)
print(objects)

[{'name': 'person', 'bbox': [260, 91, 356, 306]}]


**下面开始处理数据集**

将每个xml中的图片filename，物品的类别及bbox提取出来(一张图可能有多个)

In [15]:
import xml.etree.ElementTree as ET

class_map = {} # 定义类别字典，用于class2idx
class_nums = 0

# xml文件所处路径
Annotations_path = "VOC2012/Annotations/*.xml"

# 划分数据集为训练与测试集
f_train = open("train.txt","w",encoding = 'utf-8') # w为重写模式
f_test = open("test.txt","w",encoding = 'utf-8')

# 设置随机种子去划分
import random
all_xml_files = glob.glob(Annotations_path)
random.seed(666)
shuffled_data = random.sample(all_xml_files,len(all_xml_files)) # 打乱但不修改原数据
split_idx = int(len(all_xml_files) * 0.8)
train_xml_files = shuffled_data[:split_idx]
test_xml_files = shuffled_data[split_idx:]

# 处理训练集
for xmlname in train_xml_files:
    tree = ET.parse(xmlname)
    root = tree.getroot()
    img_name = root.find("filename").text
    f_train.write(f"{img_name} ")
    for obj in root.iter("object"):
        obj_name = obj.find("name").text
        if obj_name not in class_map:
            class_map[obj_name] = class_nums
            class_nums += 1
        x_min = obj.find("bndbox/xmin").text
        x_max = obj.find("bndbox/xmax").text
        y_min = obj.find("bndbox/ymin").text
        y_max = obj.find("bndbox/ymax").text
        f_train.write(f"{x_min} {y_min} {x_max} {y_min} {class_map[obj_name]} ") # 按特定顺序写入文件
    f_train.write("\n")

# 处理测试集
for xmlname in test_xml_files:
    tree = ET.parse(xmlname)
    root = tree.getroot()
    img_name = root.find("filename").text
    f_test.write(f"{img_name} ")
    for obj in root.iter("object"):
        obj_name = obj.find("name").text
        if obj_name not in class_map:
            class_map[obj_name] = class_nums
            class_nums += 1
        x_min = obj.find("bndbox/xmin").text
        x_max = obj.find("bndbox/xmax").text
        y_min = obj.find("bndbox/ymin").text
        y_max = obj.find("bndbox/ymax").text
        f_test.write(f"{x_min} {y_min} {x_max} {y_min} {class_map[obj_name]} ") # 按特定顺序写入文件
    f_test.write("\n")

# 关闭文件
f_train.close()
f_test.close()

# 检测类别数量是否符合
print(class_map)
print(class_nums) # 预期20类

{'person': 0, 'bird': 1, 'bus': 2, 'car': 3, 'sofa': 4, 'cat': 5, 'chair': 6, 'cow': 7, 'motorbike': 8, 'bicycle': 9, 'bottle': 10, 'diningtable': 11, 'dog': 12, 'pottedplant': 13, 'aeroplane': 14, 'train': 15, 'horse': 16, 'boat': 17, 'tvmonitor': 18, 'sheep': 19}
20


**二、生成Dataset，可用于迭代器生成**