Configuration

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

FOLDER_PATH = os.getenv("FOLDER_PATH")
files = [f"{FOLDER_PATH}_{i:02d}" for i in range(1, 51)]
OUTPUT_PATH = os.getenv("OUTPUT_PATH")
selected_jis = [9250, 9252, 9254, 9256, 9258]

Preprocess functions

In [None]:
import tensorflow as tf
from dataclasses import dataclass
import numpy as np
import cv2


@dataclass
class Record:
    """One record from a file ETL9G."""
    jis_code: int
    img: np.ndarray


def swap_white_and_black(img: np.ndarray) -> np.ndarray:
    """
    Swaps white and black pixels in the image.
    """
    return 15 - img


def clean_image_background(img: np.ndarray) -> np.ndarray:
    """
    Removes unwanted objects that touch the edges of the image.
    This helps eliminate neighboring fragments that result from scanning multiple samples from one sheet.
    Uses bfs to find connected components touching the edges and removes them.
    """
    # TODO: remove small objects inside the image
    # TOOD: check y and x order
    h, w = img.shape
    visited = np.zeros((h, w), dtype=bool)

    def bfs(start_x: int, start_y: int):
        queue = [(start_x, start_y)]
        while queue:
            x, y = queue.pop(0)
            if visited[x, y]:
                continue
            visited[x, y] = True
            if img[x, y] > 0:
                img[x, y] = 0
                for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < h and 0 <= ny < w and not visited[nx, ny]:
                        queue.append((nx, ny))

    for i in range(h):
        if not visited[i, 0]:
            bfs(i, 0)
        if not visited[i, w - 1]:
            bfs(i, w - 1)
    for i in range(w):
        if not visited[0, i]:
            bfs(0, i)
        if not visited[h - 1, i]:
            bfs(h - 1, i)

    return img


def cut_center_and_scale(img: np.ndarray, add_margin: bool = False) -> np.ndarray:
    """
    Cuts the center of the image and scales it to a fixed size.
    """
    mask = img > 1
    coords = np.argwhere(mask)

    if coords.size == 0:
        return img
    
    # TODO: check if +1 is needed
    y0, x0 = coords.min(axis=0)
    y1, x1 = coords.max(axis=0) + 1  # slices are exclusive at the top and bottom
    cropped_img = img[y0:y1, x0:x1]
    ch, cw = cropped_img.shape
    if add_margin:
        cropped_img = np.pad(cropped_img, pad_width=2, mode='constant', constant_values=0)

    scale = min(127 / ch, 128 / cw)
    new_h = int(ch * scale)
    new_w = int(cw * scale)

    scaled_img = tf.image.resize(
        cropped_img[..., np.newaxis],
        (new_h, new_w), method='bilinear'
    ).numpy()
    
    canvas = np.zeros((127, 128, 1), dtype=scaled_img.dtype)

    y_offset = (127 - new_h) // 2
    x_offset = (128 - new_w) // 2
    canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w, :] = scaled_img

    return canvas.astype(np.uint8).squeeze()


def convert_to_binary_image(img: np.ndarray, threshold = 3) -> np.ndarray:
    """
    Converts grayscale image to binary (black and white) image using thresholding.
    """
    binary_img = np.where(img > threshold, 15, 0).astype(np.uint8)
    return binary_img


def smooth_edges(img: np.ndarray, kernel_size: int = 7) -> np.ndarray:
    """
    Smooths the edges of the image using Gaussian blur.
    """
    img = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
    return img


def thin_lines(img: np.ndarray, kernel_size: int = 2) -> np.ndarray:
    """
    Thins the lines in the image using morphological operations.
    """
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
    return img

In [None]:

from typing import BinaryIO, List, Optional
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
import struct
import numpy as np
import cv2
import matplotlib.pyplot as plt

@dataclass
class Record:
    """One record from a file ETL9G."""
    jis_code: int
    img: np.ndarray

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)


def read_etl9g_records_from_file(f: BinaryIO) -> Optional[Record]:
    """
    Reads one record from a file ETL9G.
    Returns: (jis_code, img)
    """
    s = f.read(8199)
    if not s:
        return None
    r = struct.unpack('>HH8sIBBBBHHHHBB34s8128s7x', s)
    jis_code = r[1]
    img_bytes = r[15]
    
    arr = np.frombuffer(img_bytes, dtype = np.uint8)
    high = arr >> 4
    low = arr & 0x0F
    pixels = np.empty(arr.size * 2, dtype = np.uint8)
    pixels[0::2] = high
    pixels[1::2] = low
    img = pixels.reshape(127, 128)

    return Record(jis_code, img)


def read_etl9g_all_records() -> List[Record]:
    """
    Reads multiple files ETL9G binary
    """
    records = []
    for file in files:
        with open(file, "rb") as f:
            while True:
                record = read_etl9g_records_from_file(f)
                if record == None:
                    break
                if record.jis_code in selected_jis:
                    img = record.img
                    
                    simple_record = Record(record.jis_code, simple_preprocess(img))
                    records.append(simple_record)

                    # strong_record = Record(record.jis_code, preprocess(img))
                    # records.append(strong_record)

    return records


