In [None]:
# | default_exp simple_correction_data

In [None]:
# | export
from functools import partial

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from tqdm import tqdm

from ocrpostcorrection.error_correction import (
    BOS_IDX,
    PAD_IDX,
    generate_vocabs,
    get_text_transform,
    get_tokens_with_OCR_mistakes,
    indices2string,
)

In [None]:
# | hide
import os
from pathlib import Path

from torch.utils.data import DataLoader

from ocrpostcorrection.error_correction import SimpleCorrectionSeq2seq
from ocrpostcorrection.icdar_data import generate_data
from ocrpostcorrection.utils import set_seed

In [None]:
set_seed(23)

In [None]:
data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"
data, md = generate_data(data_dir)
val_files = ["en/eng_sample/2.txt"]

token_data = get_tokens_with_OCR_mistakes(data, data, val_files)
vocab_transform = generate_vocabs(token_data.query('dataset == "train"'))

2it [00:00, 798.15it/s]


In [None]:
# | export
class SimpleCorrectionDataset(Dataset):
    def __init__(self, data, max_len=10):
        self.ds = (
            data.query(f"len_ocr <= {max_len}").query(f"len_gs <= {max_len}").copy()
        )
        self.ds = self.ds.reset_index(drop=False)

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]

        return [char for char in sample.ocr], [char for char in sample.gs]

To create a `SimpleCorrectionDataset` with a maximum token length of 10, do:

In [None]:
dataset = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)

The first sample look like this:

In [None]:
dataset[0]

(['t', 'e', 's', 't', '-', ' ', 'A', 'A', 'A'],
 ['t', 'e', 's', 't', '-', '.', 'A', 'A', 'A'])

In [None]:
# | export
def collate_fn_with_text_transform(text_transform, batch):
    """Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform"""
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform["ocr"](src_sample))
        tgt_batch.append(text_transform["gs"](tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch.to(torch.int64), tgt_batch.to(torch.int64)


def collate_fn(text_transform):
    """Function to collate data samples into batch tensors"""
    return partial(collate_fn_with_text_transform, text_transform)

In [None]:
text_transform = get_text_transform(vocab_transform)

In [None]:
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=5, collate_fn=collate_fn(text_transform)
)

In [None]:
#| hide
# Can we loop over the entire dataset?
num_samples = 0
for batch in train_dataloader:
    num_samples += batch[0].shape[1]
assert num_samples == len(train)

## Training

In [None]:
# | export
def validate_model(model, dataloader, device):
    cum_loss = 0
    cum_examples = 0

    was_training = model.training
    model.eval()

    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)

            batch_size = src.size(1)

            encoder_hidden = model.encoder.initHidden(
                batch_size=batch_size, device=device
            )

            example_losses, decoder_ouputs = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()

            bl = batch_loss.item()
            cum_loss += bl
            cum_examples += batch_size

    if was_training:
        model.train()

    return cum_loss / cum_examples

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 5
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)

encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

In [None]:
val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss

25.545663621690537

In [None]:
# | export
def train_model(
    train_dl,
    val_dl,
    model=None,
    optimizer=None,
    num_epochs=5,
    valid_niter=5000,
    model_save_path="model.rar",
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
    device="cpu",
):
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs + 1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt in train_dl:
            # print(f'src: {src.size()}; tgt: {tgt.size()}')
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = model.encoder.initHidden(
                batch_size=batch_size, device=device
            )

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                print(
                    f"Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}"
                )

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f"Saving model and optimizer to {model_save_path}")
                    torch.save(
                        {
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        model_save_path,
                    )
                elif patience < max_num_patience:
                    patience += 1
                    print(f"hit patience {patience}")

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f"hit #{num_trial} trial")
                        if num_trial == max_num_trial:
                            print("early stop!")
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]["lr"] * lr_decay
                        print(
                            f"load previously best model and decay learning rate to {lr}"
                        )

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint["model_state_dict"])
                        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                        model = model.to(device)

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr

                        # reset patience
                        patience = 0

                val_loss_hist.append(val_loss)

In [None]:
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 5
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model.rar"

train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)

Epoch 1, iter 5, avg. train loss 25.21373109817505, avg. val loss 25.264954460991753
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 1, iter 10, avg. train loss 27.308312225341798, avg. val loss 25.19587156507704
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 15, avg. train loss 25.64889602661133, avg. val loss 25.134972466362846
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 20, avg. train loss 26.240159034729004, avg. val loss 25.078634050157333
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 25, avg. train loss 22.31423110961914, avg. val loss 25.014130486382378
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar


