In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from glob import glob

sources = glob("/content/drive/MyDrive/ml-class-rhseung/data/kor-eng/*")
sources

['/content/drive/MyDrive/ml-class-rhseung/data/kor-eng/kor.txt',
 '/content/drive/MyDrive/ml-class-rhseung/data/kor-eng/_about.txt']

In [3]:
import string

string.punctuation

'!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'

In [4]:
with open(sources[0], 'r', encoding='utf-8') as f:
    contents = f.readlines()

contents[:5]

['Go.\t가.\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8363271 (Eunhee)\n',
 'Hi.\t안녕.\tCC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #8355888 (Eunhee)\n',
 'Run!\t뛰어!\tCC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #8355891 (Eunhee)\n',
 'Run.\t뛰어.\tCC-BY 2.0 (France) Attribution: tatoeba.org #4008918 (JSakuragi) & #8363273 (Eunhee)\n',
 'Who?\t누구?\tCC-BY 2.0 (France) Attribution: tatoeba.org #2083030 (CK) & #6820074 (yesjustryan)\n']

In [5]:
lines = []

with open(sources[0], 'r', encoding='utf-8') as f:
    for line in f:
        eng, kor = line.split('\t')[:2]

        def clear(s):
            # 전부 소문자로 바꾸고 특수 문자 제거
            return s.lower().strip(string.punctuation)

        lines.append((clear(eng), clear(kor)))

lines[:5]

[('go', '가'), ('hi', '안녕'), ('run', '뛰어'), ('run', '뛰어'), ('who', '누구')]

In [6]:
len(lines)

5870

In [7]:
lines[1000:1005]

[("i'm heartbroken", '제 마음이 아파요'),
 ("i'm just sleepy", '나 졸려'),
 ("i'm not sulking", '나 삐친 거 아니야'),
 ("i'm on the list", '나는 명단에 있다'),
 ('is tom with you', '톰이랑 같이 있어')]

In [8]:
# Bag Of Words
SOS = 0
EOS = 1

def get_BOW(sentences):
    BOW = {"<SOS>": SOS, "<EOS>": EOS}

    for sentence in sentences:
        for word in sentence.split():
            if word not in BOW:
                BOW[word] = len(BOW)

    return BOW

In [17]:
import torch
import numpy as np
from torch.utils.data import Dataset

class Eng2Kor(Dataset):
    def __init__(self, path, device='cpu'):
        self.device = device

        self.eng_sentences = []
        self.kor_sentences = []

        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                eng, kor = line.split('\t')[:2]

                def clear(s):
                    # 전부 소문자로 바꾸고 특수 문자 제거
                    return s.lower().strip(string.punctuation)

                eng, kor = clear(eng), clear(kor)

                # 단어 수가 너무 많아지면 메모리 사용량이 너무 커져서 10단어 이하의 짧은 문장만 저장
                if len(eng.split()) <= 10 and len(kor.split()) <= 10:
                    self.eng_sentences.append(eng)
                    self.kor_sentences.append(kor)

        self.eng_BOW = get_BOW(self.eng_sentences)
        self.kor_BOW = get_BOW(self.kor_sentences)

    def gen_seq(self, sentence):
        return sentence.split() + ['<EOS>']

    def __len__(self):
        return len(self.eng_sentences)

    def __getitem__(self, idx: int):
        x = np.array([self.eng_BOW[e] for e in self.gen_seq(self.eng_sentences[idx])])
        y = np.array([self.kor_BOW[e] for e in self.gen_seq(self.kor_sentences[idx])])

        x = torch.tensor(x, dtype=torch.long).to(self.device)
        y = torch.tensor(y, dtype=torch.long).to(self.device)

        return x, y

In [18]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()

        self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.gru = nn.GRU(hidden_dim, hidden_dim)

    def forward(self, x, h):
        x = self.embedding(x).reshape(1, 1, -1)
        output, h_next = self.gru(x, h)

        return output, h_next

