In [1]:
from sklearn.model_selection import train_test_split
import shutil
import os

def split_dataset(image_dir, annotation_dir, input_directory, train_ratio=0.8):
    images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
    annotations = sorted(os.listdir(annotation_dir))

    # Ensure corresponding annotation files exist (FILE NAMES NEED TO MATCH EXACTLY)
    images_with_annotations = []
    annotations_filtered = []
    for image in images:
        annotation = image.rsplit('.', 1)[0] + '.txt'
        if annotation in annotations:
            images_with_annotations.append(image)
            annotations_filtered.append(annotation)

    print(f'Found {len(images)} images.')
    print(f'Found {len(annotations)} annotations.')
    print(f'Found {len(images_with_annotations)} images with annotations.')
    print(f'Found {len(annotations_filtered)} annotations.')

    # Split into train and valid sets
    train_images, valid_images, train_annotations, valid_annotations = train_test_split(
        images_with_annotations, annotations_filtered, train_size=train_ratio
    )

    # Function to copy files to a target directory
    def copy_files(files, source_dir, target_dir):
        for file in files:
            shutil.copy(os.path.join(source_dir, file), os.path.join(target_dir, file))

    # Create directories and copy files
    os.makedirs(f'{input_directory}/train/images', exist_ok=True)
    os.makedirs(f'{input_directory}/train/labels', exist_ok=True)
    os.makedirs(f'{input_directory}/valid/images', exist_ok=True)
    os.makedirs(f'{input_directory}/valid/labels', exist_ok=True)

    copy_files(train_images, image_dir, f'{input_directory}/train/images')
    copy_files(valid_images, image_dir, f'{input_directory}/valid/images')
    copy_files(train_annotations, annotation_dir, f'{input_directory}/train/labels')
    copy_files(valid_annotations, annotation_dir, f'{input_directory}/valid/labels')

    if overwrite_yolo_dataset:
        path = f'{input_directory}/dataset/'
        # overwrite the directories if they exist
        os.makedirs(f'{path}/train/images', exist_ok=True)
        os.makedirs(f'{path}/train/labels', exist_ok=True)
        os.makedirs(f'{path}/valid/images', exist_ok=True)
        os.makedirs(f'{path}/valid/labels', exist_ok=True)

        copy_files(train_images, image_dir, f'{path}/train/images')
        copy_files(valid_images, image_dir, f'{path}/valid/images')
        copy_files(train_annotations, annotation_dir, f'{path}/train/labels')
        copy_files(valid_annotations, annotation_dir, f'{path}/valid/labels')

overwrite_yolo_dataset = True
input_directory = '/Users/eric/Desktop/2-Career/Projects/ObjectDetection/dog_park/'
images_dir = f'{input_directory}/images'
annotations_dir = f'{input_directory}/labels'
split_dataset(images_dir, annotations_dir, input_directory, train_ratio=0.8)


Found 330 images.
Found 330 annotations.
Found 330 images with annotations.
Found 330 annotations.
