In [1]:
import torch
import yaml
from dataset import SynthTextDataset
from model import CRNN
import numpy as np
from utils import collate_pad, decode, convert2str

# config
config_file = open('config.yaml', 'r', encoding='utf-8')
config = config_file.read()
config_file.close()
config = yaml.full_load(config)

# dict
lexicon = [x for x in config['lexicon']['chars']]

# dataset
train_dataset = SynthTextDataset(config, mode='train')
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=32,
                                                shuffle=True,
                                                drop_last=False,
                                                num_workers=0,
                                                collate_fn=collate_pad)

val_dataset = SynthTextDataset(config, mode='val')
val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=32,
                                            shuffle=True,
                                            drop_last=False,
                                            num_workers=0,
                                            collate_fn=collate_pad)

# model
imgH = config['crnn']['imgH']
nc = config['crnn']['nc']
nClass = config['crnn']['nClass']
nh = config['crnn']['nh']
blank_index = config['lexicon']['blank']

model = CRNN(imgH, nc, nClass, nh)
model.eval()

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.

In [6]:
import matplotlib.pyplot as plt

for b, (img, label, label_encode) in enumerate(val_dataloader):
    out = model.forward(img)
    print(out)
    
    decoded = decode(out.detach().numpy())
    print(decoded)


tensor([[[-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8156],
         [-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8157],
         [-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8156],
         ...,
         [-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8157],
         [-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8157],
         [-1.7868, -1.7935, -1.7703, -1.8150, -1.7703, -1.8156]],

        [[-1.7904, -1.7888, -1.7663, -1.8197, -1.7675, -1.8192],
         [-1.7904, -1.7888, -1.7663, -1.8197, -1.7675, -1.8192],
         [-1.7904, -1.7888, -1.7663, -1.8197, -1.7675, -1.8192],
         ...,
         [-1.7904, -1.7888, -1.7663, -1.8198, -1.7674, -1.8192],
         [-1.7904, -1.7888, -1.7663, -1.8197, -1.7675, -1.8192],
         [-1.7904, -1.7888, -1.7663, -1.8197, -1.7675, -1.8192]],

        [[-1.7916, -1.7861, -1.7646, -1.8229, -1.7657, -1.8213],
         [-1.7916, -1.7861, -1.7646, -1.8229, -1.7657, -1.8213],
         [-1.7916, -1.7861, -1.7646, -1.8229, -1.7657, -1.