In [None]:
# | default_exp bert_vectors_correction_data

In [None]:
# | export
from functools import partial

import h5py
import numpy as np
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from ocrpostcorrection.error_correction import (
    PAD_IDX,
    generate_vocabs,
    get_text_transform,
)

In [None]:
# | hide
import os
from pathlib import Path

from torch.utils.data import DataLoader

In [None]:
#| export
class BertVectorsCorrectionDataset(Dataset):
    def __init__(self, data: pd.DataFrame, bert_vectors_file: Path, split_name: str, max_len: int=11):
        ds = data.copy()
        ds.reset_index(drop=True, inplace=True)
        ds = ds.query(f'len_ocr < {max_len}').query(f'len_gs < {max_len}').copy()
        ds.reset_index(drop=False, inplace=True)
        self.ds = ds

        f = h5py.File(bert_vectors_file, "r")
        self.bert_vectors = f.get(split_name)

    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        original_idx = sample["index"]
        bert_vector = torch.as_tensor(np.array(self.bert_vectors[original_idx]))

        return [char for char in sample.ocr], [char for char in sample.gs], bert_vector

The sample bert vectors have been generated using `python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5` (from ocrpostcorrection-notebooks, model from [9099e78](https://github.com/jvdzwaan/ocrpostcorrection-notebooks/commit/9099e785177a5c5207d01d80422e68d30f39636d))

In [None]:
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"), 
    bert_vectors_file=bert_vectors_file, 
    split_name=split_name
)

In [None]:
# | export
def collate_fn_with_text_transform(text_transform, batch):
    """Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform"""
    src_batch, tgt_batch, bert_vectors = [], [], []
    for src_sample, tgt_sample, bert_vector in batch:
        src_batch.append(text_transform["ocr"](src_sample))
        tgt_batch.append(text_transform["gs"](tgt_sample))
        bert_vectors.append(bert_vector)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch.to(torch.int64), tgt_batch.to(torch.int64), torch.stack(bert_vectors, dim=1)


def collate_fn(text_transform):
    """Function to collate data samples into batch tensors"""
    return partial(collate_fn_with_text_transform, text_transform)

In [None]:
#| hide
# Can we loop over the entire dataset?
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"
vocab_transform = generate_vocabs(data.query('dataset == "test"'))
text_transform = get_text_transform(vocab_transform)

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"), 
    bert_vectors_file=bert_vectors_file, 
    split_name=split_name
)
dataloader = DataLoader(
    dataset, batch_size=5, collate_fn=collate_fn(text_transform)
)

num_samples = 0
for batch in dataloader:
    
    num_samples += batch[0].shape[1]
assert num_samples == len(dataset)

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()