In [19]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, max_length=11):
        super(Decoder, self).__init__()

        self.max_length = max_length

        # output_dim -> hidden_dim?
        self.embedding = nn.Embedding(output_dim, hidden_dim)
        self.attention = nn.Linear(hidden_dim*2, self.max_length)
        self.feature_extract = nn.Linear(hidden_dim*2, hidden_dim)
        self.dropout = nn.Dropout(p=0.1)
        self.gru = nn.GRU(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, h, encoder_outputs):
        x = self.embedding(x).reshape(1, 1, -1)
        x = self.dropout(x)

        concat = torch.cat((x[0], h[0]), dim=-1)
        attention_weights = self.softmax(self.attention(concat))
        attention_applied = (attention_weights @ encoder_outputs).unsqueeze(0)  # torch.bmm() 으로도 가능

        output = torch.cat((x[0], attention_applied[0]), dim=1)
        output = self.feature_extract(output).unsqueeze(0)
        output = self.relu(output)

        output, h_next = self.gru(output, h)

        output = self.out(output[0])

        return output

In [20]:
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

dataset = Eng2Kor(path=sources[0], device=device)
train_loader = DataLoader(dataset)

In [21]:
from torchsummary import summary

input_dim = len(dataset.eng_BOW)
hidden_dim = 64
output_dim = len(dataset.kor_BOW)

encoder = Encoder(input_dim, hidden_dim).to(device)
decoder = Decoder(hidden_dim, output_dim).to(device)

In [22]:
# summary(encoder, (1, 1, input_dim))

In [23]:
import torch.optim as optim

encoder_optim = optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optim = optim.Adam(decoder.parameters(), lr=1e-4)
cost = nn.CrossEntropyLoss().to(device)

In [24]:
from tqdm import tqdm
from random import random

for epoch in range(200):
    total_loss = 0

    with tqdm(train_loader) as loader:
        for x, y in loader:
            encoder_h = torch.zeros(1, 1, hidden_dim).to(device)
            encoder_outputs = torch.zeros(decoder.max_length, hidden_dim).to(device)

            encoder_optim.zero_grad()
            decoder_optim.zero_grad()

            for i in range(len(x)):
                encoder_output, encoder_h = encoder(x[i], encoder_h)
                encoder_outputs[i] = encoder_output[0, 0]

            decoder_input = torch.tensor([[SOS]]).to(device)
            decoder_h = encoder_h

            use_teacher_forcing = random() < 0.5
            loss = 0

            for i in range(len(y)):
                decoder_output = decoder(decoder_input, decoder_h, encoder_outputs)

                if not use_teacher_forcing:
                    maxidx = decoder_output.argmax()
                    decoder_input = maxidx.squeeze().detach()

                target = y[i].unsqueeze(0)
                loss += cost(decoder_output, target)

                if not use_teacher_forcing:
                    if decoder_input.item() == EOS:   # <EOS>
                        break
                else:
                    decoder_input = target


            loss.backward()
            encoder_optim.step()
            decoder_optim.step()

            total_loss += loss.item() / len(dataset)

            loader.set_postfix(loss=total_loss)

torch.save(encoder.state_dict(), "attention_encoder.pth")
torch.save(decoder.state_dict(), "attention_decoder.pth")

  0%|          | 0/5684 [00:00<?, ?it/s]


RuntimeError: ignored

In [None]:
from random import randint

encoder.load_state_dict(torch.load("attention_encoder.pth", map_location=device))
decoder.load_state_dict(torch.load("attention_decoder.pth", map_location=device))

idx = randint(0, len(dataset))
input_sentence = dataset.eng_sentences[idx]
pred_sentence = ""

x, y = dataset[idx]

In [None]:
encoder_h = torch.zeros(1, 1, hidden_dim).to(device)
encoder_outputs = torch.zeros(decoder.max_length, hidden_dim).to(device)

for i in range(len(x)):
    encoder_output, encoder_h = encoder(x[i], encoder_h)
    encoder_outputs[i] = encoder_output[0, 0]

In [None]:
decoder_input = torch.tensor([[SOS]]).to(device)
decoder_h = encoder_h

for i in range(decoder.max_length):
    decoder_output = decoder(decoder_input, decoder_h, encoder_outputs)

    maxidx = decoder_output.argmax()
    decoder_output = maxidx.squeeze().detach()

    if decoder_input.item() == EOS:
        break

    pred_sentence += list(dataset.kor_BOW.keys())[decoder_input] + " "

In [None]:
print(input_sentence)
print(pred_sentence)