In [8]:
from typing import List
import string

import torch
from tqdm import tqdm
from transformers import AutoTokenizer

In [13]:
tokenizer = AutoTokenizer.from_pretrained('tokenizer/khmer-vi-0.4')

In [14]:
def read_data(path: str):
    with open(path, mode='r', encoding='utf-8') as f:
        lines = f.read()
    lines = lines.split('\n')
    if lines[-1] == '':
        lines = lines[:-1]

    return lines

In [15]:
def convert_tokens(tokens: List[str]):
    res = []
    for i, token in enumerate(tokens):
        if token.startswith(chr(9601)):
            res.append(token[1:])
        elif token in string.punctuation:
            res.append(token)
        elif len(res) != 0:
            res[-1] += token
        else:
            res.append(token)

    return res

In [16]:
test_vi = read_data('training_data/valid.vi')
test_km = read_data('training_data/valid.km')

In [17]:
model_path = 'train/2022-08-03_14-21-51/bestLossModel_iter=85999.pt'
saved_point = torch.load(model_path, map_location=torch.device('cpu'))
model = saved_point['model']
model = model.to('cpu')

In [18]:
predict_idx = []
max_length = 128
for sent in tqdm(test_vi):
    src = tokenizer.encode(sent, add_special_tokens=True)
    src = torch.tensor([src])
    num_tokens = src.shape[1]

    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(tokenizer.bos_token_id).type(torch.long)

    for _ in range(max_length - 1):

        tgt_mask = model.generate_square_subsequent_mask(sz=ys.size(1))
        out = model.decode(ys, memory, tgt_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == tokenizer.eos_token_id:
            break
    predict_idx.append(ys)

100%|██████████| 2000/2000 [33:09<00:00,  1.01it/s] 


In [19]:
targets = []
for sent in tqdm(test_km):
    targets.append(tokenizer.tokenize(sent, add_special_tokens=False))

100%|██████████| 2000/2000 [00:00<00:00, 3657.28it/s]


In [20]:
def convert_tokens(tokens: List[str]):
    res = []
    for i, token in enumerate(tokens):
        if token.startswith(chr(9601)):
            res.append(token[1:])
        elif token in string.punctuation:
            res.append(token)
        else:
            res.append(token)

    return res

In [21]:
total_predicts = [tokenizer.convert_ids_to_tokens(i[0], skip_special_tokens=True) for i in predict_idx]

In [28]:
' '.join(total_predicts[0])

'ក្រៅពីនោះ ក្រុមហ៊ុន គ្រប់គ្រង បំណុល អាក្រក់ ក៏កំពុង ដោះស្រាយ បំណុល អាក្រក់ ដើម្បី ដោះស្រាយ បំណុល អាក្រក់ តាមរយៈ ដំណោះស្រាយ ផ្សេងៗ ដើម្បីដោះស្រាយ បំណុល អាក្រក់ ។'

In [22]:
total_predicts = [convert_tokens(i) for i in total_predicts]

In [23]:
targets = [[convert_tokens(i)] for i in targets]

In [24]:
from datasets import load_metric

In [25]:
bleu = load_metric("bleu")

In [26]:
bleu.compute(predictions=total_predicts, references=targets)

{'bleu': 0.05546354409392583,
 'precisions': [0.30650861285567155,
  0.09154233043624656,
  0.03543575210351721,
  0.01479813414830304],
 'brevity_penalty': 0.8955282735259068,
 'length_ratio': 0.9006238291835659,
 'translation_length': 49519,
 'reference_length': 54983}

In [70]:
bleu.compute(predictions=total_predicts, references=targets)

{'bleu': 0.011966238520574792,
 'precisions': [0.11082148625525723,
  0.01916310111839823,
  0.005569507799019356,
  0.0017335007871508677],
 'brevity_penalty': 1.0,
 'length_ratio': 1.225560520539354,
 'translation_length': 62533,
 'reference_length': 51024}