In [1]:
import numpy as np
import random
import os
from glob import glob
import pathlib
from PIL import Image

def img_train_test_split(img_dir, split_size=0.25):
    
    #Function that creates a new directory
    def create_dir(path):
        try:
            os.mkdir(path)
        except OSError:
            print ("Creation of the directory %s failed" % path)
        else:
            print ("Successfully created the directory %s " % path)
    
    img_dir = pathlib.Path(img_dir)
    CLASS_NAMES = np.array([item.name for item in img_dir.glob('*')])
    print('Found {} classes:'.format(len(CLASS_NAMES)), CLASS_NAMES)
    
    train_path = os.path.join(img_dir, 'train')
    test_path = os.path.join(img_dir, 'test')

    create_dir(train_path)
    create_dir(test_path)
        
    for class_name in CLASS_NAMES:    
        original_class_path = os.path.join(img_dir, class_name) #the original folder countaining the images
        train_class_path = os.path.join(train_path, class_name) 
        create_dir(train_class_path) #the folder where we'll store the images for training
        num_images = len(os.listdir(original_class_path))
        images = os.listdir(original_class_path)
        for _ in range(int(num_images*(1-split_size))):
            random_img_path = random.choice(images) #chooses a random image
            images.remove(random_img_path) #remove the chosen image from the list
            img_path = os.path.join(train_class_path, random_img_path)
            img = Image.open(os.path.join(original_class_path, random_img_path))
            if random_img_path not in os.listdir(train_class_path):
                img.save(os.path.join(train_class_path, random_img_path), 'JPEG')
                
        test_class_path = os.path.join(test_path, class_name) 
        create_dir(test_class_path) #the folder where we'll store the images for training
        for _ in range(num_images-int(num_images*(1-split_size))):
            random_img_path = random.choice(images) #chooses a random image
            images.remove(random_img_path) #remove the chosen image from the list
            img_path = os.path.join(test_class_path, random_img_path)
            img = Image.open(os.path.join(original_class_path, random_img_path))
            if random_img_path not in os.listdir(test_class_path):
                img.save(os.path.join(test_class_path, random_img_path), 'JPEG')