In [None]:
# | default_exp bert_vectors_correction_data

In [None]:
# | export
from functools import partial
from pathlib import Path
from typing import List, Optional

import h5py
import numpy as np
import pandas as pd
import torch
from loguru import logger
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from ocrpostcorrection.error_correction import (
    BOS_IDX,
    PAD_IDX,
    SimpleCorrectionSeq2seq,
    generate_vocabs,
    get_text_transform,
    indices2string,
)

In [None]:
# | hide
import os

from datasets import Dataset as HFDataset
from transformers import AutoTokenizer, BertModel, DataCollatorWithPadding

from ocrpostcorrection.utils import set_seed

In [None]:
set_seed(23)

In [None]:
# | export
class BertVectorsCorrectionDataset(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        split_name: str,
        bert_vectors_file: Optional[Path] = None,
        max_len: int = 11,
        hidden_size: int = 768,
        look_up_bert_vectors: bool = True,
    ):
        ds = data.copy()
        ds.reset_index(drop=True, inplace=True)
        ds = ds.query(f"len_ocr < {max_len}").query(f"len_gs < {max_len}").copy()
        ds.reset_index(drop=False, inplace=True)
        self.ds = ds

        if bert_vectors_file:
            f = h5py.File(bert_vectors_file, "r")
            self.bert_vectors = f.get(split_name)

        self.hidden_size = hidden_size
        self.look_up_bert_vectors = look_up_bert_vectors

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        if self.look_up_bert_vectors:
            original_idx = sample["index"]
            bert_vector = torch.as_tensor(np.array(self.bert_vectors[original_idx]))
        else:
            # Bert vectors should be calculated on the fly
            bert_vector = torch.zeros(self.hidden_size)

        return [char for char in sample.ocr], [char for char in sample.gs], bert_vector

The sample bert vectors have been generated using `python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5` (from ocrpostcorrection-notebooks, model from [9099e78](https://github.com/jvdzwaan/ocrpostcorrection-notebooks/commit/9099e785177a5c5207d01d80422e68d30f39636d))

In [None]:
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=True,
)

In [None]:
split_name = "test"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"

dataset_no_look_up = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=None,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=False
)

In [None]:
# | export
def collate_fn_with_text_transform(text_transform, batch):
    """Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform"""
    src_batch, tgt_batch, bert_vectors = [], [], []
    for src_sample, tgt_sample, bert_vector in batch:
        src_batch.append(text_transform["ocr"](src_sample))
        tgt_batch.append(text_transform["gs"](tgt_sample))
        bert_vectors.append(bert_vector)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    # Size of encoder_hidden should be 1 x batch_size x hidden_size
    encoder_hidden = torch.unsqueeze(torch.stack(bert_vectors, dim=0), dim=0)

    return src_batch.to(torch.int64), tgt_batch.to(torch.int64), encoder_hidden


def collate_fn(text_transform):
    """Function to collate data samples into batch tensors"""
    return partial(collate_fn_with_text_transform, text_transform)

In [None]:
# | hide
# Can we loop over the entire dataset?
vocab_transform = generate_vocabs(data.query('dataset == "train"'))
text_transform = get_text_transform(vocab_transform)

dataloader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn(text_transform))

num_samples = 0
for batch in dataloader:
    print(len(batch))
    print(batch[2].shape)

    num_samples += batch[0].shape[1]
assert num_samples == len(dataset)

3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])


In [None]:
# | hide
# Can we loop over the entire dataset?
dataloader_no_look_up = DataLoader(dataset_no_look_up, batch_size=5, collate_fn=collate_fn(text_transform))

num_samples = 0
for batch in dataloader:
    print(len(batch))
    print(batch[2].shape)

    num_samples += batch[0].shape[1]
assert num_samples == len(dataset)

3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])


## Training

In [None]:
# | export
def validate_model(model, dataloader, device):
    cum_loss = 0
    cum_examples = 0

    was_training = model.training
    model.eval()

    with torch.no_grad():
        for src, tgt, encoder_hidden in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = encoder_hidden.to(device)

            batch_size = src.size(1)

            example_losses, _decoder_ouputs = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()

            bl = batch_loss.item()
            cum_loss += bl
            cum_examples += batch_size

    if was_training:
        model.train()

    return cum_loss / cum_examples

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 768
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)
model.to(device)

SimpleCorrectionSeq2seq(
  (encoder): EncoderRNN(
    (embedding): Embedding(46, 768)
    (gru): GRU(768, 768, batch_first=True)
  )
  (decoder): AttnDecoderRNN(
    (embedding): Embedding(44, 768)
    (attn): Linear(in_features=1536, out_features=11, bias=True)
    (attn_combine): Linear(in_features=1536, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (gru): GRU(768, 768)
    (out): Linear(in_features=768, out_features=44, bias=True)
  )
)

In [None]:
split_name = "val"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"

val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=True
)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss

