# Optical character recognition (OCR)

# Previous work

Survey of some of the previously developed OCR software products and technologies.


- **CuneiForm**
    - by Russian software company Cognitive Technologies
    - **Open-source**
    - [Link](https://www.linuxlinks.com/cuneiform/)


- **Tesseract**
    - Originally developed by Hewlett-Packard as proprietary software in the 1980s
    - Supervised by Google.
    - **Apache License**
    - [Link](https://github.com/tesseract-ocr)

- **GdPicture OCR SDK**
    - GdPicture includes an Optical Character Recognition engine to develop any kind of application requiring OCR technology.
    - **Not open-source!**
    - [Link](https://www.gdpicture.com/ocr-sdk/)


- **Transym**
    - TOCR is a one of the most sophisticated OCR software packages on the market, specifically designed for ease of integration. With a free API and attractive volume pricing, it is designed for system integrators.
    - **Not open-source!**
    - [Link](https://transym.com/)


- **Computer Vision**
    - by Microsoft
    - An AI service that analyzes content in images and video
    - Boost content discoverability, automate text extraction, analyze video in real time, and create products that more people can use by embedding cloud vision capabilities in your apps with Computer Vision, part of Azure Cognitive Services. Use visual data processing to label content with objects and concepts, extract text, generate image descriptions, moderate content, and understand people's movement in physical spaces. No machine learning expertise is required.

- **Azure service**
    - [Link](https://azure.microsoft.com/en-us/products/cognitive-services/computer-vision/#overview)

- **Google AI**
    - [Link](https://ai.google/tools/#developers)

- **Amazon Textract**
    - "Automatically extract printed text, handwriting, and data from any document"

- **AWS component**
    - [Link](https://aws.amazon.com/textract/)


- **Kadmos**
    - from reRecognition
    - KADMOS is an ICR and OCR technology for developers built on unique “discriminatory entropy”methodology that does not depend on dictionaries. KADMOS offers exceptionally fast processing speeds in over 200 languages for both hand-written and machine printed text.
    - Provided as a feature rich SDK, KADMOS can be integrated into the most sophisticated applications where mission critical levels of accuracy and speed are required.
    - KADMOS is used extensively in document management, finance, healthcare, education, government and defence applications around the world and selected by developers for its unique capabilities.
    - **Not open-source!**
    - [Link](https://rerecognition.com/products/)

- **Mindee**
    - Python OCR SDK
    - Supports invoice, passport, receipt OCR APIs and custom-built API from the API Builder.
    - **MIT license**
    - [Getting started](https://developers.mindee.com/docs/python-getting-started)
    - [Source code](https://github.com/mindee/mindee-api-python)

- **ocr.pytorch**
    - An open-source project on GitHub
    - Pure pytorch implemented OCR project including text detection and recognition.
    - **MIT license**
    - [Repository](https://github.com/courao/ocr.pytorch)

- **TextRecognitionDataGenerator**
    - A synthetic data generator for text recognition
    - **MIT license**
    - [Link](https://github.com/Belval/TextRecognitionDataGenerator)

- **License plate font**
    - .ttf font
    - open for commercial use
    - [Link](https://www.1001fonts.com/license-plates-fonts.html)

- **B. Shi, X. Bai and C. Yao**
    - B. Shi, X. Bai and C. Yao, "An End-to-End Trainable Neural Network for Image-Based Sequence Recognition and Its Application to Scene Text Recognition," in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 39, no. 11, pp. 2298-2304, 1 Nov. 2017, doi: 10.1109/TPAMI.2016.2646371.
     - [Link](https://ieeexplore.ieee.org/abstract/document/7801919?casa_token=l7alCR-iNckAAAAA:XjdU8cR5KEKqRJKS48UZSxoyXqM5LGl4llw-HUIh9UbSAiK9v3zPyD9gYvwjsngkrRW-HEh1_Q)


- **Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition**
    - Max Jaderberg, Karen Simonyan, Andrea Vedaldi, Andrew Zisserman
    - [Link](https://arxiv.org/abs/1406.2227)

- **Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks**
    - Alex Graves, Santiago Fernández, Faustino Gomez, and Jürgen Schmidhuber. 2006. Connectionist temporal classifiation: labelling unsegmented sequence data with recurrent neural networks. In Proceedings of the 23rd international conference on Machine learning (ICML '06). Association for Computing Machinery, New York, NY, USA, 369–376. https://doi.org/10.1145/1143844.1143891
    - [Link](https://dl.acm.org/doi/abs/10.1145/1143844.1143891?casa_token=e19selHVFKsAAAAA:mrb7JOGyGF84MfHGUcpvhE1AzuC-_fUKIYI-yHsZa8wpuyR0ld5xzCTmUNT0hU3xgi6aJyXkA9Zn)

# Learning materials for building Pytorch OCR application
- [Building a custom OCR using pytorch](https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html)
-[I need advice for a OCR project](https://discuss.pytorch.org/t/i-need-advice-for-a-ocr-project/113087)
-[Simon Z. Python kurzus](https://theflyingpiano99.github.io/PythonBeginnerCourse/index.html)
-[Pytorch official tutorials](https://pytorch.org/tutorials/)



# OCR study

In this section we make notes about the architecture of the OCR system under design.
The ideas in this section originate mainly from [Building a custom OCR using pytorch](https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#setting-up-the-data).

##Setting up the Data

In [1]:
#@title installing sample data generator

!python --version

!pip install trdg
!pip3 install Pillow
!pip3 install torch_utils
!pip3 install torchvision
!pip3 install pytorch_wrapper
!pip3 install textdistance
!cp -r ./drive/MyDrive/Adapting-OCR-master/* ./
#!pip3 install -r requirements.txt

zsh:1: command not found: python
Collecting trdg
  Downloading trdg-1.8.0-py3-none-any.whl (98.6 MB)
[2K     [38;2;249;38;114m━━━━━[0m[38;5;237m╺[0m[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.9/98.6 MB[0m [31m3.2 MB/s[0m eta [36m0:00:27[0m:27[0m^C
[2K     [38;2;249;38;114m━━━━━[0m[38;5;237m╺[0m[38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.5/98.6 MB[0m [31m3.2 MB/s[0m eta [36m0:00:27[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
Collecting torch_utils
  Downloading torch-utils-0.1.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: torch_utils
  Building wheel for torch_utils (setup.py) ... [?25ldone
[?25h  Created wheel for torch_utils: filename=torch_utils-0.1.2-py3-none-any.whl size=6187 sha256=d10bbff618416650473227adbb72022225e5ff1ae85c0e044535f16b0c037c78
  Stored in directory: /Users/bobarna/Library/Caches/pip/wheels/39/08/27/156ba3a830c518bfc69f5841f203a467ae1732c9c7

In [2]:
#@title Generating sample data
!rm data/train/*
!rm data/test/*
!trdg -i words.txt -c 20000 --output_dir data/train -ft fonts/din1451alt.ttf
!trdg -i words.txt -c 256 --output_dir data/test -ft fonts/din1451alt.ttf


zsh:1: no matches found: data/train/*
zsh:1: no matches found: data/test/*
zsh:1: command not found: trdg
zsh:1: command not found: trdg


In [3]:
#@title Importing libraries

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#setting-up-the-data

# Install missing dependencies:


#-------------------------------------------------------


import os
import sys
import pdb
import six
import random
import lmdb
from PIL import Image
import numpy as np
import math
from collections import OrderedDict
from itertools import chain
import logging


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import random_split

sys.path.insert(0, '../')
from src.utils.utils import AverageMeter, Eval, OCRLabelConverter
from src.utils.utils import EarlyStopping, gmkdir
from src.optim.optimizer import STLR
from src.utils.utils import gaussian
from tqdm import *


ModuleNotFoundError: No module named 'lmdb'

In [None]:
#@title Dataset

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#setting-up-the-data

class SynthDataset(Dataset):
    def __init__(self, opt):
        super(SynthDataset, self).__init__()
        self.path = os.path.join(opt['path'], opt['imgdir'])
        self.images = os.listdir(self.path)
        self.nSamples = len(self.images)
        f = lambda x: os.path.join(self.path, x)
        self.imagepaths = list(map(f, self.images))
       	transform_list =  [transforms.Grayscale(1),
                            transforms.ToTensor(), 
                            transforms.Normalize((0.5,), (0.5,))]
        self.transform = transforms.Compose(transform_list)
        self.collate_fn = SynthCollator()

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imagepath = self.imagepaths[index]
        imagefile = os.path.basename(imagepath)
        img = Image.open(imagepath)
        if self.transform is not None:
            img = self.transform(img)
        item = {'img': img, 'idx':index}
        item['label'] = imagefile.split('_')[0]
        return item 

In [None]:
#@title Synth collator

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#setting-up-the-data

class SynthCollator(object):
    
    def __call__(self, batch):
        width = [item['img'].shape[2] for item in batch]
        indexes = [item['idx'] for item in batch]
        imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], 
                           max(width)], dtype=torch.float32)
        for idx, item in enumerate(batch):
            try:
                imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
            except:
                print(imgs.shape)
        item = {'img': imgs, 'idx':indexes}
        if 'label' in batch[0].keys():
            labels = [item['label'] for item in batch]
            item['label'] = labels
        return item

#Defining our Model

Architecture proposed by [B. Shi et al.].

![image.png](https://deepayan137.github.io/blog/images/crnn.png)

In [None]:
#@title Convolutional RNN layer & bidir. LSTM layer

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html#setting-up-the-data

class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


class CRNN(nn.Module):

    def __init__(self, opt, leakyRelu=False):
        super(CRNN, self).__init__()

        assert opt['imgH'] % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = opt['nChannels'] if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16
        self.cnn = cnn
        self.rnn = nn.Sequential()
        self.rnn = nn.Sequential(
            BidirectionalLSTM(opt['nHidden']*2, opt['nHidden'], opt['nHidden']),
            BidirectionalLSTM(opt['nHidden'], opt['nHidden'], opt['nClasses']))


    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        # rnn features
        output = self.rnn(conv)
        output = output.transpose(1,0) #Tbh to bth
        return output

#The CTC Loss

CTC was proposed by [A. Graves et al.].
We don't need exasct alignment between input and output.
The CTC layer computes a score with all possible alignments of the target label.
The OCR is trained to maximize the sum of these scores.

In [None]:
#@title The CTC Loss

# Code by Jerin Philip
# https://jerinphilip.github.io/

class CustomCTCLoss(torch.nn.Module):
    # T x B x H => Softmax on dimension 2
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim
        self.ctc_loss = torch.nn.CTCLoss(reduction='mean', zero_infinity=True)

    def forward(self, logits, labels,
            prediction_sizes, target_sizes):
        EPS = 1e-7
        loss = self.ctc_loss(logits, labels, prediction_sizes, target_sizes)
        loss = self.sanitize(loss)
        return self.debug(loss, logits, labels, prediction_sizes, target_sizes)
    
    def sanitize(self, loss):
        EPS = 1e-7
        if abs(loss.item() - float('inf')) < EPS:
            return torch.zeros_like(loss)
        if math.isnan(loss.item()):
            return torch.zeros_like(loss)
        return loss

    def debug(self, loss, logits, labels,
            prediction_sizes, target_sizes):
        if math.isnan(loss.item()):
            print("Loss:", loss)
            print("logits:", logits)
            print("labels:", labels)
            print("prediction_sizes:", prediction_sizes)
            print("target_sizes:", target_sizes)
            raise Exception("NaN loss obtained. But why?")
        return loss

In [None]:
#@title Training loop

class OCRTrainer(object):
    def __init__(self, opt):
        super(OCRTrainer, self).__init__()
        self.data_train = opt['data_train']
        self.data_val = opt['data_val']
        self.model = opt['model']
        self.criterion = opt['criterion']
        self.optimizer = opt['optimizer']
        self.schedule = opt['schedule']
        self.converter = OCRLabelConverter(opt['alphabet'])
        self.evaluator = Eval()
        print('Scheduling is {}'.format(self.schedule))
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=opt['epochs'])
        self.batch_size = opt['batch_size']
        self.count = opt['epoch']
        self.epochs = opt['epochs']
        self.cuda = opt['cuda']
        self.collate_fn = opt['collate_fn']
        self.init_meters()

    def init_meters(self):
        self.avgTrainLoss = AverageMeter("Train loss")
        self.avgTrainCharAccuracy = AverageMeter("Train Character Accuracy")
        self.avgTrainWordAccuracy = AverageMeter("Train Word Accuracy")
        self.avgValLoss = AverageMeter("Validation loss")
        self.avgValCharAccuracy = AverageMeter("Validation Character Accuracy")
        self.avgValWordAccuracy = AverageMeter("Validation Word Accuracy")

    def forward(self, x):
        logits = self.model(x)
        return logits.transpose(1, 0)

    def loss_fn(self, logits, targets, pred_sizes, target_sizes):
        loss = self.criterion(logits, targets, pred_sizes, target_sizes)
        return loss

    def step(self):
        self.max_grad_norm = 0.05
        clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()
    
    def schedule_lr(self):
        if self.schedule:
            self.scheduler.step()

    def _run_batch(self, batch, report_accuracy=False, validation=False):
        input_, targets = batch['img'], batch['label']
        targets, lengths = self.converter.encode(targets)
        logits = self.forward(input_)
        logits = logits.contiguous().cpu()
        logits = torch.nn.functional.log_softmax(logits, 2)
        T, B, H = logits.size()
        pred_sizes = torch.LongTensor([T for i in range(B)])
        targets= targets.view(-1).contiguous()
        loss = self.loss_fn(logits, targets, pred_sizes, lengths)
        if report_accuracy:
            probs, preds = logits.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = self.converter.decode(preds.data, pred_sizes.data, raw=False)
            ca = np.mean((list(map(self.evaluator.char_accuracy, list(zip(sim_preds, batch['label']))))))
            wa = np.mean((list(map(self.evaluator.word_accuracy, list(zip(sim_preds, batch['label']))))))
        return loss, ca, wa

    def run_epoch(self, validation=False):
        if not validation:
            loader = self.train_dataloader()
            pbar = tqdm(loader, desc='Epoch: [%d]/[%d] Training'%(self.count, 
                self.epochs), leave=True)
            self.model.train()
        else:
            loader = self.val_dataloader()
            pbar = tqdm(loader, desc='Validating', leave=True)
            self.model.eval()
        outputs = []
        for batch_nb, batch in enumerate(pbar):
            if not validation:
                output = self.training_step(batch)
            else:
                output = self.validation_step(batch)
            pbar.set_postfix(output)
            outputs.append(output)
        self.schedule_lr()
        if not validation:
            result = self.train_end(outputs)
        else:
            result = self.validation_end(outputs)
        return result

    def training_step(self, batch):
        loss, ca, wa = self._run_batch(batch, report_accuracy=True)
        self.optimizer.zero_grad()
        loss.backward()
        self.step()
        output = OrderedDict({
            'loss': abs(loss.item()),
            'train_ca': ca.item(),
            'train_wa': wa.item()
            })
        return output

    def validation_step(self, batch):
        loss, ca, wa = self._run_batch(batch, report_accuracy=True, validation=True)
        output = OrderedDict({
            'val_loss': abs(loss.item()),
            'val_ca': ca.item(),
            'val_wa': wa.item()
            })
        return output

    def train_dataloader(self):
        # logging.info('training data loader called')
        loader = torch.utils.data.DataLoader(self.data_train,
                batch_size=self.batch_size,
                collate_fn=self.collate_fn,
                shuffle=True)
        return loader
        
    def val_dataloader(self):
        # logging.info('val data loader called')
        loader = torch.utils.data.DataLoader(self.data_val,
                batch_size=self.batch_size,
                collate_fn=self.collate_fn)
        return loader

    def train_end(self, outputs):
        for output in outputs:
            self.avgTrainLoss.add(output['loss'])
            self.avgTrainCharAccuracy.add(output['train_ca'])
            self.avgTrainWordAccuracy.add(output['train_wa'])

        train_loss_mean = abs(self.avgTrainLoss.compute())
        train_ca_mean = self.avgTrainCharAccuracy.compute()
        train_wa_mean = self.avgTrainWordAccuracy.compute()

        result = {'train_loss': train_loss_mean, 'train_ca': train_ca_mean,
        'train_wa': train_wa_mean}
        # result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': train_loss_mean}
        return result

    def validation_end(self, outputs):
        for output in outputs:
            self.avgValLoss.add(output['val_loss'])
            self.avgValCharAccuracy.add(output['val_ca'])
            self.avgValWordAccuracy.add(output['val_wa'])

        val_loss_mean = abs(self.avgValLoss.compute())
        val_ca_mean = self.avgValCharAccuracy.compute()
        val_wa_mean = self.avgValWordAccuracy.compute()

        result = {'val_loss': val_loss_mean, 'val_ca': val_ca_mean,
        'val_wa': val_wa_mean}
        return result

In [None]:
#@title Putting everything together

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html
# Under MIT license

class Learner(object):
    def __init__(self, model, optimizer, savepath=None, resume=False):
        self.model = model
        self.optimizer = optimizer
        self.savepath = os.path.join(savepath, 'best.ckpt')
        self.cuda = torch.cuda.is_available() 
        self.cuda_count = torch.cuda.device_count()
        if self.cuda:
            self.model = self.model.cuda()
        self.epoch = 0
        if self.cuda_count > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model)
        self.best_score = None
        if resume and os.path.exists(self.savepath):
            self.checkpoint = torch.load(self.savepath)
            self.epoch = self.checkpoint['epoch']
            self.best_score=self.checkpoint['best']
            self.load()
        else:
            print('checkpoint does not exist')

    def fit(self, opt):
        opt['cuda'] = self.cuda
        opt['model'] = self.model
        opt['optimizer'] = self.optimizer
        logging.basicConfig(filename="%s/%s.csv" %(opt['log_dir'], opt['name']), level=logging.INFO)
        self.saver = EarlyStopping(self.savepath, patience=15, verbose=True, best_score=self.best_score)
        opt['epoch'] = self.epoch
        trainer = OCRTrainer(opt)
        
        for epoch in range(opt['epoch'], opt['epochs']):
            train_result = trainer.run_epoch()
            val_result = trainer.run_epoch(validation=True)
            trainer.count = epoch
            info = '%d, %.6f, %.6f, %.6f, %.6f, %.6f, %.6f'%(epoch, train_result['train_loss'], 
                val_result['val_loss'], train_result['train_ca'],  val_result['val_ca'],
                train_result['train_wa'], val_result['val_wa'])
            logging.info(info)
            self.val_loss = val_result['val_loss']
            print(self.val_loss)
            if self.savepath:
                self.save(epoch)
            if self.saver.early_stop:
                print("Early stopping")
                break

    def load(self):
        print('Loading checkpoint at {} trained for {} epochs'.format(self.savepath, self.checkpoint['epoch']))
        self.model.load_state_dict(self.checkpoint['state_dict'])
        if 'opt_state_dict' in self.checkpoint.keys():
            print('Loading optimizer')
            self.optimizer.load_state_dict(self.checkpoint['opt_state_dict'])

    def save(self, epoch):
        self.saver(self.val_loss, epoch, self.model, self.optimizer)

In [None]:
#@title Learning

# Source code from https://deepayan137.github.io/blog/markdown/2020/08/29/building-ocr.html
# MIT license

alphabet = """Only thewigsofrcvdampbkuq.$A-210xT5'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%"""
args = {
    'name':'exp1',
    'path':'data',
    'imgdir': 'train',
    'imgH':32,
    'nChannels':1,
    'nHidden':256,
    'nClasses':len(alphabet),
    'lr':0.001,
    'epochs':4,
    'batch_size':32,
    'save_dir':'checkpoints',
    'log_dir':'logs',
    'resume':False,
    'cuda':False,
    'schedule':False
    
}

data = SynthDataset(args)
args['collate_fn'] = SynthCollator()
train_split = int(0.8*len(data))
val_split = len(data) - train_split
args['data_train'], args['data_val'] = random_split(data, (train_split, val_split))
print('Traininig Data Size:{}\nVal Data Size:{}'.format(
    len(args['data_train']), len(args['data_val'])))
args['alphabet'] = alphabet
model = CRNN(args)
args['criterion'] = CustomCTCLoss()
savepath = os.path.join(args['save_dir'], args['name'])
gmkdir(savepath)
gmkdir(args['log_dir'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
learner = Learner(model, optimizer, savepath=savepath, resume=args['resume'])
learner.fit(args)

#Evaluation and testing



In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_accuracy(args):
    loader = torch.utils.data.DataLoader(args['data'],
                batch_size=args['batch_size'],
                collate_fn=args['collate_fn'])
    model = args['model']
    model.eval()
    converter = OCRLabelConverter(args['alphabet'])
    evaluator = Eval()
    labels, predictions, images = [], [], []
    for iteration, batch in enumerate(tqdm(loader)):
        input_, targets = batch['img'].to(device), batch['label']
        images.extend(input_.squeeze().detach())
        labels.extend(targets)
        targets, lengths = converter.encode(targets)
        logits = model(input_).transpose(1, 0)
        logits = torch.nn.functional.log_softmax(logits, 2)
        logits = logits.contiguous().cpu()
        T, B, H = logits.size()
        pred_sizes = torch.LongTensor([T for i in range(B)])
        probs, pos = logits.max(2)
        pos = pos.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(pos.data, pred_sizes.data, raw=False)
        predictions.extend(sim_preds)
        
#     make_grid(images[:10], nrow=2)
    fig=plt.figure(figsize=(8, 8))
    columns = 4
    rows = 5
    for i in range(1, columns*rows +1):
        img = images[i]
        img = (img - img.min())/(img.max() - img.min())
        img = np.array(img * 255.0, dtype=np.uint8)
        fig.add_subplot(rows, columns, i)
        plt.title(predictions[i])
        plt.axis('off')
        plt.imshow(img)
    plt.show()
    ca = np.mean((list(map(evaluator.char_accuracy, list(zip(predictions, labels))))))
    wa = np.mean((list(map(evaluator.word_accuracy_line, list(zip(predictions, labels))))))
    return ca, wa


In [None]:
args['imgdir'] = 'test'
args['data'] = SynthDataset(args)
resume_file = os.path.join(args['save_dir'], args['name'], 'best.ckpt')
if os.path.isfile(resume_file):
    print('Loading model %s'%resume_file)
    checkpoint = torch.load(resume_file)
    model.load_state_dict(checkpoint['state_dict'])
    args['model'] = model
    ca, wa = get_accuracy(args)
    print("Character Accuracy: %.2f\nWord Accuracy: %.2f"%(ca, wa))
else:
    print("=> no checkpoint found at '{}'".format(save_file))
    print('Exiting')