In [None]:
%load_ext autoreload
%autoreload 2

import random
import torch
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# In order to ensure reproducibility, we set all the seeds manually.
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

# Load data, OCR and labels

In [None]:
with open('documents.json', 'r') as f:
    documents = json.load(f)

## Retrieving documents content with an OCR

The first step is to extract the text from the images of the documents, along with there position. To do so, we leverage pytesseract here.

In [None]:
from utils import OCR

for doc in tqdm(documents.values()):
    doc['OCR'] = OCR(doc['image_path'])

The ground truth is a rectangle on the page that delimites where the ground is located. On the following example it is displayed in red, all the tokens transcribed by pytesseract are in blue.

In [None]:
from utils import display_doc
display_doc(doc)

## Matching the ground truth to tokens

In [None]:
from utils import ground_truth_match
for doc in tqdm(documents.values()):
    doc['labels'] = ground_truth_match(doc['OCR'], doc['ground_truth'])

In [None]:
lengths = [
    len(doc['labels'])
    for doc in documents.values()
]

plt.figure()
plt.hist(lengths)
plt.show()

We notice some outliers at 11 words, therefore we remove addresses where the number of matched tokens is higher than the number of words in the ground truth.

In [None]:
failed_matching = []
for key, doc in documents.items():
    length_address = len(doc['address'].split())
    if length_address != len(doc['labels']):
        failed_matching.append(key)

In [None]:
for key in failed_matching:
    del documents[key]

In [None]:
lengths = [
    len(doc['labels'])
    for doc in documents.values()
]

plt.figure()
plt.hist(lengths)
plt.show()

# Pre-process data

## Text pre-processing

In [None]:
from utils import text_pre_processing

In [None]:
data = [
    (
        key,
        [
            (text_pre_processing(token['text']), token['position']) for token in doc['OCR']
        ],
        doc['labels']
    )
    for key, doc in documents.items()
]


# Split train / validation / test

In [None]:
N_DOCS = len(data)
split = 60, 20, 20  # train / validation / test

random.shuffle(data)
n_train = int(split[0] / 100 * N_DOCS)
n_val = n_train + int(split[1] / 100 * N_DOCS)

dataset_split = {
    'train': [doc for doc in data[:n_train]],
    'validation': [doc for doc in data[n_train:n_val]],
    'test': [doc for doc in data[n_val:]],
}

In [None]:
dataset_split['train'][0][1][:5]

# Mapping the characters

In [None]:
characters = set()
for _, doc_input, _ in dataset_split['train']:
    for word, _ in doc_input:
        characters |= set([x for x in word])
characters_mapping = {char: i + 1 for i, char in enumerate(characters)}  # + 1 to account for the stop token
len(characters_mapping)

In [None]:
dataset_split = {
    mode: [
        (
            key,
            (
                [
                    ([characters_mapping[c] for c in word], position)
                    for word, position in input_data
                ]
            ),
            target
        )
        for key, input_data, target in dataset_split[mode]
    ]
    for mode in dataset_split
}

## Tensorification

In [None]:
from utils import make_tensors

tensors_data = {}
for mode in dataset_split:
    tensors_data[mode] = make_tensors(dataset_split[mode])


In [None]:
for mode in tensors_data:
    print('-'*40)
    print('mode', mode)
    print('words', tensors_data[mode]['words'].shape)
    print('positions', tensors_data[mode]['positions'].shape)
    print('target', tensors_data[mode]['target'].shape)


## Wrapping it up in a TensorDataset

In [None]:
from torch.utils.data import TensorDataset, DataLoader 
datasets = {
    mode: TensorDataset(
            tensors_data[mode]['keys'],
            tensors_data[mode]['words'].type(torch.LongTensor),
            tensors_data[mode]['positions'],
            tensors_data[mode]['target'].type(torch.LongTensor)
        )
    for mode in tensors_data
}

