In [None]:
# Load dataset
# Load BERT Model
# Sample indices
# For every index:
#   - Get BERT output
#   - Load BERT output in dataset
#   - Compare

In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

In [None]:
!pip install transformers datasets

## Load data

In [None]:
import pandas as pd
from pathlib import Path

in_dir = Path('/mntDrive/MyDrive/icdar-dataset-20220207')
#in_dir = Path('icdar-dataset-20220207')

train = pd.read_csv(in_dir/'task2_train.csv', index_col=0)
val = pd.read_csv(in_dir/'task2_val.csv', index_col=0)
test = pd.read_csv(in_dir/'task2_test.csv', index_col=0)

train = train.fillna('')
val = val.fillna('')
test = test.fillna('')

In [None]:
print('train:', train.shape[0], 'samples')
print('val:', val.shape[0], 'samples')
print('test:', test.shape[0], 'samples')

In [None]:
def add_lens(data: pd.DataFrame) -> pd.DataFrame:
    data['len_ocr'] = data['ocr'].apply(lambda x: len(x))
    data['len_gs'] = data['gs'].apply(lambda x: len(x))

    return data

train = add_lens(train)
val = add_lens(val)
test = add_lens(test)

In [None]:
import torch
from torch.utils.data import Dataset

class Task2Dataset(Dataset):
    def __init__(self, data, task1_data_dir, max_len=11, batch_size=8):
        self.ds = data.query(f'len_ocr < {max_len}').query(f'len_gs < {max_len}').copy()
        self.ds = self.ds.reset_index(drop=False)

        self.task1_data_dir = task1_data_dir
        self.batch_size = batch_size


    def __len__(self):
        return self.ds.shape[0]

    def __getitem__(self, idx):
        sample = self.ds.loc[idx]
        original_idx = sample['index']
        print('original idx', original_idx)

        file_index = original_idx // self.batch_size
        index_in_file = original_idx % self.batch_size
        in_file = self.task1_data_dir/f'task2_task1_output_{file_index}.pt'
        task1_output_batch = torch.load(in_file)
        # Copy the task1_ouput slice, so we have a new tensor
        task1_output = task1_output_batch[index_in_file].clone().detach().requires_grad_(True)

        return sample.ocr, sample.gs, task1_output

In [None]:
from pathlib import Path

out_dir = Path('/mntDrive/MyDrive/icdar-dataset-20220207')

#out_dir = Path('icdar-dataset-20220207')
data_dir = out_dir/'task1_output'/'test'

ds = Task2Dataset(test, data_dir, max_len=11, batch_size=128)

In [None]:
for ocr, gs, hidden_input in ds:
    print(ocr, gs, hidden_input[:3])
    break

## Load BERT

In [None]:
model_dir = '/mntDrive/MyDrive/results-0.3-20220207-no-checkpoints'
#model_dir = '/Users/janneke/models/results-0.3-20220207'
model_name = 'bert-base-multilingual-cased'

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
from transformers import BertModel

model = BertModel.from_pretrained(model_name)
model.eval();

## Sample indices

In [None]:
import numpy as np

def get_indices(dataset_size, num=10):
    return np.random.choice(dataset_size, size=num, replace=False)

indices = get_indices(len(ds))
indices

In [None]:
with torch.no_grad():
    for idx in indices:
        sample = ds[idx]
        ocr = sample[0]
        tokenized_ocr = tokenizer(ocr, return_tensors="pt")
        output = model(tokenized_ocr['input_ids'])
        print(ocr, tokenized_ocr)
        expected = output['pooler_output'].detach().cpu()[0].numpy()

        actual = sample[2].requires_grad_(False).numpy()

        num = 5
        print('expected:', expected[:num])
        print('actual:', actual[:num])
        # print('np equal', np.equal(expected, actual))
        print('np allclose', np.allclose(expected, actual))

        # output = model(tokenized_ocr['input_ids'])
        # print(ocr, tokenized_ocr)
        # expected2 = output['pooler_output'].detach().cpu()[0]

        # print(expected.size())
        # print(actual.size())
        print(expected[:10])
        print(actual[:10])

        # print(torch.equal(expected, expected2))
        # print(torch.allclose(expected, expected2))
        # print(torch.sum(torch.eq(expected, expected2)).item())
        # print(torch.eq(expected, expected2))

        # print(torch.equal(expected, actual))
        # print(torch.allclose(expected, actual))
        # print(torch.sum(torch.eq(expected, actual)).item())
        # print(torch.eq(expected, actual))

        break