https://teddylee777.github.io/pytorch/pytorch-image-transforms/#google_vignette

https://pytorch.org/vision/main/transforms.html
- https://pytorch.org/vision/main/generated/torchvision.transforms.AutoAugment.html

In [1]:
from torch.utils.data import Dataset
from torchvision import transforms
from pandas.core.common import flatten
import glob

In [5]:
image_augmentation = transforms.Compose(
    [
        transforms.Resize((384, 384)),
        transforms.AutoAugment(
            policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10
        ),
    ]
)

original_dataset_path = './datasets/recaptcha-dataset'
augmented_dataset_path = './datasets/augmented_recaptcha-dataset'

In [3]:
from PIL import Image

class RecaptchaDataset(Dataset):
    def __init__(self, dataset_path):
        super(RecaptchaDataset, self).__init__()

        labels, image_paths = [], []
        for dataset in glob.glob(dataset_path + '/*'):
            label = dataset.split('/')[-1]
            if label in ('Mountain', 'Other', 'readme.txt'):
                continue

            labels.append(label)

            image_paths_for_label = glob.glob(dataset + '/*')
            image_paths.append(image_paths_for_label)
        self.labels, self.image_paths = labels, list(flatten(image_paths))

        self.idx_to_label = {i: j for i, j in enumerate(labels)}
        self.label_to_idx = {value: key for key, value in self.idx_to_label.items()}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        image = image.convert('RGB')
        #image = image_augmentation(image)
        
        label = self.image_paths[idx].split('/')[-2]
        label_idx = self.label_to_idx[label]

        return image, label_idx

In [8]:
import hashlib

def image_hash(img_pil):
    img_bytes = img_pil.tobytes()
    hash_obj = hashlib.sha256()
    hash_obj.update(img_bytes)
    return hash_obj.hexdigest()

In [18]:
import os
import random

random.seed(42)

target_count = 4000
augmented_images = {}
image_hashes = set()

original_dataset = RecaptchaDataset(original_dataset_path)

for idx, (img, label) in enumerate(original_dataset):
    label_str = original_dataset.idx_to_label[label]
    if label_str not in augmented_images:
        augmented_images[label_str] = []

    img_hash = image_hash(img)
    if img_hash not in image_hashes:
        augmented_images[label_str].append(img)
        image_hashes.add(img_hash)

for label_str, images in augmented_images.items():
    while len(images) < target_count: 
        img_to_augment = random.choice(augmented_images[label_str])
        img_augmented = image_augmentation(img_to_augment)
        img_augmented_hash = image_hash(img_augmented)
    
        if img_augmented_hash not in image_hashes:
            augmented_images[label_str].append(img_augmented)
            image_hashes.add(img_augmented_hash)
        else:
            print(f'Duplicated found, regenerating ... {label_str}: {len(augmented_images[label_str])}')

for label in augmented_images:
    save_path = os.path.join(augmented_dataset_path, label)
    os.makedirs(save_path, exist_ok=True)

    for i, img_tensor in enumerate(augmented_images[label]):
        img_pil = img_tensor
        img_pil.save(os.path.join(save_path, f'augmented_{i}.jpg'))

Duplicated found, regenerating ... Car: 3696
Duplicated found, regenerating ... Car: 3716
Duplicated found, regenerating ... Car: 3788
Duplicated found, regenerating ... Car: 3802
Duplicated found, regenerating ... Car: 3802
Duplicated found, regenerating ... Car: 3861
Duplicated found, regenerating ... Car: 3871
Duplicated found, regenerating ... Car: 3939
Duplicated found, regenerating ... Car: 3945
Duplicated found, regenerating ... Motorcycle: 121
Duplicated found, regenerating ... Motorcycle: 122
Duplicated found, regenerating ... Motorcycle: 130
Duplicated found, regenerating ... Motorcycle: 136
Duplicated found, regenerating ... Motorcycle: 137
Duplicated found, regenerating ... Motorcycle: 138
Duplicated found, regenerating ... Motorcycle: 145
Duplicated found, regenerating ... Motorcycle: 157
Duplicated found, regenerating ... Motorcycle: 162
Duplicated found, regenerating ... Motorcycle: 163
Duplicated found, regenerating ... Motorcycle: 163
Duplicated found, regenerating ...