def simple_preprocess(img: np.ndarray) -> np.ndarray:
    """
    Simple preprocessing steps.
    """
    img = clean_image_background(img)
    img = cut_center_and_scale(img)
    # img = smooth_edges(img)
    img = convert_to_binary_image(img)
    img = swap_white_and_black(img)
    return img

def preprocess(img: np.ndarray) -> np.ndarray:
    """
    Repeated preprocessing steps to get the best result.
    """
    img = thin_lines(img, kernel_size=2)
    img = clean_image_background(img)
    img = cut_center_and_scale(img)
    img = smooth_edges(img)
    img = convert_to_binary_image(img, 3)
    img = cut_center_and_scale(img, add_margin=True)
    img = smooth_edges(img)
    img = convert_to_binary_image(img, 6)
    img = swap_white_and_black(img)
    return img


def create_cnn(with_augmentation=True):
    model = tf.keras.Sequential()
    model.add(layers.Input(shape=(127, 128, 1)))

    if with_augmentation:
        model.add(tf.keras.Sequential([
            layers.RandomRotation(0.1),
            layers.RandomZoom(height_factor=(-0.1, 0.0), width_factor=(-0.1, 0.0)),
            # layers.RandomTranslation(0.1, 0.1),   # not needed because images are centered
            # layers.RandomFlip("horizontal")       # shouldn't be used for writing
        ]))

    model.add(layers.Conv2D(32, (3,3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D())
    model.add(layers.Conv2D(64, (3,3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D())
    model.add(layers.Conv2D(128, (3,3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D())
    model.add(layers.Flatten())
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(len(selected_jis), activation='softmax'))

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    return model


def split_data(records: List[Record]) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Splits the data into training and testing sets.
    """
    images = np.array([record.img for record in records])
    labels = np.array([selected_jis.index(record.jis_code) for record in records])

    images = images[..., np.newaxis]  # add channel dimension

    images_train, images_test, labels_train, labels_test = train_test_split(
        images, labels, test_size=0.2, random_state=42, stratify=labels
    )

    return images_train, images_test, labels_train, labels_test


def save_model(model: tf.keras.Model, path: str):
    """
    Saves the model to the specified path.
    """
    model.export(path)
    converter = tf.lite.TFLiteConverter.from_saved_model(path)
    tflite_model = converter.convert()
    with open(f"{path}.tflite", "wb") as f:
        f.write(tflite_model)


if __name__ == "__main__":
    records = read_etl9g_all_records()
    images_train, images_test, labels_train, labels_test = split_data(records)

    model = create_cnn(with_augmentation=True)
    # TODO: check if summary needed
    # model.summary()
    history = model.fit(
        images_train, labels_train,
        validation_data=(images_test, labels_test),
        epochs=50,
        shuffle=True,
        callbacks=[early_stop]
    )
    plt.plot(history.history['val_loss'])

    model_no_augmentation = create_cnn(with_augmentation=False)
    model_no_augmentation.set_weights(model.get_weights())
    save_model(model_no_augmentation, f"{OUTPUT_PATH}/cnn_etl9g_5s_50e_nomargin_smooth_simple")



Epoch 1/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 187ms/step - accuracy: 0.5775 - loss: 4.4747 - val_accuracy: 0.8400 - val_loss: 0.6131
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 189ms/step - accuracy: 0.8225 - loss: 0.4667 - val_accuracy: 0.8750 - val_loss: 0.4090
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 168ms/step - accuracy: 0.9050 - loss: 0.2850 - val_accuracy: 0.9050 - val_loss: 0.3073
Epoch 4/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 162ms/step - accuracy: 0.8988 - loss: 0.2652 - val_accuracy: 0.9100 - val_loss: 0.3485
Epoch 5/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 178ms/step - accuracy: 0.9225 - loss: 0.2128 - val_accuracy: 0.9250 - val_loss: 0.3283
Epoch 6/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 176ms/step - accuracy: 0.9337 - loss: 0.1863 - val_accuracy: 0.9350 - val_loss: 0.3141
INFO:tensorflow:Assets writt

INFO:tensorflow:Assets written to: models/cnn_etl9g_5s_50e_nomargin_smooth_simple\assets


Saved artifact at 'models/cnn_etl9g_5s_50e_nomargin_smooth_simple'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 127, 128, 1), dtype=tf.float32, name='keras_tensor_896')
Output Type:
  TensorSpec(shape=(None, 5), dtype=tf.float32, name=None)
Captures:
  2724535486224: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535488528: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535491792: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535492560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535485264: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535498896: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535495632: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535490448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535495824: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2724535484688: TensorSpec(shape=(), dtype=tf.resource, name=None)
