In [14]:
# First download the models

# !python ../download_data.py

In [3]:
import pickle
import torch
from pytorch_beam_search import seq2seq
from post_ocr_correction import correction
import re
from pprint import pprint

In [4]:
# load vocabularies and model, in this case, we are loading
# the english model

with open("data/models/en/model_en.arch", "rb") as file:
    architecture = pickle.load(file)
source = list(architecture["in_vocabulary"].keys())
target = list(architecture["out_vocabulary"].values())
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

In [5]:
# remove keys from old API of pytorch_beam_search

for k in [
   "in_vocabulary",
   "out_vocabulary",
   "model",
   "parameters"
]:
    architecture.pop(k)
model = seq2seq.Transformer(source_index, target_index, **architecture)
state_dict = torch.load(
    "data/models/en/model_en.pt",
    map_location = torch.device("cpu") # comment this line if you have a GPU
)

Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 164 items>
Target index: <Seq2Seq Index with 164 items>
Max sequence length: 110
Embedding dimension: 256
Feedforward dimension: 1024
Encoder layers: 2
Decoder layers: 2
Attention heads: 8
Activation: relu
Dropout: 0.5
Trainable parameters: 3,841,700





In [6]:
# change names from old API of pytorch_beam_search

state_dict["source_embeddings.weight"] = state_dict.pop("in_embeddings.weight")
state_dict["target_embeddings.weight"] = state_dict.pop("out_embeddings.weight")
model.load_state_dict(state_dict)
model.eval()

Transformer(
  (source_embeddings): Embedding(164, 256)
  (target_embeddings): Embedding(164, 256)
  (positional_embeddings): Embedding(110, 256)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=1024, bias=True)
          (dropout): Dropout(p=0.5, inplace=False)
          (linear2): Linear(in_features=1024, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.5, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Tr

In [7]:
# test data

test = "th1s 1s a c0rrupted str1ng"
reference = "this is a corrupted string"
new_source = [list(test)]
X_new = source_index.text2tensor(new_source)

In [8]:
# plain beam search

predictions, log_probabilities = seq2seq.beam_search(
    model, 
    X_new,
    progress_bar = 0)
just_beam = target_index.tensor2text(predictions[:, 0, :])[0]
just_beam = re.sub(r"<START>|<PAD>|<UNK>|<END>.*", "", just_beam)

In [9]:
# post ocr correction

disjoint_beam= correction.disjoint(
    test,
    model,
    source_index,
    target_index,
    5,
    "beam_search",
)

In [10]:
votes, n_grams_beam = correction.n_grams(
    test,
    model,
    source_index,
    target_index,
    5,
    "beam_search",
    "triangle"
)

In [11]:
evaluation = correction.full_evaluation(
    [test],
    [reference],
    model,
    source_index,
    target_index,
)

evaluating all methods...
  disjoint window...
    greedy_search...
    beam_search...
  sliding
    greedy...
      uniform...
      triangle...
      bell...
    beam...
      uniform...
      triangle...
      bell...



In [13]:
print("results")
print("  reference                      ", reference)
print("  test data                      ", test)
print("  plain beam search              ", just_beam)
print("  disjoint windows, beam search  ", disjoint_beam)
print("  n-grams, beam search, triangle ", n_grams_beam)
print()
evaluation

results
  reference                       this is a corrupted string
  test data                       th1s 1s a c0rrupted str1ng
  plain beam search               this Is a corrupted 
  disjoint windows, beam search   this 1s a corrupted string. 1.
  n-grams, beam search, triangle  this 1s a corrupted string



Unnamed: 0,window,decoding,window_size,weighting,inference_seconds,cer_before,cer_after,improvement
0,disjoint,greedy,20,,0.098444,15.384615,19.230769,-25.0
1,disjoint,greedy,10,,0.167863,15.384615,57.692308,-275.0
2,disjoint,beam,20,,0.131956,15.384615,19.230769,-25.0
3,disjoint,beam,10,,0.312966,15.384615,57.692308,-275.0
4,sliding,greedy,10,uniform,0.143648,15.384615,3.846154,75.0
5,sliding,greedy,10,triangle,0.14617,15.384615,3.846154,75.0
6,sliding,greedy,10,bell,0.144272,15.384615,3.846154,75.0
7,sliding,beam,10,uniform,0.594327,15.384615,3.846154,75.0
8,sliding,beam,10,triangle,0.60738,15.384615,3.846154,75.0
9,sliding,beam,10,bell,0.752322,15.384615,3.846154,75.0
