In [None]:
import os
import torch
import numpy as np
from zipfile import ZipFile 
from OCR.Config import CHARSET
from huggingface_hub import snapshot_download

from OCR.Encoder import TibetanWylieEncoder
from OCR.Trainer import OCRTrainer
from OCR.Networks import EasterNetwork
from OCR.Utils import shuffle_data, create_dir, build_data_paths


print(torch.__version__)
torch.cuda.empty_cache()

#### Train from single Dataset

In [None]:
# download a dataset from Huggingface
# see https://huggingface.co/BDRC for more datasets

data_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset",  cache_dir="Datasets")

with ZipFile(f"{data_path}/data.zip", 'r') as zip:
    zip.extractall(f"{data_path}/Dataset")

In [None]:
dataset_path = f"{data_path}/Dataset"
image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

print(f"Images: {len(image_paths)}, Labels: {len(label_paths)}")

output_dir = "Output"
create_dir(output_dir)

In [None]:
wylie_encoder = TibetanWylieEncoder(CHARSET)

output_dir = "Output"
create_dir(output_dir)

image_width = 3200
image_height = 100
batch_size = 32
label_encoder = wylie_encoder
num_classes = label_encoder.num_classes

network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
workers = 4

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=label_encoder,
    workers=workers, 
    image_width=image_width,
    image_height=image_height,
    batch_size=batch_size, 
    output_dir=output_dir, 
    preload_labels=True)

ocr_trainer.init(image_paths, label_paths)

In [None]:
# adjust the number of epochs and the learning rate schedular based on your scenario. Smaller datasets will require more epochs for training.
num_epochs = 24
scheduler_start = 20
ocr_trainer.train(epochs=num_epochs, scheduler_start=scheduler_start, check_cer=True, export_onnx=True, silent=True)

#### Evaluate on test set

In [None]:
cer_scores = ocr_trainer.evaluate()
cer_values = list(cer_scores.values())

score_file = os.path.join(ocr_trainer.output_dir, "cer_scores.txt")

with open(score_file, "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}\n")

cer_summary_file = os.path.join(ocr_trainer.output_dir, "cer_summary.txt")

mean_cer = np.mean(cer_values)
max_cer = np.max(cer_values)
min_cer = np.min(cer_values)

with open(cer_summary_file, "w", encoding="utf-8") as f:
    f.write(f"Mean CER: {mean_cer}\n")
    f.write(f"Max CER: {max_cer}\n")
    f.write(f"Min CER: {min_cer}")


print(f"Mean CER: {mean_cer}")
print(f"Max CER: {max_cer}")
print(f"Min CER: {min_cer}")