In [0]:
!sudo apt-get -qq install tesseract-ocr
!pip install -q pytesseract

In [0]:
!pip install -q fuzzywuzzy[speedup]

In [0]:
![ -d "ba_dataset" ] && echo "Dataset directory exists."
![ ! -d "ba_dataset" ] && echo "Dataset directory DOES NOT exist. Cloning from Github..." && git clone -q https://github.com/korayakan/ba_dataset.git && rm -rf ba_dataset/.git && echo "Done"

DATA

In [0]:
# for image handling
from PIL import Image
# for json
import json
# for glob file search
import glob
# for fuzzy string comparison
from fuzzywuzzy import fuzz
from fuzzywuzzy import process
# for encoding strings
from zlib import crc32
# for log configuration
import logging
# for random numbers
from random import randint

In [0]:
logging.getLogger().setLevel(logging.ERROR)

COORDINATE_PATH = 'ba_dataset/SROIE2019/0325updated.task1train(626p)'
TAG_PATH = 'ba_dataset/SROIE2019/0325updated.task2train(626p)'
IMG_PATH = COORDINATE_PATH
TAG_TO_IDX = {'': 0, 'company': 1, 'date': 2, 'address': 3, 'total': 4}
IDX_TO_TAG = {v: k for k, v in TAG_TO_IDX.items()}

In [0]:
def get_filenames(path, suffix):
    path = path + '/' if not path.endswith('/') else path
    files = glob.glob(path + '*.' + suffix)
    for idx, file in enumerate(files):
        files[idx] = file.split("/")[-1].replace('.txt', '').replace('.jpg', '')
    return files


def get_text_filenames(path):
    return get_filenames(path, 'txt')


def get_image_filenames(path):
    return get_filenames(path, 'jpg')


def prepare_data(filename):
    tags = read_tags(filename)

    im = read_image_file(filename)
    width, height = im.size

    coordinates = read_normalized_coordinates(filename, width, height)
    coordinate_inputs = [coordinates[i][:9] for i in range(len(coordinates))]
    coordinate_texts = [coordinates[i][9] for i in range(len(coordinates))]

    coordinate_tags = match_coordinate_tags(coordinate_texts, tags)
    return tags, coordinate_inputs, coordinate_texts, coordinate_tags


def get_all_filenames():
    image_files = get_image_filenames(IMG_PATH)
    # print('found {} image files'.format(len(image_files)))

    coordinate_files = get_text_filenames(COORDINATE_PATH)
    # print('found {} files with coordinate data'.format(len(coordinate_files)))

    tag_files = get_text_filenames(TAG_PATH)
    # print('found {} files with tag data'.format(len(tag_files)))

    filenames = list(set(image_files) & set(coordinate_files) & set(tag_files))
    filenames.sort()
    return filenames


def prepare_training_data():
    filenames = get_all_filenames()
    print('found {} files with coordinate and tag data'.format(len(filenames)))

    training_size = int(len(filenames) * 0.8)
    print('using {} files for training'.format(training_size))

    training_data = []
    for i in range(training_size):
        tags, coordinate_inputs, coordinate_texts, coordinate_tags = prepare_data(filenames[i])
        training_data.append((coordinate_inputs, coordinate_tags))

    return training_data


def get_random_test_file():
    filenames = get_all_filenames()
    print('found {} files with coordinate and tag data'.format(len(filenames)))

    test_size = int(len(filenames) * 0.2)
    print('using {} files for testing'.format(test_size))
    
    return filenames[randint(test_size + 1, len(filenames) - 1)]


def read_text_file(path, filename):
    path = path + '/' if not path.endswith('/') else path
    with open(path + filename + '.txt') as file:
        text = file.read()
    return text


def read_text_file_lines(path, filename):
    path = path + '/' if not path.endswith('/') else path
    lines = []
    with open(path + filename + '.txt') as file:
        for line in file:
            lines.append(line.rstrip('\n'))
    return lines


def read_coordinates(filename, path='ba_dataset/SROIE2019/0325updated.task1train(626p)'):
    text = read_text_file_lines(path, filename)
    coordinates = []
    for line in text:
        tokens = line.split(',')
        line_coordinates = list(map(int, tokens[0:8]))
        line_text = ','.join(tokens[8:])
        line_coordinates.append(line_text)
        coordinates.append(line_coordinates)
    return coordinates


def read_normalized_coordinates(filename, width, height, path='ba_dataset/SROIE2019/0325updated.task1train(626p)'):
    coordinates = read_coordinates(filename, path=path)
    for line in coordinates:
        for x in range(0, 8, 2):
            line[x] /= width
        for x in range(1, 8, 2):
            line[x] /= height
        line.append(line[8])
        line[8] = normalize_text(line[8])
    return coordinates


