In [None]:
!pip install albumentations pyctcdecode pyewts botok huggingface_hub natsort

In [None]:
from huggingface_hub import snapshot_download
from zipfile import ZipFile
import os

# Download dataset
dataset_path = snapshot_download(repo_id="BDRC/Karmapa8", repo_type="dataset", cache_dir="Datasets")
with ZipFile(f"{dataset_path}/data.zip", 'r') as zip:
    zip.extractall(f"{dataset_path}/Dataset")
dataset_path = os.path.join(dataset_path, "Dataset")

# Download model
model_path = snapshot_download(repo_id="BDRC/BigUCHAN_v1", repo_type="model", cache_dir="Models")


In [None]:
import torch
import numpy as np
from glob import glob
from BudaOCR.Modules import EasterNetwork, OCRTrainer, WylieEncoder
from BudaOCR.Utils import create_dir, shuffle_data, build_data_paths, read_ctc_model_config
import time

model_config = f"{model_path}/config.json"
ctc_config = read_ctc_model_config(model_config)

label_encoder = WylieEncoder(ctc_config.charset)
num_classes = label_encoder.num_classes()
image_width = ctc_config.input_width
image_height = ctc_config.input_height

image_paths, label_paths = build_data_paths(dataset_path)
image_paths, label_paths = shuffle_data(image_paths, label_paths)

network = EasterNetwork(num_classes=num_classes, image_width=image_width, image_height=image_height, mean_pooling=True)
network.fine_tune(f"{model_path}/BigUCHAN_E_v1.pth")

output_dir = "Output"
create_dir(output_dir)

ocr_trainer = OCRTrainer(
    network=network,
    label_encoder=label_encoder,
    workers=4,
    image_width=image_width,
    image_height=image_height,
    batch_size=8,
    output_dir=output_dir,
    preload_labels=True
)

ocr_trainer.init(image_paths, label_paths)


In [None]:
import shutil
import os
from glob import glob
import time
import torch

def get_latest_checkpoint(output_dir: str):
    ckpts = glob(os.path.join(output_dir, "checkpoint_epoch_*.pth"))
    if not ckpts:
        return None, 0
    ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    latest_ckpt = ckpts[-1]
    epoch_num = int(latest_ckpt.split("_")[-1].split(".")[0])
    return latest_ckpt, epoch_num

# –ü—Ä–æ–≤–µ—Ä—è–µ–º –ª–æ–∫–∞–ª—å–Ω—ã–µ —á–µ–∫–ø–æ–∏–Ω—Ç—ã –≤ Output
latest_ckpt, start_epoch = get_latest_checkpoint(ocr_trainer.output_dir)

if latest_ckpt:
    print(f"üîÅ Resuming from local Output checkpoint: {latest_ckpt} (epoch {start_epoch})")
    network.load_checkpoint(latest_ckpt)
else:
    # –ï—Å–ª–∏ –ª–æ–∫–∞–ª—å–Ω—ã—Ö –Ω–µ—Ç, –ø—Ä–æ–≤–µ—Ä—è–µ–º Google Drive
    drive_backup_dir = "/content/drive/MyDrive/ocr_checkpoints"
    if os.path.exists(drive_backup_dir):
        print("üì¶ Restoring checkpoint from Google Drive...")
        shutil.copytree(drive_backup_dir, ocr_trainer.output_dir, dirs_exist_ok=True)
        latest_ckpt, start_epoch = get_latest_checkpoint(ocr_trainer.output_dir)
        if latest_ckpt:
            print(f"üîÅ Resuming from Google Drive checkpoint: {latest_ckpt} (epoch {start_epoch})")
            network.load_checkpoint(latest_ckpt)
        else:
            print("üö® No checkpoint found after copying from Drive, starting fresh.")
            start_epoch = 0
    else:
        print("üìÇ No checkpoint found locally or in Google Drive, starting fresh.")
        start_epoch = 0


In [None]:
# –ú–æ–Ω—Ç–∏—Ä—É–µ–º Google Drive
from google.colab import drive
!fusermount -u /content/drive || true
drive.mount('/content/drive', force_remount=True)

# –ü–∞–ø–∫–∞ –¥–ª—è —Ä–µ–∑–µ—Ä–≤–Ω—ã—Ö –∫–æ–ø–∏–π –≤ Drive
drive_ckpt_dir = "/content/drive/MyDrive/ocr_checkpoints"
os.makedirs(drive_ckpt_dir, exist_ok=True)

