In [None]:
# |default_exp error_correction

In [None]:
# | export
import dataclasses
import random
from functools import partial
from itertools import chain
from pathlib import Path
from typing import Dict, Iterable, List

import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm

from ocrpostcorrection.icdar_data import Text, normalized_ed

In [None]:
# | hide
import os
from pathlib import Path
from pprint import pprint

import edlib
from torch.utils.data import DataLoader

from ocrpostcorrection.icdar_data import generate_data

## Dataset Creation

A dataset for token correction consists of the OCR text and gold standard of AlignedTokens. 
These can be extracted from the Text objects using the `get_tokens_with_OCR_mistakes` function.
This function also adds data properties that can be used for calculating statistics 
about the data.

In [None]:
# | export
def get_tokens_with_OCR_mistakes(
    data: Dict[str, Text], data_test: Dict[str, Text], val_files: List[str]
) -> pd.DataFrame:
    """Return pandas dataframe with all OCR mistakes from train, val, and test"""
    tokens = []
    # Train and val
    for key, d in data.items():
        for token in d.tokens:
            if token.ocr.strip() != token.gs.strip():
                r = dataclasses.asdict(token)
                r["language"] = key[:2]
                r["subset"] = key.split("/")[1]

                if key in val_files:
                    r["dataset"] = "val"
                else:
                    r["dataset"] = "train"

                tokens.append(r)
    # Test
    for key, d in data_test.items():
        for token in d.tokens:
            if token.ocr.strip() != token.gs.strip():
                r = dataclasses.asdict(token)
                r["language"] = key[:2]
                r["subset"] = key.split("/")[1]
                r["dataset"] = "test"

                tokens.append(r)
    tdata = pd.DataFrame(tokens)
    tdata = _add_update_data_properties(tdata)

    return tdata


def _add_update_data_properties(tdata: pd.DataFrame) -> pd.DataFrame:
    """Add and update data properties for calculating statistics"""
    tdata["ocr"] = tdata["ocr"].apply(lambda x: x.strip())
    tdata["gs"] = tdata["gs"].apply(lambda x: x.strip())
    tdata["len_ocr"] = tdata.apply(lambda row: len(row.ocr), axis=1)
    tdata["len_gs"] = tdata.apply(lambda row: len(row.gs), axis=1)
    tdata["diff"] = tdata.len_ocr - tdata.len_gs
    return tdata

The following code example shows how use this function. 
For simplicity, in the example below, the `data` dictionary (which contain 
`<file name>: Text` pairs) is used both as train/val and test set.

In [None]:
data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"
data, md = generate_data(data_dir)
val_files = ["en/eng_sample/2.txt"]

token_data = get_tokens_with_OCR_mistakes(data, data, val_files)
token_data.head()

2it [00:00, 906.68it/s]


Unnamed: 0,ocr,gs,ocr_aligned,gs_aligned,start,len_ocr,language,subset,dataset,len_gs,diff
0,test- AAA,test-.AAA,test- AAA,test-.AAA,0,9,fr,fr_sample,train,9,0
1,test-BBB,test- BBB,test@-BBB,test- BBB,10,8,fr,fr_sample,train,9,-1
2,test-CCC,test- CCC,test-@CCC,test- CCC,19,8,fr,fr_sample,train,9,-1
3,-DDD,DDD,-DDD,DDD,33,4,fr,fr_sample,train,3,1
4,test- EEE,test-EEE,test- EEE,test-@EEE,38,9,fr,fr_sample,train,8,1


### Create vocabularies

https://pytorch.org/tutorials/beginner/translation_transformer.html

Define special symbols and indices and make sure the tokens are in order of their indices to properly insert them in vocababulary.

In [None]:
# | export
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3

special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]

In [None]:
# | export
def yield_tokens(data, col):
    """Helper function to create vocabulary containing characters"""
    for token in data[col].to_list():
        for char in token:
            yield char

In [None]:
# | export


