<a href="https://colab.research.google.com/github/lunaB/Pytorch-Study/blob/master/15_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
import random
import datetime as dt

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True

# Batch Size를 작게 설정하는 경우 GPU가 CPU보다 학습 시간이 느린것을 확인. (이 상황에선 cpu가 더 빠름)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [None]:
raw = ["안녕 난 영채야\thi my name is youngchae",
       "안녕 넌 참 신기하다\thi you are so amazing",
       "와 정말 신기하다\twow so amazing",
       "영채야 이거 봐봐\tlook at this youngchae"]

SOS_token = 0
EOS_token = 1

In [None]:
class Vocab:
    def __init__(self):
        self.vocab2index = {"<SOS>": SOS_token, "<EOS>": EOS_token}
        self.index2vocab = {SOS_token: "<SOS>", EOS_token: "<EOS>"}
        self.vocab_count = {}
        self.n_vocab = len(self.vocab2index)

    def add_vocab(self, sentence):
        for word in sentence.split(" "):
            if word not in self.vocab2index:
                self.vocab2index[word] = self.n_vocab
                self.vocab_count[word] = 1
                self.index2vocab[self.n_vocab] = word
                self.n_vocab += 1
            else:
                self.vocab_count[word] += 1

In [None]:
# 하나라도 max length를 초과하면 false 리턴 preprocess에서 초과하는 데이터를 pair셋에 포함시키지 않기위해 사용
def filter_pair(pair, source_max_length, target_max_length):
    return len(pair[0].split(" ")) < source_max_length and len(pair[1].split(" ")) < target_max_length

In [None]:
def preprocess(corpus, source_max_length, target_max_length):
    print("reading corpus...")
    pairs = []
    for line in corpus:
        pairs.append([s for s in line.strip().lower().split("\t")])
    print("Read {} sentence pairs".format(len(pairs)))

    pairs = [pair for pair in pairs if filter_pair(pair, source_max_length, target_max_length)]
    print("Trimmed to {} sentence pairs".format(len(pairs)))

    source_vocab = Vocab()
    target_vocab = Vocab()

    print("Counting words...")
    for pair in pairs:
        source_vocab.add_vocab(pair[0])
        target_vocab.add_vocab(pair[1])
    print("source vocab size =", source_vocab.n_vocab)
    print("target vocab size =", target_vocab.n_vocab)

    return pairs, source_vocab, target_vocab

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, x, hidden):
        x = self.embedding(x).view(1, 1, -1)
        x, hidden = self.gru(x, hidden)
        return x, hidden

In [None]:
# attention 부분을 직접 짜서 조금 지저분함
# dot product attention으로 만듬
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size * 2, output_size)
        self.softmax = nn.Softmax(dim=0)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden, encoder_hiddens):
        x = self.embedding(x).view(1, 1, -1)
        x, hidden = self.gru(x, hidden)

        att = torch.Tensor([])
        for i in range(len(encoder_hiddens)):
            # print(encoder_hiddens[i])
            # print(x[0].T)
            att = torch.cat((att, encoder_hiddens[i].matmul(x[0].T)))

        att = self.softmax(att)
        
        # print(att)
        # print(encoder_hiddens)

        att = encoder_hiddens * att.view(-1,1,1)
        # print(att)

        att_val = att.sum(0)
        # print(att_val)
        # print(x)
        
        x = torch.cat((att_val, x[0]), dim=1)
        # print(x)

        x = self.out(x)
        # print(x)
        x = self.logsoftmax(x)

        return x, hidden

In [None]:
def tensorize(vocab, sentence):
    indexes = [vocab.vocab2index[word] for word in sentence.split(" ")]
    indexes.append(vocab.vocab2index["<EOS>"])
    return torch.Tensor(indexes).long().to(device).view(-1, 1)

In [None]:
def train(pairs, source_vocab, target_vocab, encoder, decoder, n_iter, print_every=1000, learning_rate=0.01):
    loss_total = 0

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

    training_batch = [random.choice(pairs) for _ in range(n_iter)]
    training_source = [tensorize(source_vocab, pair[0]) for pair in training_batch]
    training_target = [tensorize(target_vocab, pair[1]) for pair in training_batch]

    # softmax 한것을 nllloss로계산하면 crossentropy
    criterion = nn.NLLLoss()

    for i in range(1, n_iter + 1):
        source_tensor = training_source[i - 1]
        target_tensor = training_target[i - 1]

        encoder_hidden = torch.zeros([1, 1, encoder.hidden_size]).to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        source_length = source_tensor.size(0)
        target_length = target_tensor.size(0)

        loss = 0

        encoder_outputs = torch.Tensor([])

        for enc_input in range(source_length):
            encoder_output, encoder_hidden = encoder(source_tensor[enc_input], encoder_hidden)
            encoder_outputs = torch.cat((encoder_outputs, encoder_output))
            # encoder_outputs.append(encoder_output)

        # print(encoder_outputs)

        decoder_input = torch.Tensor([[SOS_token]]).long().to(device)
        decoder_hidden = encoder_hidden # connect encoder output to decoder input

        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)

            # print(decoder_output)
            # print(target_tensor[di])

            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # teacher forcing

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        loss_iter = loss.item() / target_length
        loss_total += loss_iter

        if i % print_every == 0:
            loss_avg = loss_total / print_every
            loss_total = 0
            print("[{} - {}%] loss = {:05.4f}".format(i, i / n_iter * 100, loss_avg))

