In [1]:
# import modules
import json
import torch
import tokenization

import numpy as np

from transformer.model import Translation

In [2]:
# read configuration file
config = json.load(open('config.json'))
config

{'kor_vocab_length': 50000,
 'eng_vocab_length': 28998,
 'd_model': 768,
 'd_ff': 2048,
 'd_k': 64,
 'd_v': 64,
 'num_layers': 12,
 'num_heads': 8,
 'start_word': '[SOS]',
 'end_word': '[EOS]',
 'sep_word': '[SEP]',
 'cls_word': '[CLS]',
 'pad_word': '[PAD]',
 'mask_word': '[MASK]'}

In [3]:
# configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
src_tokenizer = tokenization.FullTokenizer(
    vocab_file='vocab/kor_vocab.txt', do_lower_case=False)
tgt_tokenizer = tokenization.FullTokenizer(
    vocab_file='vocab/eng_vocab.txt', do_lower_case=False)
print(tgt_tokenizer.tokenize('I love you'))
print(src_tokenizer.tokenize('나는 너를 사랑한다'))


['I', 'love', 'you']
['나는', '너를', '사랑', '##한다']


In [5]:
# 'Bible Coloring'은 성경의 아름다운 이야기를 체험 할 수 있는 컬러링 앱입니다. -> Bible Coloring' is a coloring application that allows you to experience beautiful stories in the Bible.
# 씨티은행에서 일하세요? -> Do you work at a City bank?
# 11장에서는 예수님이 이번엔 나사로를 무덤에서 불러내어 죽은 자 가운데서 살리셨습니다. -> In Chapter 11 Jesus called Lazarus from the tomb and raised him from the dead.

src_text = [
    '\'Bible Coloring\'은 성경의 아름다운 이야기를 체험 할 수 있는 컬러링 앱입니다.',
    '씨티은행에서 일하세요?', '11장에서는 예수님이 이번엔 나사로를 무덤에서 불러내어 죽은 자 가운데서 살리셨습니다.']
tgt_text = [
    'Bible Coloring\' is a coloring application that allows you to experience beautiful stories in the Bible.',
    'Do you work at a City bank?', 'In Chapter 11 Jesus called Lazarus from the tomb and raised him from the dead.']

enc_length = 40
tgt_length = 40

enc_inputs = []
for text in src_text:
    tokens = src_tokenizer.tokenize(text)
    tokens = tokens[:enc_length]
    while len(tokens) < enc_length:
        tokens.append(config['pad_word'])
    enc_input = src_tokenizer.convert_tokens_to_ids(tokens)
    enc_inputs.append(enc_input)

dec_inputs, dec_outputs = [], []
for text in tgt_text:
    tokens = [config['start_word']]
    tokens.extend(tgt_tokenizer.tokenize(text)[:tgt_length])
    tokens.append(config['end_word'])
    
    dec_input = tokens[:-1]
    dec_output = tokens[1:]

    while len(dec_input) < tgt_length + 1:
        dec_input.append(config['pad_word'])
    while len(dec_output) < tgt_length + 1:
        dec_output.append(config['pad_word'])

    dec_input = tgt_tokenizer.convert_tokens_to_ids(dec_input)
    dec_output = tgt_tokenizer.convert_tokens_to_ids(dec_output)

    dec_inputs.append(dec_input)
    dec_outputs.append(dec_output)

enc_inputs_tensor = torch.as_tensor(enc_inputs, dtype=torch.long).to(device)
dec_inputs_tensor = torch.as_tensor(dec_inputs, dtype=torch.long).to(device)
dec_outputs_tensor = torch.as_tensor(dec_outputs, dtype=torch.long).to(device)

print(enc_inputs_tensor)
print(dec_inputs_tensor)
print(dec_outputs_tensor)

