In [None]:
!pip install albumentations pyctcdecode pyewts botok huggingface_hub natsort

In [None]:
from huggingface_hub import snapshot_download
from zipfile import ZipFile
import os

# Download dataset
dataset_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset", cache_dir="Datasets")
with ZipFile(f"{dataset_path}/data.zip", 'r') as zip:
    zip.extractall(f"{dataset_path}/Dataset")

# Download model
model_path = snapshot_download(repo_id="BDRC/Woodblock", repo_type="model", cache_dir="Models")


In [None]:
import torch
import numpy as np
from glob import glob
from BudaOCR.Modules import EasterNetwork, OCRTrainer, WylieEncoder
from BudaOCR.Utils import create_dir, shuffle_data, build_data_paths, read_ctc_model_config
import time
import os

# === Model config ===
model_path = snapshot_download(repo_id="BDRC/BigUCHAN_v1", repo_type="model", cache_dir="Models")
model_config = f"{model_path}/config.json"
ctc_config = read_ctc_model_config(model_config)

label_encoder = WylieEncoder(ctc_config.charset)
num_classes = label_encoder.num_classes()
image_width = ctc_config.input_width
image_height = ctc_config.input_height

# === Load dataset ===
image_paths, label_paths = build_data_paths(dataset_path)
print(f"Images: {len(image_paths)}, Labels: {len(label_paths)}")
image_paths, label_paths = shuffle_data(image_paths, label_paths)



# === Initialize network ===
network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
network.fine_tune(f"{model_path}/BigUCHAN_E_v1.pth")

# === Trainer Setup ===
output_dir = "Output"
create_dir(output_dir)

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=label_encoder,
    workers=4,
    image_width=image_width,
    image_height=image_height,
    batch_size=32,
    output_dir=output_dir,
    preload_labels=True
)

ocr_trainer.init(image_paths, label_paths)


In [None]:

# === Resume from disk ===
def get_latest_checkpoint(output_dir: str):
    ckpts = glob(os.path.join(output_dir, "checkpoint_epoch_*.pth"))
    if not ckpts:
        return None, 0
    ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    latest_ckpt = ckpts[-1]
    epoch_num = int(latest_ckpt.split("_")[-1].split(".")[0])
    return latest_ckpt, epoch_num

latest_ckpt, start_epoch = get_latest_checkpoint(ocr_trainer.output_dir)

if latest_ckpt:
    print(f"🔁 Resuming from checkpoint: {latest_ckpt} (epoch {start_epoch})")
    network.load_checkpoint(latest_ckpt)
else:
    print("🆕 Starting training from scratch.")


In [None]:
# Training loop with checkpointing and logging to file
total_epochs = 64
log_file = os.path.join(ocr_trainer.output_dir, "training_log.txt")

for epoch in range(start_epoch, total_epochs):
    now = time.ctime()
    print(f"[{now}] 🧪 Epoch {epoch+1}/{total_epochs}")
    with open(log_file, "a") as log:
        log.write(f"[{now}] Epoch {epoch+1}/{total_epochs}\n")

    loss = network.train(ocr_trainer.train_loader)

    with open(log_file, "a") as log:
        log.write(f"[{time.ctime()}] ✅ Loss: {loss:.4f}\n")

    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        ckpt_file = os.path.join(ocr_trainer.output_dir, f"checkpoint_epoch_{epoch+1}.pth")
        torch.save(network.get_checkpoint(), ckpt_file)
        print(f"[{time.ctime()}] 💾 Saved checkpoint: {ckpt_file}")
        with open(log_file, "a") as log:
            log.write(f"[{time.ctime()}] 💾 Saved checkpoint: {ckpt_file}\n")


In [None]:
# Evaluate and export
cer_scores = ocr_trainer.evaluate()
cer_values = list(cer_scores.values())

with open(os.path.join(ocr_trainer.output_dir, "cer_scores.txt"), "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}\n")

with open(os.path.join(ocr_trainer.output_dir, "cer_summary.txt"), "w", encoding="utf-8") as f:
    f.write(f"Mean CER: {np.mean(cer_values)}\n")
    f.write(f"Max CER: {np.max(cer_values)}\n")
    f.write(f"Min CER: {np.min(cer_values)}\n")

# Export ONNX
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")