In [None]:
tensors_data[mode]['target'] 

# Training loop

In [None]:
from model import Model

In [None]:
batch_size = 16
embedding_dim = 64
position_dim = 10
max_seq_len = 10

In [None]:
train_loader = DataLoader(
    datasets['train'],
    batch_size=batch_size
)

val_loader = DataLoader(
    datasets['validation'],
    batch_size=batch_size
)

test_loader = DataLoader(
    datasets['test'],
    batch_size=batch_size
)

In [None]:
model = Model(len(characters_mapping) + 1, embedding_dim, position_dim, max_seq_len)

learning_rate = 5e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

n_epochs = 100

In [None]:
from utils import train_model
train_losses, val_losses = train_model(n_epochs, model, optimizer, train_loader, val_loader)

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(train_losses, color='red', label="train")
plt.plot(val_losses, color='blue', label="validation")
plt.xlabel('n_epoch')
plt.ylabel('cross_entropy_loss')
plt.legend()
plt.show()

# Best val loss

In [None]:
best_step = np.argmin(val_losses)
model = torch.load(f'models/model_{best_step}.torch')
min(val_losses)

# Display prediction

In [None]:
for keys, words, positions, target in val_loader:
    overall_probabilities, peak_indices = model.forward(words, positions)
    break

In [None]:
from utils import display_prediction

key = keys[0].item()
peaks = peak_indices[0].tolist()
doc = documents[str(key)]
display_prediction(doc, peaks)

# Thresholding

In [None]:
from utils import get_threshold_data, get_metrics

### Validation

In [None]:
val_threshold_data = get_threshold_data(model, optimizer, val_loader)
get_metrics(val_threshold_data)

In [None]:
correct = val_threshold_data.loc[val_threshold_data.is_correct].confidence.values
incorrect = val_threshold_data.loc[~val_threshold_data.is_correct].confidence.values

plt.figure(figsize=(10, 10))
plt.hist(correct, bins=20, alpha=0.5, color='green', label='correct')
plt.hist(incorrect, bins=20, alpha=0.5, color='red', label='incorrect')
plt.xlabel('Confidence score')
plt.ylabel('Number of documents per bucket')
plt.legend()
plt.show()

### Test

In [None]:
test_threshold_data = get_threshold_data(model, optimizer, test_loader)
get_metrics(test_threshold_data)

In [None]:
correct = test_threshold_data.loc[test_threshold_data.is_correct].confidence.values
incorrect = test_threshold_data.loc[~test_threshold_data.is_correct].confidence.values

plt.figure()
plt.hist(correct, bins=20, alpha=0.5, color='green', label='correct')
plt.hist(incorrect, bins=20, alpha=0.5, color='red', label='incorrect')
plt.xlabel('Confidence score')
plt.ylabel('Number of documents per bucket')
plt.legend()
plt.show()

# Automation and accuracy

### Compute threshold on validation

In [None]:
from utils import find_threshold

target_accuracy = 0.99
accuracies, automations, threshold_99acc = find_threshold(target_accuracy, val_threshold_data)

thresholds = np.linspace(val_threshold_data.confidence.min(), 1, 100)

### Plot automation and accuracy

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(thresholds, automations, color='blue', label='automation')
plt.plot(thresholds, accuracies, color='green', label='accuracy')
plt.axvline(x=threshold_99acc, color='red', linestyle='--', label='0.99 threshold')
plt.xlabel('Confidence score')
plt.ylabel('Number of documents per bucket')
plt.ylim(ymin=0.6)
plt.legend()
plt.show()

### Get tet automation and test accuracy at the threshold

In [None]:
test_above_threshold = test_threshold_data.loc[test_threshold_data.confidence > threshold_99acc]
test_accuracy = test_above_threshold.is_correct.mean()
test_automation = len(test_above_threshold)/len(test_threshold_data)
test_accuracy, test_automation

In [None]:
threshold_99acc