def generate_vocabs(train):
    """Generate ocr and gs vocabularies from the train set"""
    vocab_transform = {}
    for name in ("ocr", "gs"):
        vocab_transform[name] = build_vocab_from_iterator(
            yield_tokens(train, name),
            min_freq=1,
            specials=special_symbols,
            special_first=True,
        )
    # Set UNK_IDX as the default index. This index is returned when the token is not found.
    # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
    for name in ("ocr", "gs"):
        vocab_transform[name].set_default_index(UNK_IDX)

    return vocab_transform

Use the trainset to create the ocr and gs vocabularies:

In [None]:
vocab_transform = generate_vocabs(token_data.query('dataset == "train"'))

In [None]:
len(vocab_transform["ocr"]), len(vocab_transform["gs"])

(46, 44)

In [None]:
# | export
class SimpleCorrectionDataset(Dataset):
    def __init__(self, data, max_len=10):
        self.ds = (
            data.query(f"len_ocr <= {max_len}").query(f"len_gs <= {max_len}").copy()
        )
        self.ds = self.ds.reset_index(drop=False)

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]

        return [char for char in sample.ocr], [char for char in sample.gs]

To create a `SimpleCorrectionDataset` with a maximum token length of 10, do:

In [None]:
dataset = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)

The first sample look like this:

In [None]:
dataset[0]

(['t', 'e', 's', 't', '-', ' ', 'A', 'A', 'A'],
 ['t', 'e', 's', 't', '-', '.', 'A', 'A', 'A'])

These character sequences need to be transformed into vectors. 

In [None]:
#| export
class BertVectorsCorrectionDataset(Dataset):
    def __init__(self, data: pd.DataFrame, bert_vectors_file: Path, split_name: str, max_len: int=11):
        ds = data.copy()
        ds.reset_index(drop=True, inplace=True)
        ds = ds.query(f'len_ocr < {max_len}').query(f'len_gs < {max_len}').copy()
        ds.reset_index(drop=False, inplace=True)
        self.ds = ds

        f = h5py.File(bert_vectors_file, "r")
        self.bert_vectors = f.get(split_name)

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        original_idx = sample["index"]
        bert_vector = torch.as_tensor(np.array(self.bert_vectors[original_idx]))

        return sample.ocr, sample.gs, bert_vector

The sample bert vectors have been generated using `python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5` (from ocrpostcorrection-notebooks, model from [9099e78](https://github.com/jvdzwaan/ocrpostcorrection-notebooks/commit/9099e785177a5c5207d01d80422e68d30f39636d))

In [None]:
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"), 
    bert_vectors_file=bert_vectors_file, 
    split_name=split_name
)

### Collation

Source: https://pytorch.org/tutorials/beginner/translation_transformer.html

In [None]:
# | export
def sequential_transforms(*transforms):
    """Helper function to club together sequential operations"""

    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input

    return func


def tensor_transform(token_ids: List[int]):
    """Function to add BOS/EOS and create tensor for input sequence indices"""
    return torch.cat((torch.tensor(token_ids), torch.tensor([EOS_IDX])))


def get_text_transform(vocab_transform):
    """Returns text transforms to convert raw strings into tensors indices"""
    text_transform = {}
    for name in ("ocr", "gs"):
        text_transform[name] = sequential_transforms(
            vocab_transform[name], tensor_transform  # Numericalization (char -> idx)
        )  # Add BOS/EOS and create tensor
    return text_transform

In [None]:
text_transform = get_text_transform(vocab_transform)

text_transform["ocr"](["t", "e", "s", "t", "-", " ", "A", "A", "A"])

tensor([ 4,  5,  6,  4,  7, 10, 13, 13, 13,  3])

In [None]:
text_transform = get_text_transform(vocab_transform)

print(text_transform["ocr"](["e", "x", "a", "m", "p", "l", "e"]))
print(text_transform["gs"](["e", "x", "a", "m", "p", "l", "e"]))

tensor([ 5,  0, 21, 34, 22, 33,  5,  3])
tensor([ 5,  0, 21, 27, 23, 26,  5,  3])


In [None]:
# | export
def collate_fn_with_text_transform(text_transform, batch):
    """Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform"""
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample, *_ in batch:
        src_batch.append(text_transform["ocr"](src_sample))
        tgt_batch.append(text_transform["gs"](tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch.to(torch.int64), tgt_batch.to(torch.int64)


def collate_fn(text_transform):
    """Function to collate data samples into batch tensors"""
    return partial(collate_fn_with_text_transform, text_transform)

In [None]:
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=5, collate_fn=collate_fn(text_transform)
)

In [None]:
#| hide
# Can we loop over the entire dataset?
num_samples = 0
for batch in train_dataloader:
    num_samples += batch[0].shape[1]
assert num_samples == len(train)

In [None]:
#| hide
# Can we loop over the entire dataset?
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"), 
    bert_vectors_file=bert_vectors_file, 
    split_name=split_name
)
dataloader = DataLoader(
    dataset, batch_size=5, collate_fn=collate_fn(text_transform)
)

num_samples = 0
for batch in train_dataloader:
    num_samples += batch[0].shape[1]
assert num_samples == len(train)

## Neural network

Network: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html


In [None]:
# | export
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

    def forward(self, input, hidden):
        # print('Encoder')
        # print('input size', input.size())
        # print('hidden size', hidden.size())
        embedded = self.embedding(input)
        # print('embedded size', embedded.size())
        # print(embedded)
        # print('embedded size met view', embedded.view(1, 1, -1).size())
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self, batch_size, device):
        return torch.zeros(1, batch_size, self.hidden_size, device=device)

