In [None]:
import spacy
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from preprocess import en_index2word, test

# -Translate- #
def translate_sentence(sentence, de_word2index, en_word2index, model, max_len=30, logging=True):
    model.eval()

    if isinstance(sentence, str):
        nlp = spacy.load('de_core_news_sm')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = ['<SOS>'] + tokens + ['<EOS>']
    if logging:
        print(f"전체 소스 토큰: {tokens}")

    src_indexes = [de_word2index.get(token, de_word2index['<UNK>']) for token in tokens]
    if logging:
        print(f"소스 문장 인덱스: {src_indexes}")

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)

    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [en_word2index['<SOS>']]

    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)

        trg_mask = model.make_trg_mask(trg_tensor)

        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)

        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)

        if pred_token == en_word2index['<EOS>']:
            break
    trg_tokens = [en_index2word[i] for i in trg_indexes]

    return trg_tokens[1:], attention

class CustomTestData:
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

test_data_de = CustomTestData(test['de'])
test_data_en = CustomTestData(test['en'])

def show_bleu(data, de_word2index, en_word2index, model, max_len=50):
    trgs = []
    pred_trgs = []
    index = 0

    smoothing = SmoothingFunction()

    for example_idx in range(1, 1000):
        src = test_data_de[example_idx]
        trg = test_data_en[example_idx]

        pred_trg, _ = translate_sentence(src, de_word2index, en_word2index, model, max_len, logging=False)

        # 마지막 토큰 제거
        pred_trg = pred_trg[:-1]

        pred_trgs.append(pred_trg)
        trgs.append([trg])

        index += 1
        if (index + 1) % 100 == 0:
            print(f"[{index + 1}/{len(data)}]")
            print(f"예측: {pred_trg}")
            print(f"정답: {trg}")

    # Calculate BLEU-4 score
    bleu_scores = []
    for i in range(len(trgs)):
        bleu = sentence_bleu(trgs[i], pred_trgs[i], weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing.method0)
        bleu_scores.append(bleu)

    avg_bleu = sum(bleu_scores) / len(bleu_scores)
    print(f'Average BLEU-4 Score = {avg_bleu*100:.2f}')