In [1]:
import os
import re

def update_dict(data_dict, key, img_name):
    """Keeps track of how many images each pokemon has"""
    if key in data_dict:
        data_dict[key]['count'] += 1
        data_dict[key]['img_names'].append(img_name)
    else:
        data_dict[key] = {'count': 1, 'img_names': [img_name]}
    return data_dict

def get_data_dict():
    """Loops through images to count the number of images and keeps track of the image file names for augmentation"""
    data_dict = {}
    for image_name in os.listdir('./training_data'):
        pokedex_num = int(re.findall("^(\d+)", image_name)[0])
        data_dict = update_dict(data_dict, pokedex_num, image_name)
    return data_dict
pokemon_image_data = get_data_dict()

In [2]:
from PIL import Image, ImageOps, ImageFilter
import random
import numpy as np


def create_folder(folder_name):
    """If folder does not exist, create it"""
    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)

pokedex={}
def update_pokedex_counts(pokedex_num):
    """Updates the number of images for training/testing. Used for file naming purposes"""
    if pokedex_num in pokedex:
        pokedex[pokedex_num] +=1
    else: 
        pokedex[pokedex_num] = 0

def save_image(img, pokedex_num):
    """Resizes image and saves to specified directory"""
    create_folder('training_data_augmented')
    num_images = pokedex[pokedex_num]
    update_pokedex_counts(pokedex_num)
    width, height = img.size
    if width != 224 or height != 224:
        img = img.resize((224, 224)) # Training resolution
    img.convert("RGB").save(f"./training_data_augmented/{pokedex_num}_{num_images}.png")




In [3]:
import cv2

def quantizing(image, shift_amount):
    """Effectively reduces the number of colors in an image by the shift_amount"""
    red = (np.asarray(image)[:, :, 0] >> shift_amount) << shift_amount
    green = (np.asarray(image)[:, :, 1] >> shift_amount) << shift_amount
    blue = (np.asarray(image)[:, :, 2] >> shift_amount) << shift_amount
    return np.stack((red, green, blue), axis=2)


def add_noise(image):
    """Adds Gaussian noise to the image"""
    rand_decimal = random.randint(20, 70)/100  # number between 0.2 and 0.7
    # mean = 0, standard deviation = rand_decimal
    gaussian = np.random.normal(0, rand_decimal, image.size)
    gaussian = gaussian.reshape(
        image.shape[0], image.shape[1], image.shape[2]).astype('uint8')  # reshaping
    return cv2.add(image, gaussian)  # Adding gaussian noise to image


def crop_image(im):
    """Crops the image and then returns back to original size"""
    left = random.randint(10, 30)
    right = random.randint(150, 200)
    top = random.randint(0, 40)
    bottom = random.randint(85, 200)
    return im.crop((left, top, right, bottom)).resize((224, 224))

In [4]:
def get_augmentation(im, prior_augmentation):
   """Randomly applies 1 or more augmentations. Doesn't do mirroring or cropping on images previously augmented to avoid issues such as cropping the pokemon out."""
    num = random.randint( 1, 6) if not prior_augmentation else random.randint(2, 5)
    if num == 1:
        im = ImageOps.mirror(im)  # flips/mirrors the image
    elif num == 2:
        radius = random.randint(1, 5)
        # applies a gaussian blur to the image
        im = im.filter(ImageFilter.GaussianBlur(radius=radius))
    elif num == 3:
        degree = random.randint(1, 359)
        im = im.rotate(degree)
    elif num == 4:
        shift_amount = random.randint(4, 7)
        im = Image.fromarray(quantizing(im, shift_amount))
    elif num == 5:
        im = Image.fromarray(add_noise(np.array(im)))
    elif num == 6:
        im = crop_image(im)

    # returning or adding another augmentation:
    if random.random() < 0.2:
        return get_augmentation(im, True)
    else:
        return im

In [5]:
from tqdm import tqdm


def save_old_images():
    """Save original images in augmented folder for ease of access/naming"""
    for pokedex_num, item in tqdm(pokemon_image_data.items()):
        update_pokedex_counts(pokedex_num)  # initializing
        for img_name in item['img_names']:
            pokedex_num = int(re.findall("^(\d+)", img_name)[0])
            save_image(Image.open(f'./training_data/{img_name}'), pokedex_num)


def augment_images():
    """Loops through each image for each pokemon and applies a random augmentation until the pokemon has 200 images total."""
    for pokedex_num, item in tqdm(pokemon_image_data.items()):
        while pokemon_image_data[pokedex_num]['count'] <= 150:
            random.shuffle(item['img_names'])
            img_name = item['img_names'][0]  # random image
            im = Image.open(f"./training_data/{img_name}").convert('RGB')
            pokemon_image_data[pokedex_num]['count'] += 1
            augmented = get_augmentation(im, False)
            save_image(augmented, pokedex_num)


save_old_images()
augment_images()


100%|██████████| 898/898 [13:56<00:00,  1.07it/s]
100%|██████████| 898/898 [20:06<00:00,  1.34s/it]
