In [1]:
import os

import torch

from tokenizers import BertWordPieceTokenizer

import sys
# Add the ptdraft folder path to the sys.path list
sys.path.append('..')

from models.transformer_model import TransformerModel

from services.data_service import DataService
from services.vocabulary_service import VocabularyService
from services.metrics_service import MetricsService
from services.log_service import LogService
from services.tokenizer_service import TokenizerService
from services.file_service import FileService
from services.pretrained_representations_service import PretrainedRepresentationsService

In [2]:
class ArgumentService:
    def __init__(self):
        self.values = {
            'device': 'cuda',
            'sentence_piece_vocabulary_size': 30522,
            'hidden_dimension': 32,
            'number_of_layers': 1,
            'number_of_heads': 1,
            'dropout': 0,
            'data_folder': '../data',
            'challenge': 'ocr',
            'configuration': 'transformer-sequence',
            'language': 'english',
            'pretrained_weights': 'bert-base-cased',
            'metric_types': ['jaccard-similarity', 'levenshtein-distance'],
            'checkpoint_folder': None,
            'output_folder': 'results'
        }
    
    def get_argument(self, key: str) -> object:
        return self.values[key]


In [3]:
arg_service = ArgumentService()

device = arg_service.get_argument('device')

data_service = DataService()
file_service = FileService(
    arg_service)

vocabulary_service = VocabularyService(
    data_service=data_service,
    file_service=file_service)

metrics_service = MetricsService()

log_service = LogService(
    arguments_service=arg_service,
    external_logging_enabled=False)

tokenizer_service = TokenizerService(
    arguments_service=arg_service,
    file_service=file_service)

pretrained_representations_service = PretrainedRepresentationsService(
    include_pretrained=True,
    pretrained_model_size=768,
    pretrained_weights='bert-base-cased',
    pretrained_max_length=512,
    device='cuda')

Loaded vocabulary


In [4]:
model = TransformerModel(
    arguments_service=arg_service,
    data_service=data_service,
    vocabulary_service=vocabulary_service,
    metrics_service=metrics_service,
    log_service=log_service,
    tokenizer_service=tokenizer_service).to(device)

In [5]:
checkpoints_path = os.path.join('..', file_service.get_checkpoints_path())
a = model.load(checkpoints_path, 'BEST')

Loaded BEST_checkpoint


In [6]:
vocab_path = os.path.join('..', 'data', 'vocabularies', 'bert-base-cased-vocab.txt')
tokenizer = BertWordPieceTokenizer(vocab_path, lowercase=False)

ocr_text = "bk. I, 70 AN ENGLISH ANTHOLOGY. And earthly power doth then show likest God's When mercy seasons justice. Therefore, Jew, Though justice be thy plea, consider this -That in the course of justice, none of us Should see salvation we do pray for mercy And that same prayer doth teach us all to render The deeds of mercy. 1596. - Merchant of Venice, iv. 1 LVII. THE POWER OF MUSIC. How sweet the moonlight sleeps upon this bank ! Here will we sit, and let the sounds of music Creep in our ears soft stillness and the night Become the touches of sweet harmony. Sit, Jessica. Look how the floor of heaven Is thick inlaid with patines of bright gold There's not the smallest orb which thou behold'st, But in his motion like an angel sings, Still quiring to the young-eyed cherubins Such harmony is in immortal souls But whilst this muddy vesture of decay Doth grossly close it in, we cannot hear it. Enter Musicians. Come, ho ! and wake Diana with a hymn With sweetest touches pierce your mistress' ear. And draw her home with music. yes. I'm never merry when I hear sweet music. Lor. The reason is, your spirits are attentive And, do but note a wild and wanton herd Or race ot youthful and unhandled colts, Fetching mad bounds, bellowing and neighing loud Which is the hot condition of their blood"