24.875640021430122

In [None]:
data

Unnamed: 0,ocr,gs,ocr_aligned,gs_aligned,start,len_ocr,language,subset,dataset,len_gs,diff
0,test- AAA,test-.AAA,test- AAA,test-.AAA,0,9,fr,fr_sample,train,9,0
1,test-BBB,test- BBB,test@-BBB,test- BBB,10,8,fr,fr_sample,train,9,-1
2,test-CCC,test- CCC,test-@CCC,test- CCC,19,8,fr,fr_sample,train,9,-1
3,-DDD,DDD,-DDD,DDD,33,4,fr,fr_sample,train,3,1
4,test- EEE,test-EEE,test- EEE,test-@EEE,38,9,fr,fr_sample,train,8,1
...,...,...,...,...,...,...,...,...,...,...,...
75,species!,species.,species!,species.,111,8,en,eng_sample,test,8,0
76,Test -hyhen,Testhyhen,Test -hyhen,Test@@hyhen,120,11,en,eng_sample,test,9,2
77,error,errors,error@,errors,137,5,en,eng_sample,test,6,-1
78,C,CCC,C@@,CCC,151,1,en,eng_sample,test,3,-2


In [None]:
# | export
def train_model(
    train_dl: DataLoader[int],
    val_dl: DataLoader[int],
    model: SimpleCorrectionSeq2seq,
    optimizer: torch.optim.Optimizer,
    num_epochs: int = 5,
    valid_niter: int = 5000,
    model_save_path: Path = Path("model.rar"),
    max_num_patience: int = 5,
    max_num_trial: int = 5,
    lr_decay: float = 0.5,
    device: torch.device = torch.device("cpu"),
) -> pd.DataFrame:
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist: List[float] = []
    train_loss_hist: List[float] = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs + 1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt, encoder_hidden in tqdm(train_dl):
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = encoder_hidden.to(device)

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                train_loss = report_loss / report_examples
                logger.info(
                    f"Epoch {epoch}, iter {num_iter}, avg. train loss "
                    + f"{train_loss}, avg. val loss {val_loss}"
                )

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    logger.info(f"Saving model and optimizer to {model_save_path}")
                    torch.save(
                        {
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        model_save_path,
                    )
                elif patience < max_num_patience:
                    patience += 1
                    logger.info(f"Hit patience {patience}")

                    if patience == max_num_patience:
                        num_trial += 1
                        logger.info(f"Hit #{num_trial} trial")
                        if num_trial == max_num_trial:
                            logger.info("Early stop!")
                            # Create train log
                            df = pd.DataFrame({"train_loss": train_loss_hist, "val_loss": val_loss_hist})
                            return df

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]["lr"] * lr_decay
                        logger.info(
                            f"Load best model so far and decay learning rate to {lr}"
                        )

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint["model_state_dict"])
                        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                        model = model.to(device)

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr

                        # reset patience
                        patience = 0

                val_loss_hist.append(val_loss)
                train_loss_hist.append(train_loss)

    # Create train log
    df = pd.DataFrame({"train_loss": train_loss_hist, "val_loss": val_loss_hist})
    return df

In [None]:
split_name = "train"
train = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

split_name = "val"
val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 768
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model_bert_vectors.rar"

train_log = train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)
os.remove(msp)
train_log

 31%|███       | 4/13 [00:00<00:01,  8.61it/s]2023-09-03 19:00:06.994 | INFO     | __main__:train_model:58 - Epoch 1, iter 5, avg. train loss 27.373350143432617, avg. val loss 24.627723693847656
2023-09-03 19:00:06.995 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
 69%|██████▉   | 9/13 [00:01<00:00,  8.23it/s]2023-09-03 19:00:07.644 | INFO     | __main__:train_model:58 - Epoch 1, iter 10, avg. train loss 24.103273010253908, avg. val loss 24.043284098307293
2023-09-03 19:00:07.644 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
100%|██████████| 13/13 [00:01<00:00,  7.17it/s]
  8%|▊         | 1/13 [00:00<00:01,  8.18it/s]2023-09-03 19:00:08.502 | INFO     | __main__:train_model:58 - Epoch 2, iter 15, avg. train loss 19.344550323486327, avg. val loss 19.952612982855904
2023-09-03 19:00:08.502 | INFO     | __

Unnamed: 0,train_loss,val_loss
0,27.37335,24.627724
1,24.103273,24.043284
2,19.34455,19.952613
3,22.214739,18.826036
4,21.808087,18.297656


## Inference / prediction

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding

In [None]:
# | export

