In [1]:
from pathlib import Path
import random
import cv2
import xml.etree.ElementTree as ET
import shutil
import tqdm

In [2]:
shutil.rmtree('training_demo/images/train')
shutil.rmtree('training_demo/images/test')

Path('training_demo/images/train/').mkdir(exist_ok= True, parents = True)
Path('training_demo/images/test').mkdir(exist_ok= True, parents = True)

In [3]:
import pdb

In [4]:
for cls in ['banana_red', 'apple_kashmir']:
    cls_dataset = []
    for xml_path in Path(f'object_detection_dataset/{cls}').glob('*.xml'):
        name = xml_path.stem
        img_path = xml_path.parent / (name+'.jpg')
        img = cv2.imread(str(img_path))
        if img.shape == (2208, 4608, 3):
            cls_dataset.append((img_path, xml_path))
            
    random.seed(0)
    random.shuffle(cls_dataset)
    train_split = cls_dataset[:int(len(cls_dataset)*.9)]
    test_split = cls_dataset[int(len(cls_dataset)*.9):]
    
    for img_path, xml_path in train_split:
        shutil.copy(img_path, 'training_demo/images/train/')
        shutil.copy(xml_path, 'training_demo/images/train/')
    
    for img_path, xml_path in test_split:
        shutil.copy(img_path, 'training_demo/images/test/')
        shutil.copy(xml_path, 'training_demo/images/test/')

In [5]:
new_width = 640
new_height = 320

def resize_xml(tree):
    root = tree.getroot()
    width = int(root.find('size').find('width').text)
    height = int(root.find('size').find('height').text)
    
    root.find('size').find('width').text = str(new_width)
    root.find('size').find('height').text = str(new_height)
    
    for obj in root.iter('object'):
        name = obj.find('name').text
        box = obj.find('bndbox')
        x_min = int(box.find('xmin').text); y_min = int(box.find('ymin').text)
        x_max = int(box.find('xmax').text); y_max = int(box.find('ymax').text)
        box.find('xmin').text = str(int(float(x_min) / width * new_width))
        box.find('xmax').text = str(int(float(x_max) / width * new_width))
        box.find('ymin').text = str(int(float(y_min) / height * new_height))
        box.find('ymax').text = str(int(float(y_max) / height * new_height))

def resize_img(img, new_width, new_height):
    resized_img = cv2.resize(img, (new_width, new_height))
    return resized_img

# train

In [13]:
xml_files = list(Path('training_demo/images/train').glob('*.xml'))
for file in tqdm.tqdm(xml_files):
    # resize xml
    tree = ET.parse(str(file))
    root = tree.getroot()
    filename = root.find('filename').text  
    if root.find('folder').text == 'Apple Kashmir':
        root.find('folder').text = 'apple_kashmir'
    root.find('path').text = str((file.parent / (str(file.stem)+'.jpg')).resolve())
    resize_xml(tree)
    tree.write(str(file))
    
    # resize img
    img_path = file.parent / (file.stem+'.jpg')
    img = cv2.imread(str(img_path))
    resized_img = resize_img(img, new_width, new_height)
    cv2.imwrite(str(img_path), resized_img)

100%|███████████████████████████████████████████████████████████████████████| 122/122 [00:14<00:00,  8.62it/s]


# test

In [14]:
xml_files = list(Path('training_demo/images/test').glob('*.xml'))
for file in tqdm.tqdm(xml_files):
    # resize xml
    tree = ET.parse(str(file))
    root = tree.getroot()
    filename = root.find('filename').text    
    if root.find('folder').text == 'Apple Kashmir':
        root.find('folder').text = 'apple_kashmir'
    root.find('path').text = str((file.parent / (str(file.stem)+'.jpg')).resolve())
    resize_xml(tree)
    tree.write(str(file))
    
    # resize img
    img_path = file.parent / (file.stem+'.jpg')
    img = cv2.imread(str(img_path))
    resized_img = resize_img(img, new_width, new_height)
    cv2.imwrite(str(img_path), resized_img)

100%|█████████████████████████████████████████████████████████████████████████| 15/15 [00:01<00:00,  7.93it/s]
