In [None]:
!pip install albumentations pyctcdecode pyewts botok huggingface_hub natsort

In [None]:
from huggingface_hub import snapshot_download
from zipfile import ZipFile
import os

# Download dataset
dataset_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset", cache_dir="Datasets")
with ZipFile(f"{dataset_path}/data.zip", 'r') as zip:
    zip.extractall(f"{dataset_path}/Dataset")
dataset_path = os.path.join(dataset_path, "Dataset")

# Download model
model_path = snapshot_download(repo_id="BDRC/BigUCHAN_v1", repo_type="model", cache_dir="Models")


In [None]:
import torch
import numpy as np
from glob import glob
from BudaOCR.Modules import EasterNetwork, OCRTrainer, WylieEncoder
from BudaOCR.Utils import create_dir, shuffle_data, build_data_paths, read_ctc_model_config
import time

model_config = f"{model_path}/config.json"
ctc_config = read_ctc_model_config(model_config)

label_encoder = WylieEncoder(ctc_config.charset)
num_classes = label_encoder.num_classes()
image_width = ctc_config.input_width
image_height = ctc_config.input_height

image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
network.fine_tune(f"{model_path}/BigUCHAN_E_v1.pth")

output_dir = "Output"
create_dir(output_dir)

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=label_encoder,
    workers=4,
    image_width=image_width,
    image_height=image_height,
    batch_size=32,
    output_dir=output_dir,
    preload_labels=True
)

ocr_trainer.init(image_paths, label_paths)


In [None]:
def get_latest_checkpoint(output_dir: str):
    ckpts = glob(os.path.join(output_dir, "checkpoint_epoch_*.pth"))
    if not ckpts:
        return None, 0
    ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    latest_ckpt = ckpts[-1]
    epoch_num = int(latest_ckpt.split("_")[-1].split(".")[0])
    return latest_ckpt, epoch_num

latest_ckpt, start_epoch = get_latest_checkpoint(ocr_trainer.output_dir)

if latest_ckpt:
    print(f"üîÅ Resuming from checkpoint: {latest_ckpt} (epoch {start_epoch})")
    network.load_checkpoint(latest_ckpt)
else:
    print("üÜï Starting training from scratch.")


In [None]:
from google.colab import drive
!fusermount -u /content/drive || true  # –±–µ–∑–æ–ø–∞—Å–Ω–æ —Ä–∞–∑–º–æ–Ω—Ç–∏—Ä–æ–≤–∞—Ç—å, –µ—Å–ª–∏ —É–∂–µ –µ—Å—Ç—å
drive.mount('/content/drive', force_remount=True)

In [None]:
import shutil

drive_backup_dir = "/content/drive/MyDrive/ocr_checkpoints"

if os.path.exists(drive_backup_dir):
    print("üì¶ Restoring checkpoint from Google Drive...")
    shutil.copytree(drive_backup_dir, ocr_trainer.output_dir, dirs_exist_ok=True)
else:
    print("üìÇ No previous checkpoint found in Google Drive.")


In [None]:
# –û–±—É—á–µ–Ω–∏–µ –ø–æ –æ–¥–Ω–æ–π —ç–ø–æ—Ö–µ —Å –∞–≤—Ç–æ—Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ–º –∏ —Å–∏–Ω—Ö—Ä–æ–Ω–∏–∑–∞—Ü–∏–µ–π —Å Google Drive
import time
import torch
import os
import shutil

total_epochs = 64

for epoch in range(start_epoch, total_epochs):
    print(f"[{time.ctime()}] üß™ Epoch {epoch+1}/{total_epochs}")
    
    # –æ–¥–Ω–∞ —ç–ø–æ—Ö–∞ –æ–±—É—á–µ–Ω–∏—è
    ocr_trainer.train(epochs=1, check_cer=False, export_onnx=False)

    # —Å–æ—Ö—Ä–∞–Ω—è–µ–º —á–µ–∫–ø–æ–∏–Ω—Ç –∫–∞–∂–¥—ã–µ 5 —ç–ø–æ—Ö
    if (epoch + 1) % 5 == 0:
        ckpt_path = os.path.join(ocr_trainer.output_dir, f"checkpoint_epoch_{epoch+1}.pth")
        torch.save(network.get_checkpoint(), ckpt_path)
        print(f"[{time.ctime()}] üíæ Saved: {ckpt_path}")

        # —Å–∏–Ω—Ö—Ä–æ–Ω–∏–∑–∞—Ü–∏—è —Å Google Drive
        drive_backup_dir = "/content/drive/MyDrive/ocr_checkpoints"
        shutil.copytree(ocr_trainer.output_dir, drive_backup_dir, dirs_exist_ok=True)

        # –ª–æ–≥–∏—Ä—É–µ–º
        with open(os.path.join(ocr_trainer.output_dir, "training_log.txt"), "a") as log:
            log.write(f"[{time.ctime()}] Epoch {epoch+1} checkpoint saved and synced.\n")

# –§–∏–Ω–∞–ª—å–Ω—ã–π —ç–∫—Å–ø–æ—Ä—Ç –º–æ–¥–µ–ª–∏ –≤ ONNX
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")


In [None]:
cer_scores = ocr_trainer.evaluate()
cer_values = list(cer_scores.values())

with open(os.path.join(ocr_trainer.output_dir, "cer_scores.txt"), "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}\\n")

with open(os.path.join(ocr_trainer.output_dir, "cer_summary.txt"), "w", encoding="utf-8") as f:
    f.write(f"Mean CER: {np.mean(cer_values)}\\n")
    f.write(f"Max CER: {np.max(cer_values)}\\n")
    f.write(f"Min CER: {np.min(cer_values)}\\n")

# Export ONNX
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")