# –û–±—É—á–µ–Ω–∏–µ –ø–æ —ç–ø–æ—Ö–∞–º —Å –∞–≤—Ç–æ—Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ–º –∏ –∫–æ–ø–∏—Ä–æ–≤–∞–Ω–∏–µ–º —á–µ–∫–ø–æ–∏–Ω—Ç–æ–≤ –≤ Drive
total_epochs = 64
for epoch in range(start_epoch, total_epochs):
    print(f"[{time.ctime()}] üß™ Epoch {epoch+1}/{total_epochs}")

    # –û–±—É—á–∞–µ–º –æ–¥–Ω—É —ç–ø–æ—Ö—É
    ocr_trainer.train(epochs=1, check_cer=False, export_onnx=False)

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º —á–µ–∫–ø–æ–∏–Ω—Ç –∫–∞–∂–¥—ã–µ 5 —ç–ø–æ—Ö
    if (epoch + 1) % 5 == 0:
        ckpt_name = f"checkpoint_epoch_{epoch+1}.pth"
        ckpt_path = os.path.join(ocr_trainer.output_dir, ckpt_name)
        drive_ckpt_path = os.path.join(drive_ckpt_dir, ckpt_name)

        torch.save(network.get_checkpoint(), ckpt_path)
        assert os.path.isfile(ckpt_path), f"‚ùå Failed to save checkpoint locally: {ckpt_path}"

        shutil.copy(ckpt_path, drive_ckpt_path)
        print(f"[{time.ctime()}] üíæ Checkpoint saved and copied to Drive: {drive_ckpt_path}")

        with open(os.path.join(ocr_trainer.output_dir, "training_log.txt"), "a") as log:
            log.write(f"[{time.ctime()}] Epoch {epoch+1}: checkpoint saved and synced to Drive.\n")

In [None]:
# –í –∫–æ–Ω—Ü–µ —ç–∫—Å–ø–æ—Ä—Ç –º–æ–¥–µ–ª–∏ –≤ ONNX
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")

In [None]:
from google.colab import drive
!fusermount -u /content/drive || true  # –±–µ–∑–æ–ø–∞—Å–Ω–æ —Ä–∞–∑–º–æ–Ω—Ç–∏—Ä–æ–≤–∞—Ç—å, –µ—Å–ª–∏ —É–∂–µ –µ—Å—Ç—å
drive.mount('/content/drive', force_remount=True)

In [None]:
import shutil

drive_backup_dir = "/content/drive/MyDrive/ocr_checkpoints"

if os.path.exists(drive_backup_dir):
    print("üì¶ Restoring checkpoint from Google Drive...")
    shutil.copytree(drive_backup_dir, ocr_trainer.output_dir, dirs_exist_ok=True)
else:
    print("üìÇ No previous checkpoint found in Google Drive.")


In [None]:
import time
import torch
import os
import shutil

total_epochs = 64
drive_ckpt_dir = "/content/drive/MyDrive/ocr_checkpoints"
os.makedirs(drive_ckpt_dir, exist_ok=True)

for epoch in range(start_epoch, total_epochs):
    print(f"[{time.ctime()}] üß™ Epoch {epoch+1}/{total_epochs}")
    
    # –æ–¥–Ω–∞ —ç–ø–æ—Ö–∞ –æ–±—É—á–µ–Ω–∏—è
    ocr_trainer.train(epochs=1, check_cer=False, export_onnx=False)

    # —Å–æ—Ö—Ä–∞–Ω—è–µ–º —á–µ–∫–ø–æ–∏–Ω—Ç –∫–∞–∂–¥—ã–µ 5 —ç–ø–æ—Ö
    if (epoch + 1) % 5 == 0:
        ckpt_name = f"checkpoint_epoch_{epoch+1}.pth"
        ckpt_path = os.path.join(ocr_trainer.output_dir, ckpt_name)
        drive_ckpt_path = os.path.join(drive_ckpt_dir, ckpt_name)

        # —Å–æ—Ö—Ä–∞–Ω—è–µ–º –Ω–∞ –¥–∏—Å–∫
        torch.save(network.get_checkpoint(), ckpt_path)
        assert os.path.isfile(ckpt_path), f"‚ùå Failed to save checkpoint locally: {ckpt_path}"

        # –∫–æ–ø–∏—Ä—É–µ–º –≤ Google Drive
        shutil.copy(ckpt_path, drive_ckpt_path)
        print(f"[{time.ctime()}] üíæ Checkpoint saved and copied to Drive: {drive_ckpt_path}")

        # –ª–æ–≥–∏—Ä—É–µ–º
        with open(os.path.join(ocr_trainer.output_dir, "training_log.txt"), "a") as log:
            log.write(f"[{time.ctime()}] Epoch {epoch+1}: checkpoint saved and synced to Drive.\n")


In [None]:
# –û—Ç–¥–µ–ª—å–Ω—ã–π —ç–∫—Å–ø–æ—Ä—Ç ONNX ‚Äî –±–µ–∑–æ–ø–∞—Å–Ω–æ –∑–∞–ø—É—Å–∫–∞—Ç—å –æ—Ç–¥–µ–ª—å–Ω–æ
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")

In [None]:
cer_scores = ocr_trainer.evaluate()
cer_values = list(cer_scores.values())

with open(os.path.join(ocr_trainer.output_dir, "cer_scores.txt"), "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}\\n")

with open(os.path.join(ocr_trainer.output_dir, "cer_summary.txt"), "w", encoding="utf-8") as f:
    f.write(f"Mean CER: {np.mean(cer_values)}\\n")
    f.write(f"Max CER: {np.max(cer_values)}\\n")
    f.write(f"Min CER: {np.min(cer_values)}\\n")

# Export ONNX
network.export_onnx(out_dir=ocr_trainer.output_dir, model_name="OCRModel")
