In [16]:
import albumentations as A
import cv2
import os
from collections import Counter
from sklearn.utils import resample


In [17]:
transform = A.Compose([
  A.Rotate(limit=10, border_mode=cv2.BORDER_CONSTANT, p=1),
  A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=1),
  A.GaussianBlur(p=1),
  A.Perspective(scale=(0.07, 0.1), fit_output=True, pad_mode=cv2.BORDER_CONSTANT, pad_val=1, p=0.5)
])

In [18]:
data_dir = './balanced-data/'
class_names = os.listdir(data_dir)
class_counts = Counter()
for class_name in class_names:
    class_counts[class_name] = len(os.listdir(os.path.join(data_dir, class_name)))
max_count = max(class_counts.values())


In [19]:
class_counts

Counter({'1': 237,
         'B': 227,
         'A': 65,
         '8': 64,
         '4': 62,
         '5': 62,
         '7': 61,
         '6': 54,
         'T': 45,
         'S': 42,
         '3': 40,
         '9': 39,
         'J': 39,
         '2': 36,
         'R': 30,
         '0': 24,
         'K': 23,
         'U': 23,
         'P': 21,
         'F': 20,
         'M': 20,
         'N': 18,
         'D': 17,
         'L': 17,
         'H': 16,
         'E': 15,
         'Z': 15,
         'I': 14,
         'Y': 10,
         'V': 9,
         'O': 8,
         'W': 8,
         'G': 7,
         'Q': 7,
         'C': 4,
         'X': 4})

In [20]:
for class_name, count in class_counts.items():
    if count < max_count:
        diff = max_count - count
        image_paths = [os.path.join(data_dir, class_name, image_name) for image_name in os.listdir(os.path.join(data_dir, class_name))]
        image_paths = resample(image_paths, n_samples=diff, random_state=42, replace=True)
        for i, image_path in enumerate(image_paths):
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            transformed = transform(image=image)
            transformed_image = transformed['image']
            new_image_path = os.path.join(data_dir, class_name, f'{i}.jpg')
            transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(new_image_path, transformed_image)
