In [None]:
import os
import torch
import logging
from BudaOCR.Config import CHARSET
from BudaOCR.Modules import EasterNetwork, OCRTrainer, WylieEncoder
from BudaOCR.Utils import shuffle_data, create_dir, build_data_paths

logging.getLogger().setLevel(logging.INFO)
print(torch.__version__)
torch.cuda.empty_cache()

In [2]:
dataset_path = "E:/Datasets/OCR/Glomanthang/Glomanthang/Stuff/Dataset_January2024_merged_bw"
image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

print(f"Images: {len(image_paths)}, Labels: {len(label_paths)}")

Images: 4147, Labels: 4147


In [None]:
charset = CHARSET
label_encoder = WylieEncoder(CHARSET)
output_dir = os.path.join(dataset_path, "Output")
create_dir(output_dir)

image_width = 3200
image_height = 80
num_classes = label_encoder.num_classes()

# Fine Tune an Easter2 Network
network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
workers = 4

checkpoint_path = "Checkpoints/2024_8_26_19_3/OCRModel_4.pth"
network.fine_tune(checkpoint_path)

In [None]:
ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=label_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=32, 
    output_dir=output_dir, 
    preload_labels=True
    )

ocr_trainer.init(image_paths, label_paths)
ocr_trainer.train(epochs=24, check_cer=True, export_onnx=True)