In [1]:
import numpy as np
import albumentations as A

from PIL import Image
import cv2
from tqdm import tqdm

In [2]:
def read_csv_file(fp):
    with open(fp, 'r') as f:
        names, types, colors = [], [], []
        for i, line in enumerate(f):
            if i:
                line = line.rstrip().split('\t')
                imgname, category, color = line[0], line[1], line[2]
                names.append(imgname)
                types.append(category)
                colors.append(color)
    return names, types, colors

In [3]:
def common_augs():
    augs = []
    augs.append(A.VerticalFlip(p=1))
    augs.append(A.HorizontalFlip(p=1))
    augs.append(A.Rotate(limit=45,p=1))
    augs.append(A.ShiftScaleRotate(p=1))
    augs.append(A.RandomScale(p=1))
    augs.append(A.RandomSnow(p=1))
    augs.append(A.RandomRain(p=1))
    return augs

In [4]:
def color_augs():
    augs = []
    augs.append(A.OpticalDistortion(p=1))
    augs.append(A.GridDistortion(p=1))
    augs.append(A.ElasticTransform(p=1))
    augs.append(A.MotionBlur(p=1))
    augs.append(A.MedianBlur(p=1))
    augs.append(A.GaussianBlur(p=1))
    return augs

In [5]:
def dress_augs():
    augs = []
    augs.append(A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=50, val_shift_limit=50, p=1))
    augs.append(A.RandomBrightness(p=1))
    augs.append(A.RandomContrast(p=1))
    augs.append(A.RandomBrightness(p=1))
    augs.append(A.ToGray(p=1))
    augs.append(A.RandomGamma(p=1))
    augs.append(A.CLAHE(p=1))
    augs.append(A.ChannelShuffle(p=1))
    return augs

In [6]:
def random_combinations(img, augs, n_times=6, how_many=5):
    augmented = []
    for _ in range(n_times):
        np.random.shuffle(augs)
        aug = A.Compose(augs[:how_many], p=1)
        augmented.append(aug(image=img)['image'])
    return augmented

In [7]:
color = color_augs() + common_augs()
dress = dress_augs() + common_augs()

In [8]:
names, types, colors = read_csv_file("./data/train.tsv")

In [9]:
def populate(prefix, names, labels, augs):
    for i, name in tqdm(enumerate(names)):
        try:
            img = cv2.imread("./data/images/" + name)
            X = random_combinations(img, augs)
            # Write original image
            cv2.imwrite("./data/" + prefix + "/" + labels[i] + "/" + name + "_0.png", img)
            # Write augmented copies
            for j, x in enumerate(X):
                cv2.imwrite("./data/" + prefix + "/" + labels[i] + "/" + name + "_" + str(j + 1) + ".png", x)
        except:
            print("Skipping %s file" % name)

In [None]:
populate("colors", names, colors, color)

8585it [2:15:21,  1.01s/it]

Skipping 8585.jpg file


8593it [2:15:30,  1.09it/s]

In [None]:
populate("types", names, types, dress)