In [122]:
import json
import os
import os.path as osp
from glob import glob
from PIL import Image

from tqdm import tqdm

from torch.utils.data import DataLoader, ConcatDataset, Dataset

import random
import re

In [123]:
fonttypes = ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10']

SRC_DATASET_DIR = '/opt/ml/input/AIHUB_doc_original'  # 파일 위치
DST_DATASET_DIR = '/opt/ml/input/data/AIHub_Docs'  # 선택한 파일 저장 위치

NUM_WORKERS = 4  # FIXME

IMAGE_EXTENSIONS = {'.gif', '.jpg', '.png'}

In [124]:
def maybe_mkdir(x):
    if not osp.exists(x):
        os.makedirs(x)

In [125]:
def get_paths(image_dir, label_dir):
    subdirs = glob(image_dir + '/???')
    path_images = []
    path_labels = []


    for subdir in subdirs:          # for doc type
        doctype = subdir[-3:]
        for fonttype in fonttypes:  # for font type
            files = glob(subdir+'/?????'+fonttype+'????.jpg')
            f = random.choice(files)
            path_images.append(f)
            path_labels.append(osp.join(label_dir, doctype, osp.basename(f).split('.')[0]+'.json'))

    return path_images, path_labels

In [126]:
class AIHubDocsDataset(Dataset):
    def __init__(self, image_dir, label_dir, copy_images_to):
        # get paths
        image_paths, label_paths = get_paths(osp.join(SRC_DATASET_DIR, 'images'),
                                                osp.join(SRC_DATASET_DIR, 'labels'))
        assert len(image_paths) == len(label_paths)

        sample_ids, samples_info = list(), dict()
        for image_path in image_paths:
            # find label & image path pair
            sample_id = osp.splitext(osp.basename(image_path))[0]    # ('00510012043', '.jpg')
            label_path = osp.join(label_dir, sample_id[:3], '{}.json'.format(sample_id))
            assert label_path in label_paths

            # get word dict
            words_info = self.parse_label_file(label_path)
            
            sample_ids.append(sample_id)
            samples_info[sample_id] = dict(image_path=image_path, label_path=label_path,
                                           words_info=words_info)
            
        self.sample_ids, self.samples_info = sample_ids, samples_info

        self.copy_images_to = copy_images_to

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        sample_info = self.samples_info[self.sample_ids[idx]]

        image_fname = osp.basename(sample_info['image_path'])
        image = Image.open(sample_info['image_path'])
        img_w, img_h = image.size

        if self.copy_images_to:
            maybe_mkdir(self.copy_images_to)
            image.save(osp.join(self.copy_images_to, osp.basename(sample_info['image_path'])))

        license_tag = dict()
        sample_info_ufo = dict(img_h=img_h, img_w=img_w, words=sample_info['words_info'], tags=["document"],
                               license_tag=license_tag)

        return image_fname, sample_info_ufo

    def parse_label_file(self, label_path):
        re_ko = re.compile('[ㄱ-ㅣ가-힣]')
        re_en = re.compile('[a-zA-Z]')
        
        with open(label_path, 'r') as f:
            label = json.load(f)

        words_info = dict()  #words_info, languages = dict(), set()
        for i in range(len(label['text']['word'])):
            label_word = label['text']['word'][i]
            # points
            x_ul, y_ul, x_lr, y_lr = label_word['wordbox']
            points = [[x_ul, y_ul], [x_lr, y_ul], [x_lr, y_lr], [x_ul, y_lr]]
            # transcription
            transcription = label_word['value']
            # language
            languagelist = []
            if re_ko.search(transcription):
                languagelist.append('ko')
            if re_en.search(transcription):
                languagelist.append('en')
            # illegibility
            illegibility = False
            # orientation
            orientation = 'Horizontal'     # 정보가 없어서 임의로 넣음. 실제로는 vertical 존재
            # word_tags
            word_tags=None
            words_info[str(i)] = dict(
               points=points, transcription=transcription, language=languagelist,
                illegibility=illegibility, orientation=orientation, word_tags=word_tags
            )
        return words_info

In [128]:
dataset = AIHubDocsDataset(osp.join(SRC_DATASET_DIR, 'images'),
                            osp.join(SRC_DATASET_DIR, 'labels'),
                            copy_images_to=osp.join(DST_DATASET_DIR, 'images'))

anno = dict(images=dict())
with tqdm(total=len(dataset)) as pbar:
    for batch in DataLoader(dataset, num_workers=NUM_WORKERS, collate_fn=lambda x: x):
        image_fname, sample_info = batch[0]
        anno['images'][image_fname] = sample_info
        pbar.update(1)

ufo_dir = osp.join(DST_DATASET_DIR, 'ufo')
maybe_mkdir(ufo_dir)
with open(osp.join(ufo_dir, 'AIHub_Docs.json'), 'w') as f:
    json.dump(anno, f, indent=4)

100%|██████████| 500/500 [00:12<00:00, 40.59it/s]
