# Generator

In [None]:
# Generator

import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import xml.etree.ElementTree as ET

class KarutaGenerator:
    def __init__(self):
        # 画像フォルダ
        self.imgdir = "../karuta-classifier/dataset/kyad"
        imgfiles = os.listdir(self.imgdir)
        imgfiles.sort()
        # データ
        self.names = []
        self.widths = []
        self.heights = []
        self.gts = []
        tree = ET.parse("dataset/annotations.xml")
        root = tree.getroot()
        for image in root:
            if image.tag != "image":
                continue
            self.names.append(image.attrib["name"])
            self.widths.append(int(image.attrib["width"]))
            self.heights.append(int(image.attrib["height"]))
            labels = []
            points = []
            for polygon in image:
                if polygon.tag != "polygon":
                    continue
                labels.append(polygon.attrib["label"])
                points.append(polygon.attrib["points"])
            self.gts.append([labels, points])
        self.id = 0

    def gen_series(self):
        for self.id in range(len(self)):
            yield self.id, self.names[self.id], self.widths[self.id], self.heights[self.id], self.gts[self.id]

    def __len__(self):
        return len(self.names)

In [None]:
# Generatorのテスト

def generator_test():
    gen = KarutaGenerator()
    
    for id, name, width, height, gt in gen.gen_series():
        pil_img = Image.open(os.path.join(gen.imgdir, name))
        img: np.array = np.array(pil_img)
        for label, point in zip(*gt):
            xys = point.split(';')
            xs = []
            ys = []
            for xy in xys:
                x, y = map(float, xy.split(','))
                xs.append(x)
                ys.append(y)
            x0: int = int(min(xs))
            x1: int = int(max(xs))
            y0: int = int(min(ys))
            y1: int = int(max(ys))
            print(label)
            img_fuda: np.array = img[y0:y1,x0:x1]
            fix, ax = plt.subplots()
            ax.imshow(img_fuda)
        if id >= 1:
           break

generator_test()

In [None]:
# Classificationタスク用画像を生成

from tqdm import tqdm

def gen_cls_imgs():
    gen = KarutaGenerator()
    
    with open("kimariji_abbrev.txt") as f:
        kimarijis = [l.strip() for l in f.readlines()]
        
    kimariji_to_i = {}
    for i, kimariji in enumerate(kimarijis):
        kimariji_to_i[kimariji] = i + 1

    dataset_dir = "dataset_cls"

    for id, name, width, height, gt in tqdm(gen.gen_series(), total=len(gen)):
        pil_img = Image.open(os.path.join(gen.imgdir, name))
        for j, (label, point) in enumerate(zip(*gt)):
            xys = point.split(';')
            xs = []
            ys = []
            for xy in xys:
                x, y = map(float, xy.split(','))
                xs.append(x)
                ys.append(y)
            x0: int = int(min(xs))
            x1: int = int(max(xs))
            y0: int = int(min(ys))
            y1: int = int(max(ys))
            pil_img_fuda = pil_img.crop((x0, y0, x1, y1))
            idx = np.argmax([
                (xs[1] - xs[0]) + (ys[3] - ys[0]),  # 左上スタート
                (xs[0] - xs[3]) + (ys[1] - ys[0]),  # 右上スタート
                (xs[0] - xs[1]) + (ys[0] - ys[3]),  # 右下スタート
                (xs[3] - xs[0]) + (ys[0] - ys[1]),  # 左下スタート
                ])
            if idx == 1:
                pil_img_fuda = pil_img_fuda.rotate(90, expand=True)
            elif idx == 2:
                pil_img_fuda = pil_img_fuda.rotate(180, expand=True)
            elif idx == 3:
                pil_img_fuda = pil_img_fuda.rotate(270, expand=True)
            
            # 画像ファイル作成
            assert label in kimariji_to_i
            i = kimariji_to_i[label]
            fname = "{:04d}_{:02d}_{:03d}_{}.png".format(id, j, i, label)
            subdir = "{:03d}_{}".format(i, label)
            dir = os.path.join(dataset_dir, subdir)
            if not os.path.isdir(dir):
                os.makedirs(dir)
            path = os.path.join(dir, fname)
            pil_img_fuda.save(path)

gen_cls_imgs()