In [13]:
# -*- coding: utf-8 -*-  

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    
    try:
        imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_COLOR)
        # if img == None:
        #     return False
        imgH, imgW, imgC = img.shape[0], img.shape[1], img.shape[2]
        if imgH * imgW * imgC == 0:
            return False
    except:
        return False
    # img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    # imgH, imgW = img.shape[0], img.shape[1]
    # if imgH * imgW == 0:
    #     return False
    return True

def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            try:
                txn.put(k, v)
            except:
                print(k)
                break

In [14]:
def korean2label(letter):
        ch1 = (ord(letter) - ord('가'))//588
        ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
        ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2
        return ch1, ch2, ch3
    
def get_all_korean():

    def nextKorLetterFrom(letter):
        lastLetterInt = 15572643
        if not letter:
            return '가'
        a = letter
        b = a.encode('utf8')
        c = int(b.hex(), 16)

        if c == lastLetterInt:
            return False

        d = hex(c + 1)
        e = bytearray.fromhex(d[2:])

        flag = True
        while flag:
            try:
                r = e.decode('utf-8')
                flag = False
            except UnicodeDecodeError:
                c = c+1
                d = hex(c)
                e = bytearray.fromhex(d[2:])
        return e.decode()

    returns = []
    flag = True
    k = ''
    while flag:
        k = nextKorLetterFrom(k)
        if k is False:
            flag = False
        else:
            returns.append(k)
    return returns


In [15]:
def createDataset(outputPath, imagePathList, labelList, writerIDList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert(len(imagePathList) == len(labelList))
    #import pdb;pdb.set_trace()
    nSamples = len(imagePathList)
    
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1
    # pbar = tqdm(total=len(imagePathList))
    # import pdb;pdb.set_trace()
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        writerID = writerIDList[i]

        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        # import pdb; pdb.set_trace()
        
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        writerIDKey = 'writerID-%09d' % cnt
        imageKey = imageKey.encode('utf-8')
        labelKey = labelKey.encode('utf-8')
        writerIDKey = writerIDKey.encode('utf-8')
        
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode('utf-8')
        cache[writerIDKey] = writerID.encode('utf-8')
        # import pdb;pdb.set_trace()
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            lexiconKey = lexiconKey.encode('utf-8')
            cache[lexiconKey] = " ".join([str(ch) for ch in lexiconList[i]]).encode('utf-8')
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples), end="\r", flush=True)
        cnt += 1
        # pbar.update(1)
    nSamples = cnt-1
    cache['num-samples'.encode()] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)
    # pbar.set_postfix('Created dataset with %d samples' % nSamples)

In [16]:
ak = get_all_korean()
cache = {k:korean2label(k) for k in ak}
datafd = "/home/jupyter/ai_font/data/train/pngs"
font_mapper = pd.read_pickle("/home/jupyter/ai_font/data/pickle/font_mapper.pickle")
fonts = font_mapper.index
img_file_list = [f"{datafd}/{f}" for f in os.listdir(datafd) if f.endswith(".png")]

In [17]:
img_path_list = []
label_list = []
ID_list = []
lexicon_list = []

for img_path in tqdm(img_file_list):        
    img_path_list.append(img_path)
    label = img_path.split("/")[-1].split("__")[-1].replace(".png","")
    label_list.append(label)
    lexicon = cache[label]
    lexicon_list.append(lexicon)
    writerID_str = img_path.split("/")[-1].split("__")[-2]
    writerID = fonts.get_loc(writerID_str)
    ID_list.append(writerID_str)

100%|██████████| 2569554/2569554 [00:08<00:00, 304266.96it/s]


In [18]:
print('total sample: %d' % len(img_path_list))

createDataset('/home/jupyter/ai_font/data/train/lmdb', img_path_list, label_list, ID_list,lexicon_list)

total sample: 2569554
Created dataset with 2569554 samples
