In [52]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
import argparse
import os

import numpy as np
import torch
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import random

from handwriting_recognition.label_converter import LabelConverter
from handwriting_recognition.model.model import HandwritingRecognitionModel
from handwriting_recognition.modelling_utils import get_image_model, get_optimizer
from handwriting_recognition.utils import TrainingConfig, get_dataset_folder_path
from handwriting_recognition.dataset import HandWritingDataset
from pathlib import Path
from handwriting_recognition.modelling_utils import get_device

torch.backends.cudnn.benchmark = True

In [54]:
CONFIG_NAME = "default_config"
MODEL_TO_TEST = "1"
MODEL_FILE = Path(get_dataset_folder_path()).parent.joinpath("model_outputs", CONFIG_NAME, MODEL_TO_TEST)

In [55]:
config_path = Path(get_dataset_folder_path()).parent.joinpath("handwriting_recognition", "configs", CONFIG_NAME).with_suffix(".json")
config = TrainingConfig.from_path(config_path=config_path)

In [56]:
data_test = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "validation.csv",
    img_size=config.feature_extractor_config.input_size,
)

In [57]:
saved_model = torch.load(MODEL_FILE)

In [58]:
config = TrainingConfig(**saved_model['config'])

In [59]:
image_model = get_image_model(model_name=config.feature_extractor_config.model_name)
model = HandwritingRecognitionModel(image_feature_extractor=image_model, training_config=config)

In [60]:
model.load_state_dict(saved_model["state"])

<All keys matched successfully>

In [61]:
converter = LabelConverter(character_set=saved_model["character_set"], max_text_length=saved_model["max_text_length"])

In [62]:
model = model.to(get_device())

In [63]:
test_loader = DataLoader(
    data_test,
    batch_size=config.batch_size,
    pin_memory=False,
    drop_last=False,
)

In [64]:
images, labels = next(iter(test_loader))

In [65]:
images.shape

torch.Size([3, 1, 224, 224])

In [66]:
text, length = converter.encode(labels)
text = text.to(device=get_device())
images = images.to(device=get_device())

In [67]:
converter.decode(text, length)

['[GO]BILEL[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]LAUMIONIER[s][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]JEAN ROCH[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]']

In [68]:
preds = model(x=images, y=text, is_train=False)

In [69]:
preds = preds[:, :text.shape[1] - 1, :]

In [71]:
_, preds_index = preds.max(2)

In [72]:
converter.decode(preds_index, length)

['U[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]',
 'U[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]',
 'U[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]']