def read_tags(filename, path='ba_dataset/SROIE2019/0325updated.task2train(626p)'):
    return json.loads(read_text_file(path, filename))


def read_image_file(filename, path='ba_dataset/SROIE2019/0325updated.task1train(626p)'):
    path = path + '/' if not path.endswith('/') else path
    # return cv2.imread(path + filename + '.jpg', 0)
    return Image.open(path + filename + '.jpg')


def match_coordinate_tags(coordinate_texts, tags):

    tags_reverted = {v: k for k, v in tags.items()}
    tag_values = list(tags.values())
    coordinate_tags = []
    for text in coordinate_texts:
        # text = text.replace('*', '_').replace('%', '_').replace(':', '_').replace('=', '_').
        # replace('-', '_').replace('(', '_').replace('@', '_').replace('^', '_').replace('/', '_')
        tag_guess = process.extractOne(text, tag_values, scorer=fuzz.partial_ratio, score_cutoff=90)
        tag = ''
        if tag_guess is not None:
            tag = tags_reverted[tag_guess[0]]
        coordinate_tags.append(encode_tag(tag))
    return coordinate_tags


def encode_tag(tag):
    return TAG_TO_IDX[tag]


def decode_tag(tag_idx):
    return IDX_TO_TAG[tag_idx]


def normalize_text(input_text, encoding="utf-8"):
    # see https://stackoverflow.com/questions/40351791/how-to-hash-strings-into-a-float-in-01
    return float(crc32(input_text.encode(encoding)) & 0xffffffff) / 2**32


def combine_predicted_tags(texts, tags):
    company = ''
    date = ''
    address = ''
    total = ''

    for i in range(len(tags)):
        if tags[i] == 1 and not company.endswith(texts[i]):
            company += ' ' + texts[i]
        if tags[i] == 2 and not date.endswith(texts[i]):
            date += ' ' + texts[i]
        if tags[i] == 3 and not address.endswith(texts[i]):
            address += ' ' + texts[i]
        if tags[i] == 4 and not total.endswith(texts[i]):
            total += ' ' + texts[i]

    return {'company': company.strip(), 'date': date.strip(), 'address': address.strip(), 'total': total.strip()}

LSTM Net

In [0]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as fn

In [0]:
SERIALIZED_MODEL_NAME = 'ba_model.pt'

INPUT_SIZE = 9
OUTPUT_SIZE = 5


class LSTM(nn.Module):

    def __init__(self, hidden_size=6, num_of_layers=1):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size

        # The LSTM takes coordinates as input, and outputs hidden states
        self.lstm = nn.LSTM(INPUT_SIZE, hidden_size, num_layers=num_of_layers)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_size, OUTPUT_SIZE)

    def forward(self, input_seq):
        hidden_space, _ = self.lstm(input_seq)
        tag_space = self.hidden2tag(hidden_space.view(len(input_seq), -1))
        tag_scores = fn.log_softmax(tag_space, dim=1)
        return tag_scores


def load_model(model_path=SERIALIZED_MODEL_NAME):
    if os.path.isfile(model_path):
        model = torch.load(model_path)
        model.eval()
        return model
    else:
        print('Model was not trained yet!')
        sys.exit(0)


def evaluate(coordinate_inputs, model_path=SERIALIZED_MODEL_NAME):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print('Using {} for prediction'.format(device))
    with torch.no_grad():
        model = load_model(model_path)
        model.to(device)
        input_seq = torch.tensor(coordinate_inputs)
        input_seq = input_seq.unsqueeze(1)
        input_seq = input_seq.to(device)
        tag_scores = model(input_seq)
        probabilities, tags = tag_scores.topk(1)
        predictions = []
        for i in range(len(probabilities)):
            predictions.append((torch.exp(probabilities[i]).item(), tags[i].item()))
        return predictions

TRAIN