## Inference / prediction

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding

In [None]:
# | export


class GreedySearchDecoder(torch.nn.Module):
    def __init__(self, model):
        super(GreedySearchDecoder, self).__init__()
        self.max_length = model.max_length
        self.encoder = model.encoder
        self.decoder = model.decoder

        self.device = model.device

    def forward(self, input, target):
        # input is src seq len x batch size
        # input voor de encoder (1 stap) moet zijn input seq len x batch size x 1
        input_tensor = input.unsqueeze(2)
        # print('input tensor size', input_tensor.size())

        input_length = input.size(0)

        batch_size = input.size(1)
        encoder_hidden = self.encoder.initHidden(batch_size, self.device)
        print(encoder_hidden.size())

        # Encoder part
        encoder_outputs = torch.zeros(
            batch_size, self.max_length, self.encoder.hidden_size, device=self.device
        )
        # print('encoder outputs size', encoder_outputs.size())

        for ei in range(input_length):
            # print(f'Index {ei}; input size: {input_tensor[ei].size()}; encoder hidden size: {encoder_hidden.size()}')
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden
            )
            # print('Index', ei)
            # print('encoder output size', encoder_output.size())
            # print('encoder outputs size', encoder_outputs.size())
            # print('output selection size', encoder_output[:, 0].size())
            # print('ouput to save', encoder_outputs[:,ei].size())
            encoder_outputs[:, ei] = encoder_output[0, 0]

        # print('encoder outputs', encoder_outputs)
        # print('encoder hidden', encoder_hidden)

        # Decoder part
        # Target = seq len x batch size
        # Decoder input moet zijn: batch_size x 1 (van het eerste token = BOS)
        target_length = target.size(0)

        decoder_input = torch.tensor(
            [[BOS_IDX] for _ in range(batch_size)], device=self.device
        )
        # print('decoder input size', decoder_input.size())

        all_tokens = torch.zeros(
            batch_size, self.max_length, device=self.device, dtype=torch.long
        )
        # print('all_tokens size', all_tokens.size())
        decoder_hidden = encoder_hidden

        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Without teacher forcing: use its own predictions as the next input
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.detach()  # detach from history as input
            # print('decoder input size:', decoder_input.size())
            # print('decoder input squeezed', decoder_input.clone().squeeze())

            # Record token
            all_tokens[:, di] = decoder_input.clone().squeeze(1)
            # print('all_tokens', all_tokens)

        return all_tokens

In [None]:
decoder = GreedySearchDecoder(model)

max_len = 10

test = SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test_dataloader = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt) in enumerate(test_dataloader):
        predicted_indices = decoder(src, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())

torch.Size([1, 5, 5])
tensor([[27, 27, 27, 17,  7, 17,  7,  7, 17, 27,  0],
        [18, 27, 27, 27, 27, 27, 17, 17, 27, 17,  0],
        [18,  3, 18, 27, 17, 26, 27, 27, 27, 27,  0],
        [18, 26, 27, 18, 27, 27, 27, 27, 27, 27,  0],
        [ 6, 27, 27, 27, 27, 17, 17,  7, 17,  7,  0]])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])


In [None]:
# | export
def predict_and_convert_to_str(model, dataloader, tgt_vocab, device):
    was_training = model.training
    model.eval()

    decoder = GreedySearchDecoder(model)

    itos = tgt_vocab.get_itos()
    output_strings = []

    with torch.no_grad():
        for src, tgt in tqdm(dataloader):
            src = src.to(device)
            tgt = tgt.to(device)

            predicted_indices = decoder(src, tgt)

            strings_batch = indices2string(predicted_indices, itos)
            for s in strings_batch:
                output_strings.append(s)

    if was_training:
        model.train()

    return output_strings

In [None]:
output_strings = predict_and_convert_to_str(
    model, test_dataloader, vocab_transform["gs"], device
)
output_strings[0:3]

100%|██████████| 7/7 [00:00<00:00, 352.93it/s]

torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])





['mmmmmmmmmm', 'Fmmmmmmmmm', 'Fmmmmmmmmm']

In [None]:
max_len = 10
test_data = (
    token_data.query('dataset == "test"')
    .query(f"len_ocr <= {max_len}")
    .query(f"len_gs <= {max_len}")
    .copy()
)

test_data["pred"] = output_strings

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()