In [1]:
import torch
import torch.nn.functional as F

import transformers
from models.seq_seq import Config, Transformer

In [2]:
tokenizer = transformers.ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator")
special_tokens_dict = {'additional_special_tokens': ['[EOS]']}
tokenizer.add_special_tokens(special_tokens_dict)

config = Config('models/seq_seq_config.json')
model = Transformer(config)
model.load_state_dict(torch.load('ckpt/best_16.80705.pt'))

<All keys matched successfully>

In [3]:
def translate_sentence(sentence, model, device, max_len = 50):
    
    model.eval()
        
    tokens = tokenizer.tokenize(sentence)
    tokens = [tokenizer.cls_token] + tokens + ['[EOS]']
    src_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
    src_mask = [1] * len(src_ids)
    n_pad = 64 - len(src_ids)
    src_ids.extend([0] * n_pad)
    src_mask.extend([0] * n_pad)

    src_ids = torch.tensor(src_ids, dtype=torch.long)
    src_ids = src_ids.unsqueeze(0)
    src_mask = torch.tensor(src_mask, dtype=torch.long)
    src_mask = src_mask.unsqueeze(0)
    
    src_mask = src_mask[:src_ids.size(0), None, None, :]
    src_mask = src_mask.to(dtype=src_ids.dtype)  # fp16 compatibility
    src_mask = (1.0 - src_mask) * -10000.0
    
    with torch.no_grad():
        src_emd = model.embedding(src_ids)
        enc_src = model.encoder(src_emd, src_mask)

    tokens = [tokenizer.cls_token]
    trg_indexes = [tokenizer.convert_tokens_to_ids(t) for t in tokens]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

        trg_mask = [1] * len(trg_indexes)
        trg_mask = torch.tensor(trg_mask, dtype=torch.long)
        trg_mask = trg_mask.unsqueeze(0)
        
        trg_mask = (1.0 - trg_mask) * -10000.0
        
        with torch.no_grad():
            trg_emd = model.embedding(trg_tensor)
            output = model.decoder(trg_emd, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == 35000:
            break
            
    return trg_indexes

In [10]:
text = "넌 누구야?"
translation = translate_sentence(text, model, 'cpu')
print(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(translation, skip_special_tokens=True)))

저는 위로봇입니다 .