In [None]:
# | export
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=11):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)

        # print('embedded size', embedded.size())
        # print(embedded)
        embedded = torch.permute(embedded, (1, 0, 2))
        # print('permuted embedded size', embedded.size())
        # print(embedded)

        # print('hidden size', hidden.size())
        # print(hidden)

        # print('permuted embedded[0] size', embedded[0].size())
        # print('hidden[0] size', hidden[0].size())

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1
        )

        # print('attn_weights', attn_weights.size())
        # print('attn_weights unsqueeze(1)', attn_weights.unsqueeze(1).size())
        # print('encoder outputs', encoder_outputs.size())

        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)

        # print('embedded[0]', embedded[0].size())
        # print('attn_applied', attn_applied.size())
        # print('attn_applied squeeze', attn_applied.squeeze().size())
        output = torch.cat((embedded[0], attn_applied.squeeze(1)), 1)
        # print('output', output.size())
        output = self.attn_combine(output).unsqueeze(0)
        # print('output', output.size())

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)

        # print(f'output: {output.size()}; hidden: {hidden.size()}; attn_weigts: {attn_weights.size()}')

        return output, hidden, attn_weights

    def initHidden(self, batch_size, device):
        return torch.zeros(1, batch_size, self.hidden_size, device=device)

In [None]:
# | export
class SimpleCorrectionSeq2seq(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        dropout,
        max_length,
        teacher_forcing_ratio,
        device="cpu",
    ):
        super(SimpleCorrectionSeq2seq, self).__init__()

        self.device = device

        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.max_length = max_length + 1

        self.encoder = EncoderRNN(input_size, hidden_size)
        self.decoder = AttnDecoderRNN(
            hidden_size, output_size, dropout_p=dropout, max_length=self.max_length
        )

    def forward(self, input, encoder_hidden, target):
        # input is src seq len x batch size
        # input voor de encoder (1 stap) moet zijn input seq len x batch size x 1
        input_tensor = input.unsqueeze(2)
        # print('input tensor size', input_tensor.size())

        input_length = input.size(0)

        batch_size = input.size(1)

        # Encoder part
        encoder_outputs = torch.zeros(
            batch_size, self.max_length, self.encoder.hidden_size, device=self.device
        )
        # print('encoder outputs size', encoder_outputs.size())

        for ei in range(input_length):
            # print(f'Index {ei}; input size: {input_tensor[ei].size()}; encoder hidden size: {encoder_hidden.size()}')
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden
            )
            # print('Index', ei)
            # print('encoder output size', encoder_output.size())
            # print('encoder outputs size', encoder_outputs.size())
            # print('output selection size', encoder_output[:, 0].size())
            # print('ouput to save', encoder_outputs[:,ei].size())
            encoder_outputs[:, ei] = encoder_output[0, 0]

        # print('encoder outputs', encoder_outputs)
        # print('encoder hidden', encoder_hidden)

        # Decoder part
        # Target = seq len x batch size
        # Decoder input moet zijn: batch_size x 1 (van het eerste token = BOS)
        target_length = target.size(0)

        decoder_input = torch.tensor(
            [[BOS_IDX] for _ in range(batch_size)], device=self.device
        )
        # print('decoder input size', decoder_input.size())

        decoder_outputs = torch.zeros(
            batch_size, self.max_length, self.decoder.output_size, device=self.device
        )

        decoder_hidden = encoder_hidden

        use_teacher_forcing = (
            True if random.random() < self.teacher_forcing_ratio else False
        )

        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            if use_teacher_forcing:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target[di, :].unsqueeze(1)  # Teacher forcing
                # print('decoder input size:', decoder_input.size())
            else:
                # Without teacher forcing: use its own predictions as the next input
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.detach()  # detach from history as input
                # print('decoder input size:', decoder_input.size())

            # print(f'Index {di}; decoder output size: {decoder_output.size()}; decoder input size: {decoder_input.size()}')
            decoder_outputs[:, di] = decoder_output

        # Zero out probabilities for padded chars
        target_masks = (target != PAD_IDX).float()

        # Compute log probability of generating true target words
        # print('P (decoder_outputs)', decoder_outputs.size())
        # print(target.transpose(0, 1))
        # print('Index', target.size(), target.transpose(0, 1).unsqueeze(-1))
        target_gold_std_log_prob = torch.gather(
            decoder_outputs, index=target.transpose(0, 1).unsqueeze(-1), dim=-1
        ).squeeze(-1) * target_masks.transpose(0, 1)
        # print(target_gold_std_log_prob)
        scores = target_gold_std_log_prob.sum(dim=1)

        # print(scores)

        return scores, encoder_outputs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 5
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)

