In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import os

import numpy as np
import torch
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import random

from handwriting_recognition.label_converter import LabelConverter
from handwriting_recognition.model.model import HandwritingRecognitionModel
from handwriting_recognition.modelling_utils import get_image_model, get_optimizer
from handwriting_recognition.utils import TrainingConfig, get_dataset_folder_path
from handwriting_recognition.dataset import HandWritingDataset
from pathlib import Path
from handwriting_recognition.modelling_utils import get_device

torch.backends.cudnn.benchmark = True




In [3]:
CONFIG_NAME = "small_config"
MODEL_TO_TEST = "99"
MODEL_FILE = Path(get_dataset_folder_path()).parent.joinpath("model_outputs", CONFIG_NAME, MODEL_TO_TEST)

In [4]:
config_path = Path(get_dataset_folder_path()).parent.joinpath("handwriting_recognition", "configs", CONFIG_NAME).with_suffix(".json")
config = TrainingConfig.from_path(config_path=config_path)

In [5]:
data_test = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "validation.csv",
    img_size=config.feature_extractor_config.input_size,
)

data_train = HandWritingDataset(
    data_path=get_dataset_folder_path() / "pre_processed" / "train.csv",
    img_size=config.feature_extractor_config.input_size,
)

In [6]:
saved_model = torch.load(MODEL_FILE)

In [7]:
config = TrainingConfig(**saved_model['config'])

In [8]:
image_model = get_image_model(model_name=config.feature_extractor_config.model_name)
model = HandwritingRecognitionModel(image_feature_extractor=image_model, training_config=config)

In [9]:
model.load_state_dict(saved_model["state"])

<All keys matched successfully>

In [10]:
converter = LabelConverter(character_set=saved_model["character_set"], max_text_length=saved_model["max_text_length"])

In [11]:
model = model.to(get_device())

In [12]:
test_loader = DataLoader(
    data_test,
    batch_size=config.batch_size,
    pin_memory=False,
    drop_last=False,
)

train_loader = DataLoader(
    data_train,
    batch_size=config.batch_size,
    pin_memory=False,
    drop_last=False,
)

In [13]:
images, labels = next(iter(train_loader))

In [14]:
images.shape

torch.Size([4, 1, 224, 224])

In [15]:
text, length = converter.encode(labels)
text = text.to(device=get_device())
images = images.to(device=get_device())

In [16]:
converter.decode(text, length)

['[GO]BALTHAZAR[s][GO][GO]',
 '[GO]SIMON[s][GO][GO][GO][GO][GO][GO]',
 '[GO]BENES[s][GO][GO][GO][GO][GO][GO]',
 '[GO]LA LOVE[s][GO][GO][GO][GO]']

In [17]:
preds = model(x=images, y=text, is_train=False)

Image Features Shape: torch.Size([4, 197, 768])
LSTM Features Shape: torch.Size([4, 197, 256])
Batch H Shape: torch.Size([4, 197, 256])
Text Shape: torch.Size([4, 13])
Num Steps: 12
Step 0 Char Onehots Shape: torch.Size([4, 21])
Step 1 Char Onehots Shape: torch.Size([4, 21])
Step 2 Char Onehots Shape: torch.Size([4, 21])
Step 3 Char Onehots Shape: torch.Size([4, 21])
Step 4 Char Onehots Shape: torch.Size([4, 21])
Step 5 Char Onehots Shape: torch.Size([4, 21])
Step 6 Char Onehots Shape: torch.Size([4, 21])
Step 7 Char Onehots Shape: torch.Size([4, 21])
Step 8 Char Onehots Shape: torch.Size([4, 21])
Step 9 Char Onehots Shape: torch.Size([4, 21])
Step 10 Char Onehots Shape: torch.Size([4, 21])
Step 11 Char Onehots Shape: torch.Size([4, 21])
Prediction Shape: torch.Size([4, 12, 21])


In [18]:
preds = preds[:, :text.shape[1] - 1, :]

In [19]:
_, preds_index = preds.max(2)

In [21]:
converter.decode(preds_index, length)

['LASIE[s][s][s][s][s][s][s]',
 'LASIE[s][s][s][s][s][s][s]',
 'LASIE[s][s][s][s][s][s][s]',
 'LASIE[s][s][s][s][s][s][s]']