In [None]:
from htr_crnn_ctc.datasource import IAM_CSVDataSource
from htr_crnn_ctc.dataset import CTCDataset
from htr_crnn_ctc.transforms import Rescale, ToRGB, ToTensor, Normalise
from htr_crnn_ctc.dataloader import CTCDataLoader
from htr_crnn_ctc.model import CTCModel
from htr_crnn_ctc.learn import Learner

from torchvision.transforms import Compose
from torch import cuda, device as Device

import matplotlib.pyplot as plt

from statistics import mode

In [None]:
dev = Device("cuda" if cuda.is_available() else "cpu")
print(f"dev: {dev}")

In [None]:
csvds = IAM_CSVDataSource(
    file="index.csv",
    root_path="tmp\\dataset\\test",
    map_columns=None
)

In [None]:
trans = [
    Rescale(
        output_size=(64, 800),
        random_pad=True,
        border_pad=(10, 40),
        random_rotation=2,
        random_stretch=1.2,
        fill_space=False,
        fill_threshold=200
    ),
    ToRGB(),
    ToTensor(
        rgb=True
    ),
    Normalise(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
]

In [None]:
ds = CTCDataset(
    data_source=csvds,
    char_dict=None,
    transform=Compose(trans)
)

In [None]:
print(f"Number of dataset: {len(ds)}")

In [None]:
print(ds.char_dict)
print()
print(f"Number of characters : {len(ds.char_dict)}")

In [None]:
src_data = open("tmp/models/test/train.log", "r", encoding="utf-8").readlines()
data = []

fit = 0
old_epoch = 2
sum_epoch = 0

epochs = []
train_loss = []
valid_loss = []
cer = []
ier = []
train_leven = []
val_leven = []

for idx, line in enumerate(src_data):
    if (_ := line.strip()) != "\n" and _ != "":
        split = line.split(" ")
        sum_epoch += 1
        epochs.append(sum_epoch)
        train_loss.append(float(split[4]))
        valid_loss.append(float(split[8]))
        cer.append(float(split[11]))
        ier.append(float(split[14]))
        train_leven.append(float(split[17]))
        val_leven.append(float(split[21]))

metrics = {
    "train_loss": (train_loss, "Training Loss"),
    "valid_loss": (valid_loss, "Validation Loss"),
    "cer": (cer, "Character Error Rate"),
    "ier": (ier, "Item Error Rate"),
    "train_leven": (train_leven, "Training Levenshtein"),
    "val_leven": (val_leven, "Validation Levenshtein")
}

best_val_index = 0

for idx in range(1, len(val_leven)):
    if val_leven[best_val_index] >= val_leven[idx]:
        best_val_index = idx

In [None]:
plt.figure(figsize=(10, 6))

for key in ["train_loss", "valid_loss"]:
    plt.plot(epochs, metrics[key][0], label=metrics[key][1])

plt.title("Epoch vs Training Loss and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))

for key in ["cer", "ier", "train_leven", "val_leven"]:
    plt.plot(epochs, metrics[key][0], label=metrics[key][1])

plt.title("Epoch vs Character Error Rate, Item Error Rate, Training and Validation Levenshtein")
plt.xlabel("Epoch")
plt.ylabel("Metrics")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
print(f"Epochs                  {epochs[best_val_index]}")
print(f"Training Loss           {train_loss[best_val_index]}")
print(f"Validation Loss         {valid_loss[best_val_index]}")
print(f"Character Error Rate    {cer[best_val_index]}")
print(f"Item Error Rate         {ier[best_val_index]}")
print(f"Training Levenshtein    {train_leven[best_val_index]}")
print(f"Validation Levenshtein  {val_leven[best_val_index]}")

In [None]:
dl = CTCDataLoader(
    dataset=ds,
    train_batch_size=120,
    validation_batch_size=240,
    validation_split=0.2,
    shuffle=True,
    seed=42,
    device=dev
)

model = CTCModel(
    chan_in=3,
    time_step=96,
    feature_size=512,
    hidden_size=512,
    output_size=len(ds.char_dict) + 1,
    num_rnn_layers=4,
    rnn_dropout=0
).to(dev)

learn = Learner(
    model=model,
    dataloader=dl,
    decode_map=None
)

learn.load(
    f="tmp/models/test/model.pth",
    inv_f="tmp/models/test/decode_map.pk",
    load_decode=True,
    keep_LSTM=True,
    freeze_conv=False
)

learn.batch_predict(
    dataloader="valid",
    show_img=True,
    up_to=20
)