In [2]:
import h5py
from dataset import SceneTextDataset
import albumentations as A
from torch.utils.data import Dataset

import os
import os.path as osp
import json
import numpy as np
from shapely.geometry import Polygon
from dataset import resize_img, adjust_height, rotate_img, crop_img, generate_roi_mask
import albumentations as A
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image

from east_dataset import generate_score_geo_maps
import cv2
import torch


INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.20 (you have 1.4.12). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


In [3]:
_lang_list = ['chinese', 'japanese', 'thai', 'vietnamese']
root_dir = os.environ.get('SM_CHANNEL_TRAIN', 'data')
split = 'train'
total_anno = dict(images=dict())
for nation in _lang_list:
    with open(osp.join(root_dir, '{}_receipt/ufo/{}.json'.format(nation, split)), 'r', encoding='utf-8') as f:
        anno = json.load(f)
    for im in anno['images']:
        total_anno['images'][im] = anno['images'][im]

anno = total_anno
image_fnames = sorted(anno['images'].keys())

In [4]:
def _infer_dir(fname):
    lang_indicator = fname.split('.')[1]
    if lang_indicator == 'zh':
        lang = 'chinese'
    elif lang_indicator == 'ja':
        lang = 'japanese'
    elif lang_indicator == 'th':
        lang = 'thai'
    elif lang_indicator == 'vi':
        lang = 'vietnamese'
    else:
        raise ValueError
    return osp.join(root_dir, f'{lang}_receipt', 'img', split)


In [5]:
ignore_under_threshold=10
drop_under_threshold=1
image_size=2048
crop_size=1024

In [6]:
def filter_vertices(vertices, labels, ignore_under=0, drop_under=0):
    if drop_under == 0 and ignore_under == 0:
        return vertices, labels

    new_vertices, new_labels = vertices.copy(), labels.copy()

    areas = np.array([Polygon(v.reshape((4, 2))).convex_hull.area for v in vertices])
    labels[areas < ignore_under] = 0

    if drop_under > 0:
        passed = areas >= drop_under
        new_vertices, new_labels = new_vertices[passed], new_labels[passed]

    return new_vertices, new_labels


In [7]:
class SceneTextWoAugDataset(Dataset):
    def __init__(self, root_dir,
                 split='train',
                 ignore_under_threshold=10,
                 drop_under_threshold=1,
                 ):
        self._lang_list = ['chinese', 'japanese', 'thai', 'vietnamese']
        self.root_dir = root_dir
        self.split = split
        total_anno = dict(images=dict())
        for nation in self._lang_list:
            with open(osp.join(root_dir, '{}_receipt/ufo/{}.json'.format(nation, split)), 'r', encoding='utf-8') as f:
                anno = json.load(f)
            for im in anno['images']:
                total_anno['images'][im] = anno['images'][im]

        self.anno = total_anno
        self.image_fnames = sorted(self.anno['images'].keys())

        self.drop_under_threshold = drop_under_threshold
        self.ignore_under_threshold = ignore_under_threshold

        self.vertices_list = []
        self.labels_list = []
        self.images_list = []

        for idx in range(len(self.image_fnames)):
            image_fname = self.image_fnames[idx]
            image_fpath = osp.join(self._infer_dir(image_fname), image_fname)

            vertices, labels = [], []
            for word_info in self.anno['images'][image_fname]['words'].values():
                num_pts = np.array(word_info['points']).shape[0]
                if num_pts > 4:
                    continue
                vertices.append(np.array(word_info['points']).flatten())
                labels.append(1)
            vertices, labels = np.array(vertices, dtype=np.float32), np.array(labels, dtype=np.int64)

            vertices, labels = filter_vertices(
                vertices,
                labels,
                ignore_under=self.ignore_under_threshold,
                drop_under=self.drop_under_threshold
            )

            image = Image.open(image_fpath)
      
            
            self.vertices_list.append(vertices)
            self.labels_list.append(labels)
            self.images_list.append(image)


        self.vertices = np.array(self.vertices_list, dtype=object)
        self.labels = np.array(self.labels_list, dtype=object)
        self.images = np.array(self.images_list, dtype=object)


    def _infer_dir(self, fname):
        lang_indicator = fname.split('.')[1]
        if lang_indicator == 'zh':
            lang = 'chinese'
        elif lang_indicator == 'ja':
            lang = 'japanese'
        elif lang_indicator == 'th':
            lang = 'thai'
        elif lang_indicator == 'vi':
            lang = 'vietnamese'
        else:
            raise ValueError
        return osp.join(self.root_dir, f'{lang}_receipt', 'img', self.split)
    
    def __len__(self):
        return len(self.image_fnames)

    def __getitem__(self, idx):


        image, vertices, labels = self.images[idx], self.vertices[idx], self.labels[idx]

        # image, vertices = resize_img(image, vertices, self.image_size)
        # image, vertices = adjust_height(image, vertices)
        # image, vertices = rotate_img(image, vertices)
        # image, vertices = crop_img(image, vertices, labels, self.crop_size)

        if image.mode != 'RGB':
            image = image.convert('RGB')

        image = np.array(image)

        funcs = []
        if True:
            funcs.append(A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        transform = A.Compose(funcs)

        image = transform(image=image)['image']

        word_bboxes = np.array(vertices).reshape(-1, 4, 2)
        roi_mask = generate_roi_mask(image, vertices, labels)

        # print("shape",image.shape)

        return image, word_bboxes, roi_mask


In [8]:
train_dataset = SceneTextWoAugDataset(
root_dir,
split='train')

  self.images = np.array(self.images_list, dtype=object)


In [9]:
from east_dataset import EASTDataset

In [10]:
dataset = EASTDataset(train_dataset,to_tensor=False)

In [11]:
from tqdm import tqdm

In [12]:
import h5py
from torch.utils.data import DataLoader


In [13]:
def save_dataset_to_h5py(dataset, h5py_filename="east_dataset.h5", batch_size=16):
    # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # h5py 파일 생성
    with h5py.File(h5py_filename, 'w') as h5file:
        # 각 데이터 타입의 그룹 생성
        image_group = h5file.create_group('images')
        score_map_group = h5file.create_group('score_maps')
        geo_map_group = h5file.create_group('geo_maps')
        roi_mask_group = h5file.create_group('roi_masks')

        # 인덱스 초기화
        idx = 0

        # 배치 단위로 데이터 저장
        for idx in tqdm(range(len(dataset))):
            image, score_map, geo_map, roi_mask = dataset[idx]


            # 각 배치의 데이터 저장
            image_group.create_dataset(str(idx), data=image)
            score_map_group.create_dataset(str(idx), data=score_map)
            geo_map_group.create_dataset(str(idx), data=geo_map)
            roi_mask_group.create_dataset(str(idx), data=roi_mask)

    print(f"Dataset saved to {h5py_filename}")


In [None]:
save_dataset_to_h5py(dataset, h5py_filename="east_dataset.h5")

  1%|▏         | 5/400 [00:46<1:06:48, 10.15s/it]