### Data augmentation initialization

In [2]:
import cv2
from tqdm import tqdm
from pathlib import Path
import notebook_utils as nutils
import albumentations as A
import shutil
import numpy as np
### Augment one file

def augment_file(file_path, output_dir, aug_transform, num_samples=20):
    image = cv2.imread(str(file_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
    
    if not file_path.with_suffix('.txt').is_file():
        return
    
    # Copy original file
    shutil.copyfile(str(file_path), str(Path(output_dir) / file_path.name))
    shutil.copyfile(str(file_path.with_suffix('.txt')), str(Path(output_dir) / file_path.with_suffix('.txt').name))
    
    with open(str(file_path.with_suffix('.txt')), 'r') as f:
        lines = f.readlines()

    keypoints = []

    for line in lines:
        splitted_line = line.strip().lower().split(',')
        x1, y1, x2, y2, x3, y3, x4, y4 = map(float, splitted_line[:8])

        keypoints.append((x1, y1))
        keypoints.append((x2, y2))
        keypoints.append((x3, y3))
        keypoints.append((x4, y4))

    times = 1

    data = Path(output_dir)
    image_name = file_path.stem

    while (times < num_samples):
        transformed = aug_transform(image=image, keypoints=keypoints)
        trans_p = np.array(transformed['keypoints'])

        if len(keypoints) != len(trans_p):
            pass
        else:
            bboxes = trans_p.reshape(int(trans_p.shape[0] / 4), 4, 2)
            
            cv2.imwrite(str(data / "{}_{}{}".format(image_name, times, ".jpg")), transformed['image'])

            with open(data / "{}_{}{}".format(image_name, times, ".txt"), "w") as f:
                for bbox in bboxes.tolist():
                    line = []
                    for p in bbox:
                        line.append(",".join([str(int(p[i])) for i in range(len(p))]))

                    f.writelines(",".join(line) + "\r\n")

#             viz = transformed['image'].copy()
#             for bbox in bboxes:
#                 cv2.polylines(viz, [bbox.astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=1)
            
#             cv2.imwrite(str(Path('datasets/SROIE2019/test/viz') / "{}_{}{}".format(image_name, times, ".jpg")), viz)    
            
        times += 1
   

### Augment whole folder

In [3]:
im_fns = nutils.list_files("datasets/SROIE2019/0325updated.task1train(626p)", "*.jpg")
output_aug = 'datasets/SROIE2019/task1train/aug'
aug_transform = A.Compose(
    [A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, 
                        rotate_limit=20, interpolation=1, 
                        border_mode=1, value=None, 
                        mask_value=None, always_apply=False, 
                        p=0.8),
    A.RGBShift(),
    A.RandomBrightnessContrast()], 
    keypoint_params=A.KeypointParams(format='xy')
)

for im_fn in tqdm(im_fns):
    try:
        augment_file(im_fn, output_aug, aug_transform, 20)
    except:
        pass
 

100%|██████████| 712/712 [13:48<00:00,  1.16s/it]
