In [None]:
import os
import numpy as np
from tqdm import tqdm
from evaluate import load
from torch.utils.data import DataLoader
from BudaOCR.Config import CHARSET, STACK_FILE
from BudaOCR.Modules import CTCDataset, ModelTester, EasterNetwork, CRNNNetwork, WylieEncoder, StackEncoder
from BudaOCR.Utils import build_distribution_paths, get_filename, read_distribution, read_ocr_model_config

In [2]:
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

In [4]:
dataset_dir = "E:/Datasets/OCR/DbuMed/NEW/Drutsa-Complete/batch31"
chkpt_dir = f"{dataset_dir}/Output/2024_9_23_16_12"

config_file = f"{chkpt_dir}/model_config.json"
distribution_file = f"{chkpt_dir}/data.distribution"

assert os.path.isfile(config_file)
assert os.path.isfile(distribution_file)

In [7]:
checkpoint, architecture, encoder, input_width, input_height, charset = read_ocr_model_config(config_file)
train_samples, valid_samples, test_samples = read_distribution(distribution_file)
test_images, test_label_paths = build_distribution_paths(dataset_dir, test_samples)

print(f"Train samples: {len(train_samples)}")
print(f"Valid samples: {len(valid_samples)}")
print(f"Test samples: {len(test_samples)}")

Train samples: 2682
Valid samples: 335
Test samples: 336


In [None]:
if encoder == "wylie":
    label_encoder = WylieEncoder(charset)
else:
    label_encoder = StackEncoder(charset)

test_labels = [label_encoder.read_label(x) for x in test_label_paths]
num_classes = label_encoder.num_classes()


if architecture == "Easter2":
    network = EasterNetwork(input_width, input_height, num_classes)
else:
    network = CRNNNetwork(input_width, input_height, num_classes)
    
test_dataset = CTCDataset(
        images=test_images,
        labels=test_labels,
        label_encoder=label_encoder,
        img_height=input_height,
        img_width=input_width
    )


model_tester = ModelTester(network, label_encoder)

In [None]:
network.device

In [None]:
cer_scores = model_tester.evaluate(test_dataset, test_label_paths)
cer_values = list(cer_scores.values())

print(f"Mean CER: {np.mean(cer_values)}")
print(f"Max CER: {np.max(cer_values)}")
print(f"Min CER: {np.min(cer_values)}")

score_file = os.path.join(chkpt_dir, "cer_scores.txt")

with open(score_file, "w", encoding="utf-8") as f:
    for sample, value in cer_scores.items():
        f.write(f"{sample} - {value}/n")