In [0]:
def train(epochs, print_every=1, learning_rate=0.1, hidden_size=6, num_of_layers=1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using {} for training'.format(device))

    model = LSTM(hidden_size=hidden_size, num_of_layers=num_of_layers)
    loss_function = nn.NLLLoss()
    optimizer = opt.SGD(model.parameters(), lr=learning_rate)
    model.to(device)

    data_set = prepare_training_data()
    split_data = {0: [], 1: [], 2: [], 3: []}
    for i in range(len(data_set)):
        if i % 4 == 0:
            split_data[0].append(data_set[i])
        if i % 4 == 1:
            split_data[1].append(data_set[i])
        if i % 4 == 2:
            split_data[2].append(data_set[i])
        if i % 4 == 3:
            split_data[3].append(data_set[i])
    # for i in range(4):
    #     print(len(split_data[i]))

    steps = 0
    running_loss = 0
    train_losses, test_losses = [], []
    for epoch in range(epochs):
        steps += 1
        validation_idx = epoch % 4
        validation_data = split_data[validation_idx]
        training_data = []
        for i in range(4):
            if i != validation_idx:
                training_data.extend(split_data[i])

        for coordinate_inputs, coordinate_tags in training_data:
            # print(coordinate_inputs)
            # print(coordinate_tags)

            # Step 1. Clear Pytorch gradients
            model.zero_grad()

            # Step 2. Get inputs ready for the network
            input_seq = torch.tensor(coordinate_inputs)
            input_seq = input_seq.unsqueeze(1)
            input_seq = input_seq.to(device)
            targets = torch.tensor(coordinate_tags)
            targets = targets.to(device)
            # print(input_seq)
            # print(targets)

            # Step 3. Run forward pass
            tag_scores = model(input_seq)

            # Step 4. Compute the loss, gradients, and update the parameters
            loss = loss_function(tag_scores, targets)
            # print(loss)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        if (steps == 1) or (steps % print_every == 0):
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for coordinate_inputs_val, coordinate_tags_val in validation_data:
                    input_seq_val = torch.tensor(coordinate_inputs_val)
                    input_seq_val = input_seq_val.unsqueeze(1)
                    input_seq_val = input_seq_val.to(device)
                    targets_val = torch.tensor(coordinate_tags_val)
                    targets_val = targets_val.to(device)

                    tag_scores_val = model.forward(input_seq_val)
                    loss_val = loss_function(tag_scores_val, targets_val)
                    test_loss += loss_val.item()

            train_losses.append(running_loss / (len(training_data) * print_every))
            test_losses.append(test_loss / len(validation_data))
            print(f"Epoch {epoch + 1}/{epochs}.. "
                  f"Train loss: {running_loss / (len(training_data) * print_every):.3f}.. "
                  f"Test loss: {test_loss / len(validation_data):.3f}.. ")
                  #f"Test accuracy: {accuracy / len(testloader):.3f}")
            running_loss = 0
            model.train()

    torch.save(model, SERIALIZED_MODEL_NAME)
    return train_losses, test_losses

PREDICTION

In [0]:
def predict(filename, model_path=SERIALIZED_MODEL_NAME):
    print('File:')
    print(filename, '\n')

    tags, coordinate_inputs, coordinate_texts, coordinate_tags = prepare_data(filename)

    print('Expected category tags:')
    print(tags, '\n')

    # print('Input coordinates:')
    # print(coordinate_inputs)
    print(coordinate_inputs)
    print(coordinate_texts)
    print('Expected coordinate tags:')
    print(coordinate_tags)

    predictions = evaluate(coordinate_inputs, model_path)
    predicted_coordinate_tags = []
    for prediction in predictions:
        predicted_coordinate_tags.append(prediction[1])
    probabilities = []
    for prediction in predictions:
        probabilities.append("{:.0%}".format(prediction[0]))
    print('Predicted coordinate tags:')
    print(predicted_coordinate_tags)
    print('Confidence:')
    print(probabilities, '\n')

    predicted_tags = combine_predicted_tags(coordinate_texts, predicted_coordinate_tags)
    print(predicted_tags)


def get_expected_tags(filename):
    tags, coordinate_inputs, coordinate_texts, coordinate_tags = prepare_data(filename)
    return coordinate_tags


def get_predicted_tags(filename, model_path=SERIALIZED_MODEL_NAME):
    tags, coordinate_inputs, coordinate_texts, coordinate_tags = prepare_data(filename)
    predictions = evaluate(coordinate_inputs, model_path)
    predicted_coordinate_tags = []
    for prediction in predictions:
        predicted_coordinate_tags.append(prediction[1])
    return predicted_coordinate_tags

OCR DEMO

In [0]:
img = read_image_file('X00016469612')
img

In [0]:
import pytesseract
print(pytesseract.image_to_string(img, lang='eng', config='--oem 1'))

In [0]:
print(pytesseract.image_to_data(img, lang='eng', config='--oem 1'))

KI_TRAIN DEMO

In [0]:
from datetime import datetime, timezone
import time
start = time.time()
now = datetime.now(timezone.utc).astimezone()
print('start model training at {}'.format(now))

In [0]:
#train_losses, test_losses = train(1000, print_every = 100, learning_rate = 0.05, hidden_size=256, num_of_layers=3)
train_losses, test_losses = train(10, print_every = 1, learning_rate = 0.05, hidden_size=256, num_of_layers=1)

In [0]:
now = datetime.now(timezone.utc).astimezone()
done = time.time()
elapsed = int(round(done - start))
print('training finished at {}, duration was {} min {} sec'.format(now, int(round(elapsed / 60)), elapsed % 60))

In [0]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt

plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend(frameon=False)
plt.show()

In [0]:
predict(get_random_test_file())