In [None]:
def evaluate(pairs, source_vocab, target_vocab, encoder, decoder, target_max_length):
    for pair in pairs:
        print(">", pair[0])
        print("=", pair[1])
        source_tensor = tensorize(source_vocab, pair[0])
        source_length = source_tensor.size()[0]
        encoder_hidden = torch.zeros([1, 1, encoder.hidden_size]).to(device)

        
        encoder_outputs = torch.Tensor([])
        for ei in range(source_length):
            encoder_output, encoder_hidden = encoder(source_tensor[ei], encoder_hidden)
            encoder_outputs = torch.cat((encoder_outputs, encoder_output))

        decoder_input = torch.Tensor([[SOS_token]], device=device).long()
        decoder_hidden = encoder_hidden
        decoded_words = []

        for di in range(target_max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _, top_index = decoder_output.data.topk(1)
            if top_index.item() == EOS_token:
                decoded_words.append("<EOS>")
                break
            else:
                decoded_words.append(target_vocab.index2vocab[top_index.item()])

            decoder_input = top_index.squeeze().detach()

        predict_words = decoded_words
        predict_sentence = " ".join(predict_words)
        print("<", predict_sentence)
        print("")

In [None]:
SOURCE_MAX_LENGTH = 10
TARGET_MAX_LENGTH = 10

In [None]:
load_pairs, load_source_vocab, load_target_vocab = preprocess(raw, SOURCE_MAX_LENGTH, TARGET_MAX_LENGTH)
print(random.choice(load_pairs))

reading corpus...
Read 4 sentence pairs
Trimmed to 4 sentence pairs
Counting words...
source vocab size = 12
target vocab size = 15
['영채야 이거 봐봐', 'look at this youngchae']


In [None]:
enc_hidden_size = 16
dec_hidden_size = enc_hidden_size
enc = Encoder(load_source_vocab.n_vocab, enc_hidden_size).to(device)
dec = Decoder(dec_hidden_size, load_target_vocab.n_vocab).to(device)

In [None]:
train(load_pairs, load_source_vocab, load_target_vocab, enc, dec, 5000, print_every=1000)

[1000 - 20.0%] loss = 0.6425
[2000 - 40.0%] loss = 0.0351
[3000 - 60.0%] loss = 0.0147
[4000 - 80.0%] loss = 0.0090
[5000 - 100.0%] loss = 0.0065


In [None]:
evaluate(load_pairs, load_source_vocab, load_target_vocab, enc, dec, TARGET_MAX_LENGTH)

> 안녕 난 영채야
= hi my name is youngchae
< hi my name is youngchae <EOS>

> 안녕 넌 참 신기하다
= hi you are so amazing
< hi you are so amazing <EOS>

> 와 정말 신기하다
= wow so amazing
< wow so amazing <EOS>

> 영채야 이거 봐봐
= look at this youngchae
< look at this youngchae <EOS>



In [None]:
# load_pairs, load_source_vocab, load_target_vocab = preprocess(raw, SOURCE_MAX_LENGTH, TARGET_MAX_LENGTH)
# print(random.choice(load_pairs))

# enc_hidden_size = 16
# dec_hidden_size = enc_hidden_size
# enc = Encoder(load_source_vocab.n_vocab, enc_hidden_size).to(device)
# dec = Decoder(dec_hidden_size, load_target_vocab.n_vocab).to(device)

# train(load_pairs, load_source_vocab, load_target_vocab, enc, dec, 1, print_every=1000)

In [None]:
# test
a = torch.Tensor([[[-0.2807, -0.0313, -0.2060, -0.5300, -0.0359, -0.0324,  0.1610,
           0.0359,  0.1713, -0.0643,  0.1804,  0.1991, -0.0182,  0.0163,
           0.1620, -0.0905]],

        [[ 0.3218,  0.2815, -0.3996, -0.0347, -0.3250, -0.5432, -0.0847,
           0.2177, -0.2206, -0.4593,  0.0988,  0.0595,  0.1314, -0.2499,
          -0.0862, -0.2302]],

        [[ 0.3090,  0.2358, -0.3276, -0.0723, -0.0751,  0.0968, -0.3588,
           0.1281, -0.1979, -0.3368,  0.0836,  0.3466,  0.0074, -0.4496,
           0.1435, -0.3111]],

        [[ 0.3193,  0.3354,  0.2523,  0.0485, -0.2001, -0.5193, -0.2616,
          -0.4318, -0.0128,  0.1569,  0.2528, -0.0342,  0.2377, -0.4790,
          -0.3276, -0.0365]]])
b = torch.Tensor([[[1]],[[2]],[[3]],[[4]]])
a*b

tensor([[[-0.2807, -0.0313, -0.2060, -0.5300, -0.0359, -0.0324,  0.1610,
           0.0359,  0.1713, -0.0643,  0.1804,  0.1991, -0.0182,  0.0163,
           0.1620, -0.0905]],

        [[ 0.6436,  0.5630, -0.7992, -0.0694, -0.6500, -1.0864, -0.1694,
           0.4354, -0.4412, -0.9186,  0.1976,  0.1190,  0.2628, -0.4998,
          -0.1724, -0.4604]],

        [[ 0.9270,  0.7074, -0.9828, -0.2169, -0.2253,  0.2904, -1.0764,
           0.3843, -0.5937, -1.0104,  0.2508,  1.0398,  0.0222, -1.3488,
           0.4305, -0.9333]],

        [[ 1.2772,  1.3416,  1.0092,  0.1940, -0.8004, -2.0772, -1.0464,
          -1.7272, -0.0512,  0.6276,  1.0112, -0.1368,  0.9508, -1.9160,
          -1.3104, -0.1460]]])