# NN.TRANSFORMER로 언어 번역 (eng-fra.txt 사용)
https://tutorials.pytorch.kr/beginner/translation_transformer.html

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import os
import unicodedata
import re
import random
import math
import time
from tqdm import tqdm

In [2]:
UNK, PAD, BOS, EOS = 0, 1, 2, 3


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "<unk>", 1: "<pad>", 2: "<bos>", 3: "<eos>"}
        self.n_words = 4  # SOS 와 EOS 포함

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [24]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s


MAX_LENGTH = 30

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # 파일을 읽고 줄로 분리
    lines = open('../data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # 모든 줄을 쌍으로 분리하고 정규화
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # 쌍을 뒤집고, Lang 인스턴스 생성
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


def to_vec(lang, s, max_len):
    vec = [BOS] + [lang.word2index[word] for word in s.split()] + [EOS]
    vec += [PAD] * (max_len - len(vec))
    return torch.tensor(vec).view(-1, 1)


def to_sentence(lang, vec):
    return " ".join(lang.index2word[i.item()] for i in vec)


def src_tgt_pair(input_lang, output_lang, pair, src_len=32, tgt_len=32):
    i, o = pair
    return to_vec(input_lang, i, src_len), to_vec(output_lang, o, tgt_len)


def get_batch(input_lang, output_lang, pairs, batch_size=64):
    src, tgt = [], []
    for pair in random.choices(pairs, k=batch_size):
        s, t = src_tgt_pair(input_lang, output_lang, pair)
        src.append(s)
        tgt.append(t)
    return torch.cat(src, 1), torch.cat(tgt, 1)


def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz), diagonal=1)
    mask[torch.where(mask == 1)] = float("-inf")
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros(src_seq_len, src_seq_len).type(torch.bool)

    src_padding_mask = (src == PAD).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


In [4]:
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

Reading lines...
Read 135842 sentence pairs
Trimmed to 13067 sentence pairs
Counting words...
Counted words:
fra 5173
eng 3391
['je pense cloturer mon compte d epargne .', 'i am thinking of closing my savings account .']


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size,
                 dropout,
                 maxlen):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class NN(nn.Module):
    def __init__(self, n_src_vocab, n_tgt_vocab, n_emb, dropout=0.1, max_seq_size=32):
        super(NN, self).__init__()
        self.src_emb = TokenEmbedding(n_src_vocab, n_emb)
        self.tgt_emb = TokenEmbedding(n_tgt_vocab, n_emb)
        self.pos_enc = PositionalEncoding(n_emb, dropout, max_seq_size)
        self.transformer = nn.Transformer(d_model=n_emb, nhead=4, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=512)
        self.generator = nn.Linear(n_emb, n_tgt_vocab)
        
    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, memory_key_pad_mask):
        src_emb = self.pos_enc(self.src_emb(src))
        tgt_emb = self.pos_enc(self.tgt_emb(tgt))
        out = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_pad_mask, tgt_pad_mask, memory_key_pad_mask)
        return self.generator(out)
    
    def encode(self, src, src_mask):
        return self.transformer.encoder(
            self.pos_enc(self.src_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(
            self.pos_enc(self.tgt_emb(tgt)), memory, tgt_mask)

In [7]:
n_src_vocab = input_lang.n_words
n_tgt_vocab = output_lang.n_words
n_emb = 128
model = NN(n_src_vocab, n_tgt_vocab, n_emb)

n_src_vocab, n_tgt_vocab

(5173, 3391)

In [10]:
n_epochs = 2
n_itr = 100
b_size = 64
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [11]:
s_time = time.time()

for epoch in range(n_epochs):
    total_loss = 0
    for iter in tqdm(range(n_itr)):
        src, tgt = get_batch(input_lang, output_lang, pairs, b_size)
        tgt_in = tgt[:-1, :]
        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src, tgt_in)
        tgt_out = tgt[1:, :]
        optimizer.zero_grad()
        out = model(src, tgt_in, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, src_pad_mask)
        loss = loss_fn(out.reshape(-1, out.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss
    torch.save(model.state_dict(), "transformer.pth")
    print("\n%d %.3fs loss - %.4f" % (epoch, time.time() - s_time, total_loss / n_itr))

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:36<00:00,  2.72it/s]



0 36.848s loss - 4.9024


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:37<00:00,  2.64it/s]


1 74.755s loss - 3.9639





In [26]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long)
    for i in range(max_len-1):
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool))
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS:
            break
    return ys


# 입력 문장을 도착어로 번역하는 함수
def translate(model, input_lang, output_lang, src_sentence):
    src_sentence = normalizeString(src_sentence)
    src = to_vec(input_lang, src_sentence, 32)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS).flatten()
    return to_sentence(output_lang, tgt_tokens).replace("<bos>", "").replace("<eos>", "")

In [28]:
translate(model, input_lang, output_lang, "Je ne m'en vais pas .")

' i m not . '