In [None]:
import pathlib

ROOT = pathlib.Path("/content/drive/MyDrive/Colab Notebooks/KanjiLookup")

GENERATED = ROOT / "generated"

INPUT_FONTS_FOLDER = ROOT / 'fonts'
INPUT_KANJI_FOLDER = ROOT / 'text'

GENERATED_EMBEDDINGS_FOLDER = GENERATED / 'embeddings'

MODEL = "kha-white/manga-ocr-base"

MODEL_EMBEDDING_SIZE = 768
MODEL_IMAGE_SIZE = 224
FONT_SIZE = 184

In [None]:
import numpy as np
from PIL import Image, ImageFont, ImageDraw

def get_standard_kanji_set() -> set[str]:
    file = INPUT_KANJI_FOLDER / "kanji_joyo.txt"
    return set(file.read_text(encoding="UTF-8").splitlines())

def load_kanji_list() -> list[str]:
    kanji_list = []
    for file in INPUT_KANJI_FOLDER.glob("*.txt"):
        kanji_list += file.read_text(encoding="UTF-8").splitlines()
    return kanji_list


def draw_kanji(font: ImageFont.FreeTypeFont, kanji: str):
    "Create an image with the given `kanji` drawn in the given `font`"
    image = Image.new('L', (MODEL_IMAGE_SIZE, MODEL_IMAGE_SIZE), color=255)
    draw = ImageDraw.Draw(image)
    draw.text((MODEL_IMAGE_SIZE // 2, MODEL_IMAGE_SIZE // 2), kanji, font=font, anchor="mm", fill=0)
    return image

def check_has_text(image: Image):
    "Verifies if the image contain anything at all"
    _arr = np.asarray(image)
    if _arr.min() == _arr.max():
        return False
    return True


def list_fonts() -> dict[str, ImageFont.FreeTypeFont]:
    """Returns a dictionary of `font_name -> ImageFont`"""
    return {
        # font_file.stem: ImageFont.truetype(font_file, FONT_SIZE)
        # ^ Worked on my machine, but not on colab.
        font_file.stem: ImageFont.truetype(str(font_file), FONT_SIZE)
        for font_file in INPUT_FONTS_FOLDER.glob("**/*.ttf")
    }

def generate_images_for_font(font: ImageFont.FreeTypeFont, kanji_list: list[str]) -> dict[str, Image.Image]:
    """Returns a dictionary of `kanji -> Image`"""
    out = {}
    _bad = []
    for kanji in kanji_list:
        image = draw_kanji(font, kanji)
        if check_has_text(image):
            out[kanji] = image
        else:
            _bad.append(kanji)
    if _bad:
        print(f"Font {font.getname()} does not seems to support {_bad} characters, skipping them for this font")
    return out

In [None]:
from PIL import Image
import torch
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,  # Load extractor
    ViTModel,  # Load ViT encoder
)

assert MODEL == "kha-white/manga-ocr-base", "Other models are not natively supported, \
    you may have to change a lot of things to get it to work"
assert MODEL_EMBEDDING_SIZE == 768, "The only model embedding size supported is 768"

def load_model() -> tuple[ViTImageProcessor, ViTModel]:
    """Load the model based on the config.py file.
    Returns the `feature_extractor` and the `encoder`, in this order.
    """
    print("Loading Image Processor from HuggingFace Hub")
    feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False)

    print("Loading ViT Model from HuggingFace Hub")
    model: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder

    if torch.cuda.is_available():
        print('Using CUDA')
        model.cuda()
    else:
        print('Using CPU')

    return feature_extractor, model


def get_embeddings(feature_extractor: ViTImageProcessor, encoder: ViTModel, images: list[Image.Image]) -> torch.Tensor:
    """Processes the images and returns their Embeddings"""
    images_rgb = [image.convert("RGB") for image in images]
    with torch.inference_mode():
        pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"]
        return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()


In [None]:
import typing

from tqdm import tqdm
import torch

T = typing.TypeVar("T")

BATCH_SIZE = 1024

def batched(original: list[T], group_size: int) -> list[list[T]]:
    groups = []
    for i in range(0, len(original), group_size):
        groups.append(original[i : i + group_size])
    # return [groups[0]]  # Return only the first group of each batch for testing
    return groups


In [None]:
extractor, encoder = load_model()

fonts = list_fonts()
print(f"Using the {len(fonts)} following fonts: {fonts.keys()}")

kanji_list = load_kanji_list()
kanji_batches = batched(kanji_list, BATCH_SIZE)
print(f"Processing a total of {len(kanji_list)} Kanji in {len(kanji_batches)} batches of {BATCH_SIZE}")

In [None]:
for font_name in tqdm(fonts):
    font = fonts[font_name]
    font_embeddings_folder = GENERATED_EMBEDDINGS_FOLDER / font_name
    font_embeddings_folder.mkdir(exist_ok=True, parents=True)

    print(f"Generating Embeddings for font {font.getname()}")

    for batch_index, kanji_batch in enumerate(tqdm(kanji_batches)):
        images_dict = generate_images_for_font(font, kanji_batch)

        if len(images_dict) == 0:
            # The font does not supports any of the characters from the batch
            print(f"The font {font_name} did not support any of the characters for batch {batch_index}")
            continue

        tensor_out_file = font_embeddings_folder / f"batch_{batch_index}.pt"
        labels_out_file = font_embeddings_folder / f"batch_{batch_index}.txt"

        labels = "\n".join(images_dict.keys())  # We cannot use kanji_batch directly
        # because `generate_images_for_font` may skip some characters
        tensor = get_embeddings(extractor, encoder, list(images_dict.values()))

        labels_out_file.write_text(labels, encoding="UTF-8")
        torch.save(tensor, tensor_out_file)
