In [None]:
import cv2
import os
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import numpy as np
import random
from multiprocessing import Pool
from tqdm import tqdm
from itertools import product

In [None]:

IN_DIR = "/datadrive/workspace/others/202206.tanachou/data/train_no_bg"
OUT_DIR = "/datadrive/workspace/others/202206.tanachou/data/train_aug"
BG_DIR = "/datadrive/workspace/others/202206.tanachou/data/BG-20k"
AUG_CNT = 100


In [None]:
def aug_img(img_path, bg_path, out_path):
    
    img = Image.open(img_path)
    img_w, img_h = img.size

    resize_ratio = 512 / int(max(img_w, img_h))
    new_img_w = int(img_w * resize_ratio)
    new_img_h = int(img_h * resize_ratio)

    img_rz = img.resize([new_img_w, new_img_h])

    # rotate angle
    angle = random.randint(0, 360)    
    img_rz = img_rz.rotate(angle, expand=1)
    
    # translate position
    pos_x = random.randint(0, int(img_rz.size[0] * 0.2))
    pos_y = random.randint(0, int(img_rz.size[1] * 0.2))
    
    # background
    background = Image.open(bg_path).resize(img_rz.size)
    background.paste(img_rz, (pos_x, pos_y),  mask=img_rz.convert('RGBA'))
    
    background.save(out_path)

In [None]:
bg_paths = ["{}/{}".format(BG_DIR, bg_file) for bg_file in os.listdir(BG_DIR)]


In [None]:
total_img_paths = []
total_bg_paths = []
total_output_paths = []

for type_name in os.listdir(IN_DIR):
    in_dir = "{}/{}".format(IN_DIR, type_name)
    out_dir = "{}/{}".format(OUT_DIR, type_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        
    for img_name in os.listdir(in_dir):
        img_path = "{}/{}".format(in_dir, img_name)
        
        img_paths = [img_path] * AUG_CNT
        bg_paths = random.sample(bg_paths, AUG_CNT)
        output_paths = []
        name, ext = os.path.splitext(img_name)
        for index in range(AUG_CNT):
            output_path = "{}/{}_{:03d}.jpg".format(out_dir, name, index)
            output_paths.append(output_path)
        
        
        total_img_paths += img_paths
        total_bg_paths += bg_paths
        total_output_paths += output_paths

In [None]:
with Pool(15) as pool:
    tqdm(pool.starmap(aug_img, zip(total_img_paths, total_bg_paths, total_output_paths) ))