input = torch.tensor([[6, 4], [22, 30], [0, 6], [18, 4], [11, 3], [3, 1]])
encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

target = torch.tensor([[6, 4], [23, 5], [16, 6], [16, 4], [11, 4], [3, 1]])

losses, _ = model(input, encoder_hidden, target)
losses

tensor([-22.0969, -19.2172], grad_fn=<SumBackward1>)

## Training

In [None]:
# | export
def validate_model(model, dataloader, device):
    cum_loss = 0
    cum_examples = 0

    was_training = model.training
    model.eval()

    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)

            batch_size = src.size(1)

            encoder_hidden = model.encoder.initHidden(
                batch_size=batch_size, device=device
            )

            example_losses, decoder_ouputs = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()

            bl = batch_loss.item()
            cum_loss += bl
            cum_examples += batch_size

    if was_training:
        model.train()

    return cum_loss / cum_examples

In [None]:
val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss

24.570460001627605

In [None]:
def train_model(
    train_dl,
    val_dl,
    model=None,
    optimizer=None,
    num_epochs=5,
    valid_niter=5000,
    model_save_path="model.rar",
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
    device="cpu",
):
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs + 1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt in train_dl:
            # print(f'src: {src.size()}; tgt: {tgt.size()}')
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = model.encoder.initHidden(
                batch_size=batch_size, device=device
            )

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                print(
                    f"Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}"
                )

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f"Saving model and optimizer to {model_save_path}")
                    torch.save(
                        {
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        model_save_path,
                    )
                elif patience < max_num_patience:
                    patience += 1
                    print(f"hit patience {patience}")

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f"hit #{num_trial} trial")
                        if num_trial == max_num_trial:
                            print("early stop!")
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]["lr"] * lr_decay
                        print(
                            f"load previously best model and decay learning rate to {lr}"
                        )

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint["model_state_dict"])
                        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                        model = model.to(device)

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr

                        # reset patience
                        patience = 0

                val_loss_hist.append(val_loss)

