In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os

import numpy as np
import torch
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import random

from handwriting_recognition.label_converter import LabelConverter
from handwriting_recognition.model.model import HandwritingRecognitionModel
from handwriting_recognition.modelling_utils import get_image_model
from handwriting_recognition.utils import TrainingConfig, get_dataset_folder_path
from handwriting_recognition.dataset import HandWritingDataset
from pathlib import Path
from handwriting_recognition.modelling_utils import get_device
from handwriting_recognition.eval import cer, wer

torch.backends.cudnn.benchmark = True

In [None]:
CONFIG_NAME = "default_config"
MODEL_TO_TEST = "checkpoint.pt"
MODEL_FILE = Path(get_dataset_folder_path()).parent.joinpath("model_outputs", CONFIG_NAME, MODEL_TO_TEST)

In [None]:
saved_model

In [None]:
config_path = Path(get_dataset_folder_path()).parent.joinpath("handwriting_recognition", "configs", CONFIG_NAME).with_suffix(".json")
config = TrainingConfig.from_path(config_path=config_path)

In [None]:
data_train = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "train.csv",
    img_size=config.feature_extractor_config.input_size,
)

data_val = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "validation.csv",
    img_size=config.feature_extractor_config.input_size,
)

data_test = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "test.csv",
    img_size=config.feature_extractor_config.input_size,
)

In [None]:
saved_model = torch.load(MODEL_FILE)

In [None]:
if "scheduler_config" not in saved_model['config']:
    saved_model['config']['scheduler_config'] = None
if "beta1" not in saved_model['config']['optim_config']:
    saved_model['config']['optim_config']['beta1'] = 0.95
if "beta2" not in saved_model['config']['optim_config']:
    saved_model['config']['optim_config']['beta2'] = 0.99

In [None]:
config = TrainingConfig(**saved_model['config'])

In [None]:
image_model = get_image_model(model_name=config.feature_extractor_config.model_name)
model = HandwritingRecognitionModel(image_feature_extractor=image_model, training_config=config)

In [None]:
model.load_state_dict(saved_model["state"])

In [None]:
converter = LabelConverter(character_set=saved_model['character_set'], max_text_length=saved_model['max_text_length'])

In [None]:
model = model.to(get_device())
model = model.eval()

In [None]:
BATCH_SIZE = 64

In [None]:
train_loader = DataLoader(
    data_train,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=False,
    drop_last=False,
)

val_loader = DataLoader(
    data_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=False,
    drop_last=False,
)

test_loader = DataLoader(
    data_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=False,
    drop_last=False,
)

In [None]:
images, labels = next(iter(train_loader))

In [None]:
images.shape

In [None]:
text, length = converter.encode(labels)
text = text.to(device=get_device())
images = images.to(device=get_device())

In [None]:
preds = model(x=images, y=text[:, :-1], is_train=False)
target = text[:, 1:]

In [None]:
predicted_classes = preds.argmax(dim=-1)

In [None]:
pred_decoded = converter.decode(predicted_classes, length)
pred_decoded[:3]

In [None]:
target_decoded = converter.decode(target, length)
target_decoded[:3]

In [None]:
from handwriting_recognition.train import _evaluate

loss_function = CrossEntropyLoss(ignore_index=0).to(get_device())

validation_loss, val_character_error_rate, val_word_error_rate, all_val_preds, all_val_ground_truths = _evaluate(
    epoch=saved_model['epoch'],
    model=model,
    data_loader=val_loader,
    converter=converter,
    loss_function=loss_function
)

In [None]:
val_character_error_rate

In [None]:
val_word_error_rate

In [None]:
test_loss, test_character_error_rate, test_word_error_rate, all_test_preds, all_test_ground_truths = _evaluate(
    epoch=saved_model['epoch'],
    model=model,
    data_loader=test_loader,
    converter=converter,
    loss_function=loss_function
)

In [None]:
test_character_error_rate

In [None]:
test_word_error_rate

In [None]:
import pickle

with open("updated_test_ground_truths.pkl", 'wb') as f:
    pickle.dump(all_test_ground_truths, f)

with open("updated_test_preds.pkl", 'wb') as f:
    pickle.dump(all_test_preds, f)

with open("updated_val_ground_truths.pkl", 'wb') as f:
    pickle.dump(all_val_ground_truths, f)
    
with open("updated_val_preds.pkl", 'wb') as f:
    pickle.dump(all_val_preds, f)


In [None]:
for i, (pred, gt) in enumerate(zip(all_test_preds, all_test_ground_truths)):
    if cer(pred, gt) == 1:
        print(pred, gt, i)

In [None]:
data_test.df.iloc[16070]

In [None]:
import matplotlib.pyplot as plt 
from PIL import Image
img_arr = Image.open("/home/faraz/Documents/code/handwriting-recognition/dataset/pre_processed/test/TEST_16687.tiff")
plt.imshow(img_arr)