<a href="https://colab.research.google.com/github/michaelgfalk/clean-ocr/blob/master/ocr_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Waves of Words: Correcting Trove's Messy OCR

One aim of the *Waves of Words* project is to extract Aboriginal wordlists from [Trove](https://trove.nla.gov.au). A challenge we face is that historical newspapers are difficult to OCR, so many of the texts are riddled with errors.

Using the training data available from the ALTA 2017 OCR competition, can we create a model that will clean the text enough for our aboriginal word detector to work?

I have been giving some thought to whether uppercase letters and punctuation should be preserved in this model, given that the aim is to clean up the text for our detector, which only requires lower case letters and ignores punctuation. I think we need to include all the characters in this one. The extra information about sentence barriers, for example, should hopefully help the model as it would a human when it tries to correct the text. Moreover, many OCR errors involve exchaning punctuation or digits for letters, e.g. `l = 1 = !`.

**References:**

* D. Mollá, S. Cassidy. Overview of the 2017 ALTA Shared Task:
Correcting OCR Errors (2017). *Proc. ALTA 2017*.
[https://aclanthology.coli.uni-saarland.de/papers/U17-1014/u17-1014](https://aclanthology.coli.uni-saarland.de/papers/U17-1014/u17-1014)

In [0]:
# Install TensorFlow2
!pip install -q tensorflow-gpu==2.0.0-alpha0

In [0]:
from __future__ import absolute_import, division, print_function

from google.colab import drive

import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import pandas as pd
import re
import numpy as np
from itertools import product

In [0]:
# Mount google drive to get training data. Set data_dir
drive.mount('/content/gdrive')
data_dir = '/content/gdrive/My Drive/waves_of_words/ocr_correction_data/'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
# Import test and training data
raw_x = pd.read_csv(data_dir + 'train_input.csv')
raw_y = pd.read_csv(data_dir + 'train_output.csv')
test_x = pd.read_csv(data_dir + 'test_input.csv')

In [0]:
# What is the shape of the corpus?
print("Summary of raw_x:\n")
display(raw_x['original'].str.len().describe())
print("\n\nSummary of raw_y:\n")
display(raw_y['solution'].str.len().describe())
print(f"\n\nThere are {raw_x['original'].str.len().sum()} characters in the training set.")
print(f"\n\n90% of the training examples have {raw_x['original'].str.len().quantile(0.9)} characters or less")

Summary of raw_x:



count     6000.000000
mean      2738.592000
std       3687.978419
min        107.000000
25%        675.000000
50%       1416.000000
75%       3369.000000
max      48424.000000
Name: original, dtype: float64



Summary of raw_y:



count     6000.000000
mean      2737.897500
std       3695.634174
min         75.000000
25%        671.000000
50%       1413.000000
75%       3374.250000
max      48500.000000
Name: solution, dtype: float64



There are 16431552 characters in the training set.


90% of the training examples have 6464.0 characters or less


In [0]:
def chunk_text(article_string, chunk_size = 200, start_char = "<START>", end_char = "<END>"):
  """Chunk Trove articles from dataset.
  
  Arguments:
  ==========
  article_string (str): the entire article as a single string
  chunk_size (int): the length of the desired chunks, in characters
  start (str): the token for the beginning of an article
  end (str): the token for the end of an article
  
  Returns:
  ==========
  chunks (list): a list of chunks"""
  
  # Ensure 'start' and 'end' are not present in the string
  if start in article_string or end in article_string:
    raise Exception("Start or end token found in string")
  
  # If not, add placeholders for special characters...
  article_string = "S" + article_string + "E"
  # ... and chunk
  chunks = []
  num_chars = len(article_string)
  for i in range(0, num_chars, chunk_size):
    sub_strt = i
    sub_end = min(num_chars, i + chunk_size)
    chunks.append(article_string[sub_strt:sub_end])
  
  # Replace special characters
  chunks[0] = re.sub("^S", start_char, chunks[0])
  chunks[-1] = re.sub("E$", end_char, chunks[-1])
  
  return chunks

Since the sequences don't line up, we will need to use a 'stateful' RNN to connect all the chunks during training...

To do this, we need to split the training data into two levels of batches. There will be $k$ hyperbatches, each containing $l$ training examples. 

In [0]:
num_articles = len(raw_x)
chunk_size = 200
q = 0.85
ninety_percentile = max(raw_x['original'].str.len().quantile(q) + 2, raw_y['solution'].str.len().quantile(q) + 2)
max_chunks = np.ceil(ninety_percentile / chunk_size)
batch_size = 256
num_hyper_batches = np.ceil(num_articles/batch_size)


print(f'There are {num_articles} articles in the training data.')
print(f'Let us split each article into chunks of {chunk_size} characters,')
print(f'and cap the number of chunks at the {int(q * 100)}th percentile.')
print(f'{q * 100}% of articles have {ninety_percentile:.2f} characters or less (including start and end tokens).')
print(f'This equates to {max_chunks} chunks per article.')
print(f'If we choose a batch_size of {batch_size}, there will be {num_hyper_batches} hyper-batches,')
print(f'comprising {batch_size * max_chunks} training examples each.')

There are 6000 articles in the training data.
Let us split each article into chunks of 200 characters,
and cap the number of chunks at the 85th percentile.
85.0% of articles have 5108.15 characters or less (including start and end tokens).
This equates to 26.0 chunks per article.
If we choose a batch_size of 256, there will be 24.0 hyper-batches,
comprising 6656.0 training examples each.


What would be most efficient, actually, is to dynamically create the hyper-batches by sorting the training examples in order of length. Then each hyper-batch could have its own `max_chunks` hyperparameter. Meanwhile $t$ and the `batch_size` would stay the same for each hyper-batch.

In [0]:
# Set hyperparameters
chunk_size = 200
start_char = "स" # 's' in devanagari
end_char = "ए" # 'e' in devanagari
batch_size = 256
num_articles = len(raw_x)

In [0]:
# Join DataFrames
train_joined = pd.merge(raw_x, raw_y, on = 'id')

# Sort in order of string length
train_joined['max_len'] = pd.concat(
    [train_joined['original'].str.len(), train_joined['solution'].str.len()],
    axis = 1
).max(axis = 1)
train_joined = train_joined.sort_values(by = 'max_len')

In [0]:
# Fit a tokenizer to the data
tkzr = Tokenizer(
    num_words = None,
    filters = None,
    lower = False,
    char_level = True
)

# Show it the X data
tkzr.fit_on_texts(train_joined['original'])
# Show it the Y data
tkzr.fit_on_texts(train_joined['solution'])
# Show it the special start and end characters
tkzr.fit_on_texts([start_char,end_char])

In [0]:
# Add start and end, tokenize, pad, chunk
hyper_batches = []

# Start and end tokens
train_joined['original'] = start_char + train_joined['original'] + end_char
train_joined['solution'] = start_char + train_joined['solution'] + end_char

# Tokenise
x_tokens = tkzr.texts_to_sequences(train_joined['original'])
y_tokens = tkzr.texts_to_sequences(train_joined['solution'])

# Iterate over hyper_batches:
for i in range(0, num_articles, batch_size):
  hyper_batch = {}
  
  # Determine slice start and end points
  end = min(i + batch_size, num_articles)
  
  # Get articles for this batch
  batch_x = x_tokens[i:end]
  batch_y = y_tokens[i:end]
  
  # Determine max_len
  max_len = max([len(x) for x in batch_x] + [len(y) for y in batch_y])
  # Round up to chunk_size
  max_len += chunk_size - (max_len % chunk_size)
  
  # Pad sequences
  x_padded = pad_sequences(batch_x, maxlen = max_len)
  y_padded = pad_sequences(batch_y, maxlen = max_len)
  
  # Split and stack
  num_chunks = int(max_len / chunk_size)
  hyper_batch['X'] = np.concatenate(np.split(x_padded, num_chunks, axis = 1), axis = 0)
  hyper_batch['Y'] = np.concatenate(np.split(y_padded, num_chunks, axis = 1), axis = 0)
  
  # Create index
  hyper_batch['index'] = [x for x in product(range(0,num_chunks), range(0,batch_size))]
  
  # Append
  hyper_batches.append(hyper_batch)
  
  # Sanity check
  print(f'\nNext hyperbatch:')
  print(f'max_len = {max_len}')
  print(f'chunk_size = {chunk_size}')
  print(f'num_chunks = {num_chunks}')
  print(f'shape of X: {hyper_batch["X"].shape}')
  print(f'shape of Y: {hyper_batch["Y"].shape}')
  

This works, but the last three hyperbatches are gonna have a lot of zeros in them...