In [123]:
import os

In [124]:
import os.path as osp

In [125]:
import json

In [126]:
_lang_list = ['chinese', 'japanese', 'thai', 'vietnamese']
root_dir = os.environ.get('SM_CHANNEL_TRAIN', 'data')
split = 'train'
total_anno = dict(images=dict())
for nation in _lang_list:
    with open(osp.join(root_dir, '{}_receipt/ufo/{}.json'.format(nation, split)), 'r', encoding='utf-8') as f:
        anno = json.load(f)
    for im in anno['images']:
        total_anno['images'][im] = anno['images'][im]

anno = total_anno
image_fnames = sorted(anno['images'].keys())


In [127]:
idx = 0

In [128]:
import numpy as np

In [129]:
from shapely.geometry import Polygon


In [130]:
from PIL import Image


In [131]:
def _infer_dir(fname):
    lang_indicator = fname.split('.')[1]
    if lang_indicator == 'zh':
        lang = 'chinese'
    elif lang_indicator == 'ja':
        lang = 'japanese'
    elif lang_indicator == 'th':
        lang = 'thai'
    elif lang_indicator == 'vi':
        lang = 'vietnamese'
    else:
        raise ValueError
    return osp.join(root_dir, f'{lang}_receipt', 'img', split)


In [132]:
def filter_vertices(vertices, labels, ignore_under=0, drop_under=0):
    if drop_under == 0 and ignore_under == 0:
        return vertices, labels

    new_vertices, new_labels = vertices.copy(), labels.copy()

    areas = np.array([Polygon(v.reshape((4, 2))).convex_hull.area for v in vertices])
    labels[areas < ignore_under] = 0

    if drop_under > 0:
        passed = areas >= drop_under
        new_vertices, new_labels = new_vertices[passed], new_labels[passed]

    return new_vertices, new_labels


In [133]:
ignore_under_threshold=10
drop_under_threshold=1

In [134]:
from dataset import resize_img, adjust_height, rotate_img, crop_img, generate_roi_mask
import albumentations as A


In [44]:
image_size=2048
crop_size=1024

In [122]:
%%time

for idx in range(100):
    image_fname = image_fnames[idx]
    image_fpath = osp.join(_infer_dir(image_fname), image_fname)

    vertices, labels = [], []
    for word_info in anno['images'][image_fname]['words'].values():
        num_pts = np.array(word_info['points']).shape[0]
        if num_pts > 4:
            continue
        vertices.append(np.array(word_info['points']).flatten())
        labels.append(1)
    vertices, labels = np.array(vertices, dtype=np.float32), np.array(labels, dtype=np.int64)

    vertices, labels = filter_vertices(
        vertices,
        labels,
        ignore_under=ignore_under_threshold,
        drop_under=drop_under_threshold
    )

    image = Image.open(image_fpath)
    image, vertices = resize_img(image, vertices, image_size)
    image, vertices = adjust_height(image, vertices)
    image, vertices = rotate_img(image, vertices)
    image, vertices = crop_img(image, vertices, labels, crop_size)

    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = np.array(image)

    
    funcs = []
    if True:
        funcs.append(A.ColorJitter())
    if True:
        funcs.append(A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = A.Compose(funcs)

    image = transform(image=image)['image']
    word_bboxes = np.reshape(vertices, (-1, 4, 2))
    roi_mask = generate_roi_mask(image, vertices, labels)


CPU times: user 1min 20s, sys: 1.48 s, total: 1min 21s
Wall time: 2min 54s


In [116]:
transform = A.Compose([
A.Resize(height=image_size, width=image_size),           # resize_img 대체
A.RandomCrop(height=crop_size, width=crop_size),         # crop_img 대체
A.Rotate(limit=45),                                      # rotate_img 대체
A.RandomBrightnessContrast(p=0.2),                       # 추가적인 밝기/대비 증강
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # normalize
],
keypoint_params=A.KeypointParams(format='xy', remove_invisible=False)
)


In [121]:
%%time

for idx in range(100):
    image_fname = image_fnames[idx]
    image_fpath = osp.join(_infer_dir(image_fname), image_fname)
    # print(image_fpath)
    vertices, labels = [], []
    for word_info in anno['images'][image_fname]['words'].values():
        num_pts = np.array(word_info['points']).shape[0]
        if num_pts > 4:
            continue
        vertices.append(np.array(word_info['points']).flatten())
        labels.append(1)
    vertices, labels = np.array(vertices, dtype=np.float32), np.array(labels, dtype=np.int64)

    vertices, labels = filter_vertices(
        vertices,
        labels,
        ignore_under=ignore_under_threshold,
        drop_under=drop_under_threshold
    )
    image = Image.open(image_fpath)
    # if image.mode != 'RGB':
    #     image = image.convert('RGB')
    image = np.array(image)


    augmented = transform(image=image, keypoints=vertices)

    image = augmented['image']

    word_bboxes = augmented['keypoints']


    # word_bboxes = np.array([transform.apply_to_coords(x, y) for x, y in word_bboxes]).flatten()
    
    roi_mask = generate_roi_mask(image, vertices, labels)
    # return image, word_bboxes, roi_mask

CPU times: user 6.53 s, sys: 1.45 s, total: 7.98 s
Wall time: 9.48 s


In [118]:
augmented

{'image': array([[[0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         ...,
         [0.427451  , 0.30980393, 0.12941177],
         [0.427451  , 0.30980393, 0.14509805],
         [0.427451  , 0.30980393, 0.14509805]],
 
        [[0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         ...,
         [0.41960788, 0.3019608 , 0.12156864],
         [0.41960788, 0.3019608 , 0.12941177],
         [0.427451  , 0.30980393, 0.14509805]],
 
        [[0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         [0.4666667 , 0.34117648, 0.16078432],
         ...,
         [0.41960788, 0.3019608 , 0.12156864],
         [0.41960788, 0.3019608 , 0.12156864],
         [0.41176474, 0.29411766, 0.12156864]],
 
        ...,
 
        [[0.3803922 , 0.2627451 , 0.09803922],
         [0.3803922 , 0.2627451 