In [None]:
import os
import cv2
import torch
import pyewts
import random
import numpy as np
from tqdm import tqdm
from glob import glob
from evaluate import load
from natsort import natsorted
from Modules import Easter2Inference, TrOCRInference, CRNNInference
from huggingface_hub import snapshot_download
from Utils import get_filename, read_ctc_model_config, read_label, show_image, preprare_ocr_line

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
print(device)

# set up wylie converter and the CER scorer
converter = pyewts.pyewts()
cer_scorer = load("cer")

In [None]:
# setup data
data_path = snapshot_download(repo_id="BDRC/KhyentseWangpo", repo_type="dataset",  cache_dir="Datasets")

lines = natsorted(glob(f"{data_path}/lines/*.jpg"))
labels = natsorted(glob(f"{data_path}/transcriptions/*.txt"))

print(f"Images: {len(lines)}, Labels: {len(labels)}")

In [None]:
# show a sample from the dataset
idx = random.randint(0, len(lines)-1)
img = cv2.imread(lines[idx])
show_image(img)

#### Scoring unsing CRNN Model

In [None]:
# download a model: https://huggingface.co/BDRC/GoogleBooks_C_v1
model_id = "BDRC/GoogleBooks_C_v1"
model_path = snapshot_download(
                repo_id=model_id,
                repo_type="model",
                local_dir=f"Models/{model_id}",
            )

print(model_path)
model_config = f"{model_path}/config.json"

assert(os.path.isfile(model_config))

ocr_config = read_ctc_model_config(model_config)
crnn_inference = CRNNInference(ocr_config)

In [None]:
crnn_scores = {}

for image_path, label_path in tqdm(zip(lines, labels), total=len(lines)):
    image_n = get_filename(image_path)
    image = cv2.imread(image_path)

    gt_lbl = read_label(label_path)
    gt_lbl = converter.toWylie(gt_lbl)
    prediction = crnn_inference.predict(image)

    try:
        if prediction != "" and gt_lbl != "":
            cer_score = cer_scorer.compute(predictions=[prediction], references=[gt_lbl])
            crnn_scores[image_n] = cer_score
    except BaseException as e:
        print(f"Failed to calculate CER for prediction: {prediction} against labek: {gt_lbl}, raised exception: {e}")


cer_values = list(crnn_scores.values())
mean_cer = np.mean(cer_values)
max_cer = np.max(cer_values)
min_cer = np.min(cer_values)
print(f"Mean CER: {mean_cer}, Max CER: {max_cer}, Min CER: {min_cer}")

#### Scoring using Easter2 Model

In [None]:
# download the model: https://huggingface.co/BDRC/GoogleBooks_E_v1
model_id = "BDRC/GoogleBooks_E_v1"
model_path = snapshot_download(
                repo_id=model_id,
                repo_type="model",
                local_dir=f"Models/{model_id}",
            )

print(model_path)
model_config = f"{model_path}/config.json"

assert(os.path.isfile(model_config))

ocr_config = read_ctc_model_config(model_config)
easter2_inference = Easter2Inference(ocr_config)

In [None]:
easter_cer_scores = {}

for image_path, label_path in tqdm(zip(lines, labels), total=len(lines)):
    image_n = get_filename(image_path)
    image = cv2.imread(image_path)

    gt_lbl = read_label(label_path)
    gt_lbl = converter.toWylie(gt_lbl)
    prediction = easter2_inference.predict(image)

    try:
        if prediction != "" and gt_lbl != "":
            cer_score = cer_scorer.compute(predictions=[prediction], references=[gt_lbl])
            easter_cer_scores[image_n] = cer_score
    except BaseException as e:
        print(f"Failed to calculate CER for prediction: {prediction} against labek: {gt_lbl}, raised exception: {e}")


cer_values = list(easter_cer_scores.values())
mean_cer = np.mean(cer_values)
max_cer = np.max(cer_values)
min_cer = np.min(cer_values)
print(f"Mean CER: {mean_cer}, Max CER: {max_cer}, Min CER: {min_cer}")

#### Scoring using TrOCR

In [None]:
# download the model: https://huggingface.co/BDRC/GoogleBooks_T_v1

model_id = "BDRC/GoogleBooks_T_v1"
checkpoint = snapshot_download(
                repo_id=model_id,
                repo_type="model",
                local_dir=f"Models/{model_id}",
            )
trocr_inference = TrOCRInference(checkpoint)

In [None]:
trocr_scores = {}

for image_path, label_path in tqdm(zip(lines, labels), total=len(lines)):
    image_n = get_filename(image_path)
    image = cv2.imread(image_path)

    gt_lbl = read_label(label_path)
    prediction = trocr_inference.predict(image)

    try:
        if prediction != "" and gt_lbl != "":
            cer_score = cer_scorer.compute(predictions=[prediction], references=[gt_lbl])
            trocr_scores[image_n] = cer_score
    except BaseException as e:
        print(f"Failed to calculate CER for prediction: {prediction} against labek: {gt_lbl}, raised exception: {e}")


cer_values = list(trocr_scores.values())
mean_cer = np.mean(cer_values)
max_cer = np.max(cer_values)
min_cer = np.min(cer_values)
print(f"Mean CER: {mean_cer}, Max CER: {max_cer}, Min CER: {min_cer}")