In [None]:
# | default_exp bert_vectors_correction_data

In [None]:
# | export
import sys
from functools import partial
from pathlib import Path
from typing import List

import h5py
import numpy as np
import pandas as pd
import torch
from loguru import logger
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from ocrpostcorrection.error_correction import (
    PAD_IDX,
    SimpleCorrectionSeq2seq,
    generate_vocabs,
    get_text_transform,
)

In [None]:
# | hide
import os

from ocrpostcorrection.utils import set_seed

In [None]:
set_seed(23)

In [None]:
# | export
class BertVectorsCorrectionDataset(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        bert_vectors_file: Path,
        split_name: str,
        max_len: int = 11,
    ):
        ds = data.copy()
        ds.reset_index(drop=True, inplace=True)
        ds = ds.query(f"len_ocr < {max_len}").query(f"len_gs < {max_len}").copy()
        ds.reset_index(drop=False, inplace=True)
        self.ds = ds

        f = h5py.File(bert_vectors_file, "r")
        self.bert_vectors = f.get(split_name)

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        original_idx = sample["index"]
        bert_vector = torch.as_tensor(np.array(self.bert_vectors[original_idx]))

        return [char for char in sample.ocr], [char for char in sample.gs], bert_vector

The sample bert vectors have been generated using `python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5` (from ocrpostcorrection-notebooks, model from [9099e78](https://github.com/jvdzwaan/ocrpostcorrection-notebooks/commit/9099e785177a5c5207d01d80422e68d30f39636d))

In [None]:
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)

In [None]:
# | export
def collate_fn_with_text_transform(text_transform, batch):
    """Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform"""
    src_batch, tgt_batch, bert_vectors = [], [], []
    for src_sample, tgt_sample, bert_vector in batch:
        src_batch.append(text_transform["ocr"](src_sample))
        tgt_batch.append(text_transform["gs"](tgt_sample))
        bert_vectors.append(bert_vector)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    # Size of encoder_hidden should be 1 x batch_size x hidden_size
    encoder_hidden = torch.unsqueeze(torch.stack(bert_vectors, dim=0), dim=0)

    return src_batch.to(torch.int64), tgt_batch.to(torch.int64), encoder_hidden


def collate_fn(text_transform):
    """Function to collate data samples into batch tensors"""
    return partial(collate_fn_with_text_transform, text_transform)

In [None]:
# | hide
# Can we loop over the entire dataset?
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"
vocab_transform = generate_vocabs(data.query('dataset == "test"'))
text_transform = get_text_transform(vocab_transform)

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
dataloader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn(text_transform))

num_samples = 0
for batch in dataloader:
    print(len(batch))
    print(batch[2].shape)

    num_samples += batch[0].shape[1]
assert num_samples == len(dataset)

3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])
3
torch.Size([1, 5, 768])


## Training

In [None]:
# | export
def validate_model(model, dataloader, device):
    cum_loss = 0
    cum_examples = 0

    was_training = model.training
    model.eval()

    with torch.no_grad():
        for src, tgt, encoder_hidden in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = encoder_hidden.to(device)

            batch_size = src.size(1)

            example_losses, _decoder_ouputs = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()

            bl = batch_loss.item()
            cum_loss += bl
            cum_examples += batch_size

    if was_training:
        model.train()

    return cum_loss / cum_examples

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 768
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)

In [None]:
split_name = "val"

val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss

24.723804897732204

In [None]:
# | export
def train_model(
    train_dl: DataLoader[int],
    val_dl: DataLoader[int],
    model: SimpleCorrectionSeq2seq,
    optimizer: torch.optim.Optimizer,
    num_epochs: int = 5,
    valid_niter: int = 5000,
    model_save_path: Path = Path("model.rar"),
    max_num_patience: int = 5,
    max_num_trial: int = 5,
    lr_decay: float = 0.5,
    device: torch.device = torch.device("cpu"),
) -> pd.DataFrame:
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist: List[float] = []
    train_loss_hist: List[float] = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs + 1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt, encoder_hidden in tqdm(train_dl):
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = encoder_hidden.to(device)

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                train_loss = report_loss / report_examples
                logger.info(
                    f"Epoch {epoch}, iter {num_iter}, avg. train loss "
                    + f"{train_loss}, avg. val loss {val_loss}"
                )

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    logger.info(f"Saving model and optimizer to {model_save_path}")
                    torch.save(
                        {
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        model_save_path,
                    )
                elif patience < max_num_patience:
                    patience += 1
                    logger.info(f"Hit patience {patience}")

                    if patience == max_num_patience:
                        num_trial += 1
                        logger.info(f"Hit #{num_trial} trial")
                        if num_trial == max_num_trial:
                            logger.info("Early stop!")
                            sys.exit()

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]["lr"] * lr_decay
                        logger.info(
                            f"Load best model so far and decay learning rate to {lr}"
                        )

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint["model_state_dict"])
                        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

                        model = model.to(device)

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr

                        # reset patience
                        patience = 0

                val_loss_hist.append(val_loss)
                train_loss_hist.append(train_loss)

    # Create train log
    df = pd.DataFrame({"train_loss": train_loss_hist, "val_loss": val_loss_hist})
    return df

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
split_name = "train"
train = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

split_name = "val"
val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 768
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model_bert_vectors.rar"

train_log = train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)
train_log

  0%|          | 0/13 [00:00<?, ?it/s]


RuntimeError: Placeholder storage has not been allocated on MPS device!

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()