tensor([[47488, 47728,  1722,  2272, 21191,  2540, 13330, 13201,  3807,  7195,
         47488, 47459, 47492, 27321, 22775, 21385,  9246,  4037, 47558, 47467,
          1634, 33068,  7923, 48117, 22183, 47440,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [25413, 30745, 47471, 24155, 47774,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [ 1612,   213, 46686,     3,  1914,  1253, 47491,  7942,  9246, 47555,
         15713, 10552, 47668,  1239, 26620, 38688, 47461, 11548,   151,     3,
         47440,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])
tensor([[    1,  5907, 13068,  1160,   114,  111

In [6]:
## configure model, optimizer, criterion
src_pad_index = src_tokenizer.convert_tokens_to_ids([config['pad_word']])[0]
tgt_pad_index = tgt_tokenizer.convert_tokens_to_ids([config['pad_word']])[0]

transformer = Translation(
    src_vocab_size=config['kor_vocab_length'],
    tgt_vocab_size=config['eng_vocab_length'],
    d_model=config['d_model'], d_ff=config['d_ff'],
    d_k=config['d_k'], d_v=config['d_v'], n_heads=config['num_heads'], 
    n_layers=config['num_layers'], src_pad_index=src_pad_index,
    tgt_pad_index=src_pad_index, device=device).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(transformer.parameters(), lr=5e-5)
transformer

 (enc_self_attn): MultiHeadAttention(
          (WQ): Linear(in_features=768, out_features=512, bias=True)
          (WK): Linear(in_features=768, out_features=512, bias=True)
          (WV): Linear(in_features=768, out_features=512, bias=True)
          (linear): Linear(in_features=512, out_features=768, bias=True)
          (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (pos_ffn): PoswiseFeedForwardNet(
          (l1): Linear(in_features=768, out_features=2048, bias=True)
          (l2): Linear(in_features=2048, out_features=768, bias=True)
          (relu): GELU()
          (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
      (5): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (WQ): Linear(in_features=768, out_features=512, bias=True)
          (WK): Linear(in_features=768, out_features=512, bias=True)
          (WV): Linear(in_features=768, out_features=512, bias=True)
          (lin

In [9]:
for epoch in range(100):

    optimizer.zero_grad()
    dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns = transformer(enc_inputs_tensor, dec_inputs_tensor)
    loss = criterion(
        dec_logits.view(-1, dec_logits.size(-1)),
        dec_outputs_tensor.contiguous().view(-1))
    loss.backward()
    optimizer.step()

    print(epoch, loss.item())

0 4.10388708114624
1 3.678687334060669
2 3.5360214710235596
3 3.406212091445923
4 3.2002038955688477
5 3.103100061416626
6 2.9335062503814697
7 2.7363855838775635
8 2.6097939014434814
9 2.4988210201263428
10 2.375528335571289
11 2.242288589477539
12 2.1073808670043945
13 1.9714640378952026
14 1.8304811716079712
15 1.6902967691421509
16 1.5418894290924072
17 1.386192798614502
18 1.2296375036239624
19 1.0829826593399048
20 0.9372422099113464
21 0.7987121343612671
22 0.6721315383911133
23 0.5591924786567688
24 0.4610476791858673
25 0.3770865797996521
26 0.30839186906814575
27 0.2527344822883606
28 0.20801861584186554
29 0.17255273461341858
30 0.144961878657341
31 0.12296938896179199
32 0.1052861139178276
33 0.09093289077281952
34 0.0792354866862297
35 0.06968825310468674
36 0.06177357956767082
37 0.055134180933237076
38 0.04958122596144676
39 0.04493553936481476
40 0.04098932817578316
41 0.037585072219371796
42 0.034641340374946594
43 0.03210744261741638
44 0.029924551025032997
45 0.02802

In [11]:
test_sentences = [
    '\'Bible Coloring\'은 성경의 아름다운 이야기를 체험 할 수 있는 컬러링 앱입니다.',
    '씨티은행에서 일하세요?', '11장에서는 예수님이 이번엔 나사로를 무덤에서 불러내어 죽은 자 가운데서 살리셨습니다.']

for test_sentence in test_sentences:
    orig_text = test_sentence
    print(orig_text)
    test_sentence = src_tokenizer.tokenize(test_sentence)
    test_sentence_ids = src_tokenizer.convert_tokens_to_ids(test_sentence)
    enc_token = torch.as_tensor([test_sentence_ids], dtype=torch.long)

    test_sentence_dec = ['[SOS]']
    test_sentence_dec = tgt_tokenizer.convert_tokens_to_ids(test_sentence_dec)
    eos_flag = tgt_tokenizer.convert_tokens_to_ids(['[EOS]'])

    while test_sentence_dec[-1] != eos_flag[0] or len(test_sentence_dec) > 50:
        dec_input = torch.as_tensor([test_sentence_dec], dtype=torch.long)
        enc_token, dec_input = enc_token.to(device), dec_input.to(device)
        with torch.no_grad():
            dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns = transformer(enc_token, dec_input)
        predict = torch.argmax(dec_logits, axis=2)[:, -1].squeeze().detach().cpu().numpy()
        test_sentence_dec.append(int(predict))

    predict_text = ' '.join(tgt_tokenizer.convert_ids_to_tokens(test_sentence_dec))
    predict_text = predict_text.replace(" ##", "")
    predict_text = predict_text.replace("##", "")
    print(f'orig_text    : {orig_text}')
    print(f'predict_text : {predict_text}')
    print('-----------------')

'Bible Coloring'은 성경의 아름다운 이야기를 체험 할 수 있는 컬러링 앱입니다.
orig_text    : 'Bible Coloring'은 성경의 아름다운 이야기를 체험 할 수 있는 컬러링 앱입니다.
predict_text : [SOS] Bible Coloring ' is a coloring application that allows you to experience beautiful stories in the Bible . [EOS]
-----------------
씨티은행에서 일하세요?
orig_text    : 씨티은행에서 일하세요?
predict_text : [SOS] Do you work at a City bank ? [EOS]
-----------------
11장에서는 예수님이 이번엔 나사로를 무덤에서 불러내어 죽은 자 가운데서 살리셨습니다.
orig_text    : 11장에서는 예수님이 이번엔 나사로를 무덤에서 불러내어 죽은 자 가운데서 살리셨습니다.
predict_text : [SOS] In Chapter 11 Jesus called Lazarus from the tomb and raised him from the dead . [EOS]
-----------------