In [7]:
trg = " earthly power doth then show likest God's When mercy seasons justice. Therefore, Jew, Though justice be thy plea, consider this-That in the course of justice, none of us Should see salvation we do pray for mercy And that same prayer doth teach us all to render The deeds of mercy. 1596. -Merchant of Venice, iv. 1. LVII. THE POWER OF MUSIC. HOW sweet the moonlight sleeps upon this bank ! Here will we sit, and let the sounds of music Creep in our ears soft stillness and the night Become the touches of sweet harmony. Sit, Jessica. Look how the floor of heaven Is thick inlaid with patines of bright gold There's not the smallest orb which thou behold'st, But in his motion like an angel sings, Still quiring to the young-eyed cherubins Such harmony is in immortal souls But whilst this muddy vesture of decay Doth grossly close it in, we cannot hear it. Enter Musicians. Come, ho ! and wake Diana with a hymn With sweetest touches pierce your mistress' ear. And draw her home with music. Yes. I'm never merry when I hear sweet music. Lor. The reason is, your spirits are attentive And, do but note a wild and wanton herd Or race of youthful and unhandled colts, Fetching mad bounds, bellowing and neighing loud, Which is the hot condition of their blood"
trg_tokens = tokenizer.encode(ocr_text).ids
# print(trg_tokens)

ocr_v = vocabulary_service.string_to_ids(ocr_text)
# print(ocr_v)

trg_v = vocabulary_service.string_to_ids(trg)
# print(trg_v)

In [9]:
max_len = len(ocr_text) + 10
model.eval()

tokens = tokenizer.encode(ocr_text).ids
# print(tokens)

src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)

src_mask = model.make_src_mask(src_tensor)

def _split_to_chunks(list_to_split: list, chunk_size: int, overlap_size: int):
        result = [list_to_split[i:i+chunk_size]
                  for i in range(0, len(list_to_split), chunk_size-overlap_size)]
        return result
    
def _get_pretrained_representation(ocr_aligned):
    ocr_aligned_splits = [ocr_aligned]
    if len(ocr_aligned) > 512:
        ocr_aligned_splits = _split_to_chunks(
            ocr_aligned, chunk_size=512, overlap_size=2)

    pretrained_outputs = torch.zeros(
        (len(ocr_aligned_splits), 512, 768)).to(device)
    
    for i, ocr_aligned_split in enumerate(ocr_aligned_splits):
        ocr_aligned_tensor = torch.Tensor(
            ocr_aligned_split).unsqueeze(0).long().to(device)
        pretrained_output = pretrained_representations_service.get_pretrained_representation(
            ocr_aligned_tensor)

        _, output_length, _ = pretrained_output.shape

        pretrained_outputs[i, :output_length, :] = pretrained_output

    pretrained_result = pretrained_outputs.view(
        -1, 768)

    return pretrained_result

pretrained_representation = _get_pretrained_representation(tokens)[:len(tokens)].unsqueeze(0)
# print(pretrained_representation.shape)

with torch.no_grad():
    enc_src = model.encoder(src_tensor, src_mask, pretrained_representation)
    
# trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
trg_indexes = [vocabulary_service.cls_token]
# trg_indexes = trg_v

for i in range(max_len):

    trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

    trg_mask = model.make_trg_mask(trg_tensor)

    with torch.no_grad():
        output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
    
    pred_token = output.argmax(2)[:,-1].item()
    
    trg_indexes.append(pred_token)

    if pred_token == vocabulary_service.eos_token:
        print('EOS token predicted... Breaking...')
        break

# trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
trg_tokens = vocabulary_service.ids_to_string(trg_indexes)
print(trg_tokens)

# return trg_tokens[1:], attention

[CLS]The ATE AThe The The AThe AThe AThe ATE ATE AThe TE The ATE AThe The te The TE The The The  ATE The The AThe  The AThe The The The  Me  AThe The TE TE ATE AThe Mre TE The AThe ATE  Mre  The AThe The The AThe ATE te AThe A ATE  The AThe ATE AThe The The ATE The  AThe Aan he he he he he he he he ore he he he he he f he te he oure f he he he he te he he he he he he te he he he or he he ore he he he te he he he f te te he he he he te he he he he he te he he he he he f he f he pre he he he te he he he he he he he he he he he he he f he he he te ore he oure he he te he he he he ore he or he te he he he he te f he he he he an he he he ore on he he he he ore f he he he he he he he he w he f he te he ore he he te he f he he he ore he f he he te he he he ore he on he he he he oure he he he he he f he he pre he he f he he he he he he he te f he he he he he he he he he f he he te an te te f he te te te te f he te w te te te te w w te te he he te wan te w he te te te he te te f he te w te he t