In [None]:
import os
import torch
import logging
import matplotlib.pyplot as plt
from BudaOCR.Config import N_CHARSET
from BudaOCR.Modules import EasterNetwork, OCRTrainer, WylieEncoder, StackEncoder
from BudaOCR.Utils import shuffle_data, create_dir, read_stack_file, build_data_paths, build_distribution_from_directory
from huggingface_hub import snapshot_download

logging.getLogger().setLevel(logging.INFO)
print(torch.__version__)
torch.cuda.empty_cache()

In [None]:
charset = N_CHARSET
wylie_encoder = WylieEncoder(N_CHARSET)

stack_file = f"tib-stacks.txt"
stacks = read_stack_file(stack_file)
stack_encoder = StackEncoder(stacks)

print(stack_encoder.num_classes())
print(wylie_encoder.num_classes())

#### Train from single Dataset

In [None]:
data_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset",  cache_dir="Datasets")
dataset_path = f"{data_path}/Dataset"
image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

print(f"Images: {len(image_paths)}, Labels: {len(label_paths)}")

In [None]:
output_dir = os.path.join(dataset_path, "Output")
create_dir(output_dir)

image_width = 3200
image_height = 100
num_classes = stack_encoder.num_classes()
network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
workers = 4

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=stack_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=32, 
    output_dir=output_dir, 
    preload_labels=True
    )

ocr_trainer.init(image_paths, label_paths)
ocr_trainer.train(epochs=40, check_cer=True, export_onnx=True)

#### Train from a distribution assembled from multiple directories

In [None]:
dataset_path = "Datasets/Drutsa-Complete"
distribution = build_distribution_from_directory(dataset_path)

In [None]:
output_dir = os.path.join(dataset_path, "Output")
create_dir(output_dir)

image_width = 3200
image_height = 100
num_classes = wylie_encoder.num_classes()

network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
workers = 4

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=wylie_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=32, 
    output_dir=output_dir, 
    preload_labels=True
    )

ocr_trainer.init_from_distribution(distribution)
ocr_trainer.train(epochs=80, scheduler_start=60, patience=10, check_cer=True, export_onnx=True)