class GreedySearchDecoder(torch.nn.Module):
    def __init__(self, model):
        super(GreedySearchDecoder, self).__init__()
        self.max_length = model.max_length
        self.encoder = model.encoder
        self.decoder = model.decoder

        self.device = model.device

    def forward(self, input, encoder_hidden, target):
        # input is src seq len x batch size
        # input voor de encoder (1 stap) moet zijn input seq len x batch size x 1
        input_tensor = input.unsqueeze(2)
        # print('input tensor size', input_tensor.size())

        input_length = input.size(0)

        batch_size = input.size(1)

        # Encoder part
        encoder_outputs = torch.zeros(
            batch_size, self.max_length, self.encoder.hidden_size, device=self.device
        )
        # print('encoder outputs size', encoder_outputs.size())

        for ei in range(input_length):
            # print(f'Index {ei}; input size: {input_tensor[ei].size()}; encoder hidden size: {encoder_hidden.size()}')
            encoder_output, encoder_hidden = self.encoder(
                input_tensor[ei], encoder_hidden
            )
            # print('Index', ei)
            # print('encoder output size', encoder_output.size())
            # print('encoder outputs size', encoder_outputs.size())
            # print('output selection size', encoder_output[:, 0].size())
            # print('ouput to save', encoder_outputs[:,ei].size())
            encoder_outputs[:, ei] = encoder_output[0, 0]

        # print('encoder outputs', encoder_outputs)
        # print('encoder hidden', encoder_hidden)

        # Decoder part
        # Target = seq len x batch size
        # Decoder input moet zijn: batch_size x 1 (van het eerste token = BOS)
        target_length = target.size(0)

        decoder_input = torch.tensor(
            [[BOS_IDX] for _ in range(batch_size)], device=self.device
        )
        # print('decoder input size', decoder_input.size())

        all_tokens = torch.zeros(
            batch_size, self.max_length, device=self.device, dtype=torch.long
        )
        # print('all_tokens size', all_tokens.size())
        decoder_hidden = encoder_hidden

        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Without teacher forcing: use its own predictions as the next input
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.detach()  # detach from history as input
            # print('decoder input size:', decoder_input.size())
            # print('decoder input squeezed', decoder_input.clone().squeeze())

            # Record token
            all_tokens[:, di] = decoder_input.clone().squeeze(1)
            # print('all_tokens', all_tokens)

        return all_tokens

In [None]:
decoder = GreedySearchDecoder(model)

test_dataloader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
        predicted_indices = decoder(src, encoder_hidden, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())

tensor([[ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  4,  3,  3,  3,  3,  0],
        [ 4,  4,  5,  6,  4,  3,  3,  3,  3,  3,  0],
        [14, 13,  3,  3,  3,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])


In [None]:
decoder = GreedySearchDecoder(model)

test_dataloader = DataLoader(dataset_no_look_up, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
        predicted_indices = decoder(src, encoder_hidden, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())

tensor([[ 4,  5,  6,  4,  8,  6,  4,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  4,  3,  3,  3,  3,  0],
        [ 4,  4,  5,  6,  4,  3,  4,  3,  3,  3,  0],
        [14, 13,  3,  3,  3,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])


In [None]:
# | export
def predict_and_convert_to_str(model, dataloader, bert_model, dataloader_bert_vectors, tgt_vocab, device):
    was_training = model.training
    model.eval()

    decoder = GreedySearchDecoder(model)

    itos = tgt_vocab.get_itos()
    output_strings = []

    with torch.no_grad():
        for (src, tgt, _bert_vector), bert_vector_input in tqdm(zip(dataloader, dataloader_bert_vectors)):
            src = src.to(device)
            tgt = tgt.to(device)
            bert_vector_input = bert_vector_input.to(device)

            bert_vector_output = bert_model(**bert_vector_input)
            encoder_hidden = bert_vector_output["pooler_output"].detach().unsqueeze(0)

            predicted_indices = decoder(src, encoder_hidden, tgt)

            strings_batch = indices2string(predicted_indices, itos)
            for s in strings_batch:
                output_strings.append(s)

    if was_training:
        model.train()

    return output_strings

In [None]:
model_name = "bert-base-multilingual-cased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)
bert_model.eval()

dataset_bert_vectors = HFDataset.from_pandas(test_dataloader.dataset.ds.ocr.to_frame())
tokenized_dataset = dataset_bert_vectors.map(
    lambda sample: tokenizer(sample["ocr"], truncation=True),
    batched=True,
)
tokenized_dataset = tokenized_dataset.remove_columns(
    ["ocr"]
)

collator = DataCollatorWithPadding(tokenizer)
test_dataloader_bert_vectors = DataLoader(
    tokenized_dataset, batch_size=5, collate_fn=collator
)

predictions = predict_and_convert_to_str(
    model=model,
    dataloader=test_dataloader,
    bert_model=bert_model,
    dataloader_bert_vectors=test_dataloader_bert_vectors,
    tgt_vocab=vocab_transform["gs"],
    device=device,
)
predictions[:3]

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Map:   0%|          | 0/35 [00:00<?, ? examples/s]

0it [00:00, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
7it [00:00, 11.23it/s]


['test-', 'test-t', 'ttest']

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()