In [1]:
from pathlib import Path
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split

accepted_suffix = ('.jpg', '.png', '.jpeg', '.JPG', '.PNG', '.bmp', '.gif')
root_path = Path('./')

###
input_dir = 'all_images'
output_dir = 'training'
###

resize_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
])

original_data_path = root_path/input_dir
data_path = {x: root_path/output_dir/ x for x in ['train', 'val']}
split_ratio = 0.8

for x in ['train', 'val']:
    data_path[x].mkdir(parents=True, exist_ok=True)

for class_name in original_data_path.iterdir():
    class_name = class_name.name
    original_class_path = original_data_path/class_name
    if not original_class_path.is_dir():
        continue

    all_images = []
    for filename in original_class_path.iterdir():
        if filename.suffix in accepted_suffix:
            try:
                with Image.open(filename) as img:
                    if img.mode == 'RGB':
                        all_images.append(filename.name)
            except Exception as e:
                print(f'Error occured. {filename}: {e}')

    images = {} 
    images['train'], images['val'] = train_test_split(all_images,
                                                      train_size=split_ratio,
                                                      random_state=42,
                                                      )
    for x in ['train', 'val']:
        class_path = data_path[x]/class_name
        class_path.mkdir(parents=True, exist_ok=True)
        for image_name in images[x]:
            src = original_class_path/image_name
            dst = class_path/image_name
            with Image.open(src) as img:
                resized_img = resize_transform(img)
                resized_img.save(dst)