In [None]:
import os
import torch
import logging
import numpy as np
from zipfile import ZipFile
from BudaOCR.Config import CHARSET
from huggingface_hub import snapshot_download
from BudaOCR.Modules import CRNNNetwork, OCRTrainer, WylieEncoder, StackEncoder
from BudaOCR.Utils import shuffle_data, create_dir, build_data_paths, build_distribution_from_file, read_stack_file
logging.getLogger().setLevel(logging.INFO)

torch.cuda.empty_cache()
print(torch.__version__)

In [None]:
stack_file = f"tib-stacks.txt"
stacks = read_stack_file(stack_file)
stack_encoder = StackEncoder(stacks)
wylie_encoder = WylieEncoder(CHARSET)

In [None]:
data_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset",  cache_dir="Datasets")

with ZipFile(f"{data_path}/data.zip", 'r') as zip:
    zip.extractall(f"{data_path}/Dataset")

dataset_path = f"{data_path}/Dataset"
image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

print(f"Images: {len(image_paths)}, Labels: {len(label_paths)}")

In [None]:
output_dir = "Output"
create_dir(output_dir)

image_width = 3200
image_height = 100
num_classes = wylie_encoder.num_classes()

network = CRNNNetwork(image_width=image_width, image_height=image_height, num_classes=num_classes)
workers = 4

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=wylie_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=12, 
    output_dir=output_dir, 
    preload_labels=True
    )


ocr_trainer.init(image_paths, label_paths)
ocr_trainer.train(epochs=48, check_cer=True, export_onnx=True)

#### Train from fixed Distribution

In [None]:
distr_file = f"{data_path}/data.distribution"
distribution = build_distribution_from_file(distr_file, dataset_path)

image_width = 3200
image_height = 100
num_classes = wylie_encoder.num_classes()

network = CRNNNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height)

workers = 4

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=wylie_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=16, 
    output_dir=output_dir, 
    preload_labels=True
    )

ocr_trainer.init_from_distribution(distribution)

In [None]:
num_epochs = 80
scheduler_start = 62
ocr_trainer.train(epochs=num_epochs, scheduler_start=scheduler_start, check_cer=True, export_onnx=True, silent=True)

#### Evaluate on Test set

In [None]:
cer_scores = ocr_trainer.evaluate()
cer_values = list(cer_scores.values())

score_file = os.path.join(ocr_trainer.output_dir, "cer_scores.txt")

with open(score_file, "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}\n")

cer_summary_file = os.path.join(ocr_trainer.output_dir, "cer_summary.txt")

mean_cer = np.mean(cer_values)
max_cer = np.max(cer_values)
min_cer = np.min(cer_values)

with open(cer_summary_file, "w", encoding="utf-8") as f:
    f.write(f"Mean CER: {mean_cer}\n")
    f.write(f"Max CER: {max_cer}\n")
    f.write(f"Min CER: {min_cer}")