In [None]:
import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    random.seed(0)

    # Adjust the split ratios
    train_ratio = 0.8
    val_ratio = 0.1
    test_ratio = 0.1

    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    train_root = os.path.join(data_root, "train1")
    mk_file(train_root)
    for cla in flower_class:
        mk_file(os.path.join(train_root, cla))

    val_root = os.path.join(data_root, "val1")
    mk_file(val_root)
    for cla in flower_class:
        mk_file(os.path.join(val_root, cla))

    test_root = os.path.join(data_root, "test1")
    mk_file(test_root)
    for cla in flower_class:
        mk_file(os.path.join(test_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)

        # Split the indices for training, validation, and test sets
        train_index = random.sample(images, k=int(num * train_ratio))
        val_test_remaining = list(set(images) - set(train_index))
        val_index = random.sample(val_test_remaining, k=int(len(val_test_remaining) * val_ratio / (val_ratio + test_ratio)))
        test_index = list(set(val_test_remaining) - set(val_index))

        for index, image in enumerate(images):
            image_path = os.path.join(cla_path, image)

            if image in train_index:
                new_path = os.path.join(train_root, cla)
            elif image in val_index:
                new_path = os.path.join(val_root, cla)
            else:
                new_path = os.path.join(test_root, cla)

            copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")

        print()

    print("processing done!")


if __name__ == '__main__':
    main()