In [None]:
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 5
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model.rar"

train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)

Epoch 1, iter 5, avg. train loss 24.269323539733886, avg. val loss 25.217035081651474
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 1, iter 10, avg. train loss 24.12647247314453, avg. val loss 25.133172776963974
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 15, avg. train loss 25.964556884765624, avg. val loss 25.080279880099827
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 20, avg. train loss 28.973776626586915, avg. val loss 25.027193705240887
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 25, avg. train loss 29.202961349487303, avg. val loss 24.984923468695747
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar


## Evaluation

In [None]:
model_save_path = Path(os.getcwd()) / "data" / "model.rar"

checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.eval()

SimpleCorrectionSeq2seq(
  (encoder): EncoderRNN(
    (embedding): Embedding(46, 5)
    (gru): GRU(5, 5, batch_first=True)
  )
  (decoder): AttnDecoderRNN(
    (embedding): Embedding(44, 5)
    (attn): Linear(in_features=10, out_features=11, bias=True)
    (attn_combine): Linear(in_features=10, out_features=5, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (gru): GRU(5, 5)
    (out): Linear(in_features=5, out_features=44, bias=True)
  )
)

### Predict

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding

In [None]:
# | export


class GreedySearchDecoder(nn.Module):
    def __init__(self, model):
        super(GreedySearchDecoder, self).__init__()
        self.max_length = model.max_length
        self.encoder = model.encoder
        self.decoder = model.decoder

        self.device = model.device

    def forward(self, input, target):
        # input is src seq len x batch size
        # input voor de encoder (1 stap) moet zijn input seq len x batch size x 1
        input_tensor = input.unsqueeze(2)
        # print('input tensor size', input_tensor.size())

        input_length = input.size(0)

        batch_size = input.size(1)
        encoder_hidden = self.encoder.initHidden(batch_size, self.device)

        # Encoder part
        encoder_outputs = torch.zeros(
            batch_size, self.max_length, self.encoder.hidden_size, device=self.device
        )
        # print('encoder outputs size', encoder_outputs.size())

        for ei in range(input_length):
            # print(f'Index {ei}; input size: {input_tensor[ei].size()}; encoder hidden size: {encoder_hidden.size()}')
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden
            )
            # print('Index', ei)
            # print('encoder output size', encoder_output.size())
            # print('encoder outputs size', encoder_outputs.size())
            # print('output selection size', encoder_output[:, 0].size())
            # print('ouput to save', encoder_outputs[:,ei].size())
            encoder_outputs[:, ei] = encoder_output[0, 0]

        # print('encoder outputs', encoder_outputs)
        # print('encoder hidden', encoder_hidden)

        # Decoder part
        # Target = seq len x batch size
        # Decoder input moet zijn: batch_size x 1 (van het eerste token = BOS)
        target_length = target.size(0)

        decoder_input = torch.tensor(
            [[BOS_IDX] for _ in range(batch_size)], device=self.device
        )
        # print('decoder input size', decoder_input.size())

        all_tokens = torch.zeros(
            batch_size, self.max_length, device=self.device, dtype=torch.long
        )
        # print('all_tokens size', all_tokens.size())
        decoder_hidden = encoder_hidden

        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Without teacher forcing: use its own predictions as the next input
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.detach()  # detach from history as input
            # print('decoder input size:', decoder_input.size())
            # print('decoder input squeezed', decoder_input.clone().squeeze())

            # Record token
            all_tokens[:, di] = decoder_input.clone().squeeze(1)
            # print('all_tokens', all_tokens)

        return all_tokens

In [None]:
decoder = GreedySearchDecoder(model)

max_len = 10

test = SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test_dataloader = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt) in enumerate(test_dataloader):
        predicted_indices = decoder(src, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())

