In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import os
import string
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from typing import * 


import torch
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader

import torchvision.transforms as VT


from ocrnune.data import dataset
from ocrnune.models import crnn

import ocrnune.transforms as NT
from ocrnune.data.dataset import LMDBDataset, BalanceDatasetConcatenator


In [3]:
# enc = crnn.EncoderOCR()
# test_data = torch.rand(3,1,224,224)
# out = enc(test_data)
# print(out.shape)


In [11]:
batch_size = 4
batch_max_length = 25
character = string.printable[:-6]
img_size = (32,100)

trn_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=img_size),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])

val_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=img_size),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])


trn_path = '/data/lmdb/data_lmdb_release/training'
val_path = '/data/lmdb/data_lmdb_release/validation'



train_bdc = BalanceDatasetConcatenator(trn_path, dataset_class=LMDBDataset, 
                                       transform=trn_transform,
                                       subdir=('ST', 'MJ'), usage_ratio=(0.5, 0.5),
                                       im_size=img_size, is_sensitive=False)
trainset = train_bdc.get_dataset()


valid_bdc = BalanceDatasetConcatenator(val_path, dataset_class=LMDBDataset, 
                                       transform=val_transform,
                                       im_size=img_size, is_sensitive=False)
validset = valid_bdc.get_dataset()


In [12]:
len(trainset), len(validset)

(7221024, 6992)

In [13]:
from ocrnune.utils import AttnLabelConverter
converter = AttnLabelConverter(character)

train_loader = DataLoader(trainset, batch_size=4)
imgs, texts =  next(iter(train_loader))



In [14]:
texts_encode, length = converter.encode(texts)
print(texts)
print(texts_encode)

('with', 'nancy', 'bob', 'the')
tensor([[ 0, 34, 20, 31, 19,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 25, 12, 25, 14, 36,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 13, 26, 13,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 31, 19, 16,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0]])


In [15]:
converter.decode(texts_encode, length)

['[GO]with[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]nancy[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]bob[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]the[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]']

In [16]:
num_class = len(converter.character)
num_class

96

In [17]:
model = crnn.OCR(num_class=num_class, im_size=img_size)

In [18]:
out = model(imgs, texts_encode[:, :-1])

In [19]:
out[0].shape

torch.Size([26, 96])

In [28]:
import torch.nn.functional as NNF

batch_size = 4 
batch_max_length = 25
# pred_length = batch_size * batch_max_length
pred_length = torch.IntTensor([batch_max_length] * batch_size)
# preds_prob = NNF.softmax(out, dim=2)

preds_prob = NNF.softmax(out, dim=2)
_, preds_index = preds_prob.max(dim=2)
# preds_max_prob[0]

preds_str = converter.decode(preds_index[:, 1:], pred_length)
# preds_index.shape
preds_str

['/////////////////////////',
 '/////////////////////////',
 '/////////////////////////',
 '/////////////////////////']