<a href="https://colab.research.google.com/github/kahxuan/chinese-calligraphy-recognition/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/kahxuan/chinese-calligraphy-ocr.git

In [None]:
%cd chinese-calligraphy-ocr

In [None]:
import os
import yaml
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib as mpl
import matplotlib.pyplot as plt
from modules.model import CRModel, IMG_SIZE, INPUT_SHAPE

tf.config.run_functions_eagerly(True)

config_path = 'config.yaml'

with open(config_path) as file:
    config = yaml.safe_load(file)

data_path = config['dataset']['raw_dir']
config = config['train']

In [None]:
# load dataset

train = image_dataset_from_directory('data/train', 
                                     seed=config['dataset']['seed'],
                                     shuffle=True, 
                                     batch_size=config['batch_size'], 
                                     image_size=IMG_SIZE, 
                                     label_mode='categorical'
                                    )
val = image_dataset_from_directory('data/validation', 
                                   seed=config['dataset']['seed'],
                                   shuffle=True, 
                                   batch_size=config['batch_size'], 
                                   image_size=IMG_SIZE, 
                                   label_mode='categorical'
                                  )

test = image_dataset_from_directory('data/test', 
                                   seed=config['dataset']['seed'],
                                   shuffle=True, 
                                   batch_size=config['batch_size'], 
                                   image_size=IMG_SIZE, 
                                   label_mode='categorical'
                                  )

In [None]:
model = CRModel(num_class=config['num_class'])
model.build(input_shape=tuple(INPUT_SHAPE))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=config['optimizer']['lr']),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
history = model.fit(train, validation_data=val, epochs=config['epochs'])
loss, accuracy = model.evaluate(test)
print('Test accuracy:', accuracy)

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,2.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

In [None]:
y_true = []
y_pred = []
images = []
data = test.as_numpy_iterator()

for step in range(len(test)):
    X, y = data.next()
    y_true += list(y)
    y_pred += list(model.predict(X))
    images += list(X)

In [None]:
acc_top1 = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)
acc_top3 = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)
print('Top 1 acc', acc_top1)
print('Top 3 acc', acc_top3)

In [None]:
idx = np.where(y_true != y_pred)[0]
random.shuffle(idx)
idx = idx[16]

labels = [np.argmax(y_true[i]) for i in idx]
preds = [np.argmax(y_pred[i]) for i in idx]

print('Predictions', ' '.join([train.class_names[pred] for pred in preds]))
print('Labels     ', ' '.join([train.class_names[label] for label in labels]))

In [None]:
chinese_font = mpl.font_manager.FontProperties(fname='fonts/heiti.ttf')
plt.figure(figsize=(10, 10))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(images[i].astype("uint8"))
    plt.title(train.class_names[np.argmax(y_pred[i])], fontproperties=chinese_font, fontsize=20)
    plt.axis("off")