tensor([[19, 40, 40,  9, 40,  9,  9,  6, 40, 40,  0],
        [ 9, 40,  9,  9, 40,  9,  9, 40,  9,  9,  0],
        [ 9,  9, 40,  9,  9, 40,  9,  9, 40,  9,  0],
        [ 9,  9,  9,  6, 40,  9, 40,  9, 40,  9,  0],
        [ 9, 40,  9,  9, 40,  9,  9, 40,  9,  9,  0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])


In [None]:
# | export
def indices2string(indices, itos):
    output = []
    for idxs in indices:
        # print(idxs)
        string = []
        for idx in idxs:
            if idx not in (UNK_IDX, PAD_IDX, BOS_IDX):
                if idx == EOS_IDX:
                    break
                else:
                    string.append(itos[idx])
        word = "".join(string)
        output.append(word)
    return output

In [None]:
indices = torch.tensor(
    [
        [20, 34, 22, 6, 1, 1, 1, 1, 1, 1],
        [22, 6, 1, 1, 1, 1, 1, 1, 1, 1],
        [21, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [4, 5, 6, 4, 1, 1, 1, 1, 1, 1],
        [29, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]
)
indices2string(indices, vocab_transform["gs"].get_itos())

['This', 'is', 'a', 'test', '!']

In [None]:
# | export
def predict_and_convert_to_str(model, dataloader, tgt_vocab, device):
    was_training = model.training
    model.eval()

    decoder = GreedySearchDecoder(model)

    itos = tgt_vocab.get_itos()
    output_strings = []

    with torch.no_grad():
        for src, tgt in tqdm(dataloader):
            src = src.to(device)
            tgt = tgt.to(device)

            predicted_indices = decoder(src, tgt)

            strings_batch = indices2string(predicted_indices, itos)
            for s in strings_batch:
                output_strings.append(s)

    if was_training:
        model.train()

    return output_strings

In [None]:
output_strings = predict_and_convert_to_str(
    model, test_dataloader, vocab_transform["gs"], device
)
output_strings[0:3]

100%|██████████| 7/7 [00:00<00:00, 227.77it/s]


['ISSESEEsSS', 'ESEESEESEE', 'EESEESEESE']

In [None]:
max_len = 10
test_data = (
    token_data.query('dataset == "test"')
    .query(f"len_ocr <= {max_len}")
    .query(f"len_gs <= {max_len}")
    .copy()
)

test_data["pred"] = output_strings

### Performance measure: mean normalized edit distance

- Mean (normalized) edit distance.
    - Option: ignore -
    - option: ignore case

In [None]:
test_data["ed"] = test_data.apply(
    lambda row: edlib.align(row.ocr, row.gs)["editDistance"], axis=1
)
test_data.ed.describe()

count    35.000000
mean      1.971429
std       1.773758
min       1.000000
25%       1.000000
50%       1.000000
75%       2.000000
max       8.000000
Name: ed, dtype: float64

In [None]:
test_data["ed_norm"] = test_data.apply(
    lambda row: normalized_ed(row.ed, row.ocr, row.gs), axis=1
)
test_data.ed_norm.describe()

count    35.000000
mean      0.390952
std       0.326909
min       0.100000
25%       0.125000
50%       0.250000
75%       0.583333
max       1.000000
Name: ed_norm, dtype: float64

In [None]:
test_data["ed_pred"] = test_data.apply(
    lambda row: edlib.align(row.pred, row.gs)["editDistance"], axis=1
)
test_data.ed_pred.describe()

count    35.000000
mean      9.857143
std       1.004193
min       7.000000
25%       9.000000
50%      10.000000
75%      11.000000
max      11.000000
Name: ed_pred, dtype: float64

In [None]:
test_data["ed_norm_pred"] = test_data.apply(
    lambda row: normalized_ed(row.ed_pred, row.pred, row.gs), axis=1
)
test_data.ed_norm_pred.describe()

count    35.000000
mean      0.972092
std       0.061538
min       0.700000
25%       1.000000
50%       1.000000
75%       1.000000
max       1.000000
Name: ed_norm_pred, dtype: float64

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()