In [1]:
%load_ext autoreload
%autoreload 2

In [48]:
import os
import string
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from typing import * 


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader

import torchvision.transforms as VT


from iqradre.recog.data import dataset
from iqradre.recog.models import crnn_v1 as crnn

import iqradre.recog.transforms as NT
from iqradre.recog.data.dataset import LMDBDataset, BalanceDatasetConcatenator
from iqradre.recog.data import loader
from iqradre.recog.utils import AttnLabelConverter
from iqradre.recog.trainer.task import TaskOCR

import torchvision.transforms as VT
from iqradre.recog import transforms as NT

In [3]:
BATCH_SIZE = 4
NUM_WORKERS = 4
BATCH_MAX_LENGTH = 25
SHUFFLE = True
USAGE_RATIO = (0.5, 0.5)
SENSITIVE = True
CHARACTER = string.printable[:-6]
IMG_SIZE = (32,100)
BETA1 = 0.9
BETA2 = 0.999
LRATE = 1.0

GRAD_CLIP = 5.0

In [4]:
TRAINSET_PATH = '/data/lmdb/data_lmdb_release/training'
VALIDSET_PATH = '/data/lmdb/data_lmdb_release/validation'

trainloader, trainset = loader.train_loader(TRAINSET_PATH, batch_size=BATCH_SIZE, 
                                  shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                  img_size=IMG_SIZE, usage_ratio=USAGE_RATIO,
                                  is_sensitive=SENSITIVE, character=CHARACTER)

validloader, validset = loader.valid_loader(VALIDSET_PATH, batch_size=BATCH_SIZE,
                                  shuffle=True, num_workers=NUM_WORKERS,
                                  img_size=IMG_SIZE, is_sensitive=SENSITIVE,
                                  character=CHARACTER)


In [5]:
from iqradre.recog.utils import AttnLabelConverter
CHARACTER = string.printable[:-6]
converter = AttnLabelConverter(CHARACTER)
NUM_CLASS = len(converter.character)

In [49]:
class TextPredictor(object):
    def __init__(self, weight_path, device='cpu'):
        self.weight_path = weight_path
        self.device = device
        self._load_config()
        
    def _load_config(self):
        self.character = string.printable[:-6]
        self.converter = AttnLabelConverter(self.character)
        self.num_class = len(converter.character)
        self.batch_max_length = 25
        self.img_size = (32, 100)
        
        
    def _load_model(self):
        state_dict = torch.load(self.weight_path, map_location=torch.device(self.device))
        self.model = crnn_v1.OCRNet(num_class=self.num_class, im_size=self.img_size, hidden_size=256)
        self.model.load_state_dict(state_dict)
        
    def _predict(self, images:list):
        data = self._transform(images)
        batch_size = data.shape[0]
        
        length = torch.IntTensor([self.batch_max_length] * batch_size)
        preds = self.model(images)
        preds = preds[:, :self.batch_max_length, :]
        _, preds_index = preds.max(2)
        preds_str = self.converter.decode(preds_index, length)
        preds_clean = self._clean_prediction(preds_str)
        return preds_clean
    
    def _clean_prediction(preds_str):
        out = []
        for prd_st in preds_str:
            word = prd_st.split("[s]")[0]
            out.append(word)
        return out
    
    def _transform(self, images):
        transform = VT.Compose([
            NT.ResizeRatioWithRightPad(size=self.img_size),
            VT.ToTensor(),
            VT.Normalize(mean=(0.5), std=(0.5))
        ])
        
        return transform(images)
        

In [32]:
weight_path = '../weights/recog/ocrnet_pretrained_ktp.pth'
state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
model = crnn.OCRNet(num_class=NUM_CLASS, im_size=IMG_SIZE, hidden_size=256)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [33]:
images, labels = next(iter(trainloader))

In [39]:
length_for_pred = torch.IntTensor([BATCH_MAX_LENGTH] * BATCH_SIZE)
preds = model(images)
preds = preds[:, :BATCH_MAX_LENGTH, :]
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)

In [43]:
def clean_prediction_string(pred_str):
    out = []
    for prd_st in preds_str:
        word = prd_st.split("[s]")[0]
        out.append(word)
    return out

clean_preds_str = clean_prediction_string(preds_str)

In [44]:
preds_str

['Jemasan[s]an[s]asta[s][s][s][s]a[s][s]a[s][s]',
 'CRI[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]',
 'L/n[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]',
 'PING[s][s][s][s][s]G[s][s][s][s][s][s][s][s][s][s][s][s][s][s][s]']

In [45]:
clean_preds_str

['Jemasan', 'CRI', 'L/n', 'PING']