# Train a new model on your own data

In [1]:
from pytorch_beam_search import seq2seq
from post_ocr_correction import correction
import re

In [2]:
# training data

source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("abcdefghijk"), list("mnopqrstwxy")]

In [3]:
# preprocessing

source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)

In [4]:
# model

model = seq2seq.Transformer(source_index, target_index)
model.train()
train_log = model.fit(X, Y, epochs = 100, progress_bar = 0)



Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 28 items>
Target index: <Seq2Seq Index with 26 items>
Max sequence length: 32
Embedding dimension: 32
Feedforward dimension: 128
Encoder layers: 2
Decoder layers: 2
Attention heads: 2
Activation: relu
Dropout: 0.0
Trainable parameters: 63,130

Training started
X_train.shape: torch.Size([2, 14])
Y_train.shape: torch.Size([2, 13])
Epochs: 100
Learning rate: 0.0001
Weight decay: 0
Epoch | Train                 | Minutes
      | Loss     | Error Rate |
---------------------------------------
    1 |   3.4754 |     95.833 |     0.0
    2 |   3.4545 |     95.833 |     0.0
    3 |   3.4336 |     95.833 |     0.0
    4 |   3.4131 |     95.833 |     0.0
    5 |   3.3930 |     95.833 |     0.0
    6 |   3.3730 |     95.833 |     0.0
    7 |   3.3532 |     95.833 |     0.0
    8 |   3.3340 |     95.833 |     0.0
    9 |   3.3151 |     95.833 |     0.0
   10 |   3.2967 |     95.833 |     0.0
   11 |   3.2785 |     91.667 |     0.0
   12 

In [5]:
# test data

test = "ghijklmnopqrst"
new_source = [list(test)]
X_new = source_index.text2tensor(new_source)

In [6]:
# plain beam search

model.eval()
predictions, log_probabilities = seq2seq.beam_search(
    model, 
    X_new,
    progress_bar = 0
)
just_beam = target_index.tensor2text(predictions[:, 0, :])[0]
just_beam = re.sub(r"<START>|<PAD>|<UNK>|<END>.*", "", just_beam)

In [7]:
# post ocr correction

disjoint_beam = correction.disjoint(
    test,
    model,
    source_index,
    target_index,
    5,
    "beam_search",
)
votes, n_grams_beam = correction.n_grams(
    test,
    model,
    source_index,
    target_index,
    5,
    "beam_search",
    "triangle"
)

In [8]:
print("\nresults")
print("  test data                      ", test)
print("  plain beam search              ", just_beam)
print("  disjoint windows, beam search  ", disjoint_beam)
print("  n-grams, beam search, triangle ", n_grams_beam)


results
  test data                       ghijklmnopqrst
  plain beam search               mny
  disjoint windows, beam search   mbbbembbsomnp
  n-grams, beam search, triangle  mbbbyyybobrsoa
