In [2]:
import shutil
import os
import numpy as np
import argparse

def get_files_from_folder(path: str) -> list:
    """
    Get all the images from the given folder.
    path: Path to the images
    return: ndarray of all the files found
    """
    files = os.listdir(path)
    return np.asarray(files)

def main(path_to_data: str, path_to_test_data: str, train_ratio: float):
    """
    Split images into train and test according to the given ratio.
    path_to_data: Path to the folder where the train dataset will be stored
    path_to_test_data: Path to the folder where the test dataset will be stored
    train_ratio: percentage of the dataset that will be destined for training 
    """
    # get dirs
    _, dirs, _ = next(os.walk(path_to_data))

    # calculates how many train data per class
    data_counter_per_class = np.zeros((len(dirs)))
    for i in range(len(dirs)):
        path = os.path.join(path_to_data, dirs[i])
        files = get_files_from_folder(path)
        data_counter_per_class[i] = len(files)
    test_counter = np.round(data_counter_per_class * (1 - train_ratio))

    # transfers files
    for i in range(len(dirs)):
        path_to_original = os.path.join(path_to_data, dirs[i])
        path_to_save = os.path.join(path_to_test_data, dirs[i])

        #creates dir
        if not os.path.exists(path_to_save):
            os.makedirs(path_to_save)
        files = get_files_from_folder(path_to_original)
        # moves data
        for j in range(int(test_counter[i])):
            dst = os.path.join(path_to_save, files[j])
            src = os.path.join(path_to_original, files[j])
            shutil.move(src, dst)


def parse_args():
    """
    For a user-friendly command-line interface in order to be run as a script.
    """
    parser = argparse.ArgumentParser(description="Dataset divider")
    parser.add_argument("--data_path", required=True,
    help="Path to data")
    parser.add_argument("--test_data_path_to_save", required=True,
    help="Path to test data where to save")
    parser.add_argument("--train_ratio", required=True,
    help="Train ratio - 0.9 means splitting data in 90 % train and 10 % test")
    return parser.parse_args()

if __name__ == "__main__":
    main('train', 'val', float(0.9))