<a href="https://colab.research.google.com/github/dhdlswhd34/PlayGround/blob/main/train/seq2seq_attn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import unicode_literals, print_function, division
import torch.nn.functional as F
from torch import optim
import torch.nn as nn
from io import open
import unicodedata
import random
import string
import torch
import re


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SOS_token = 0
EOS_token = 1

In [3]:
ROOT_PATH = '/content/drive/MyDrive/Colab Notebooks'

In [4]:
class Lang:
  def __init__(self, name):

    self.name = name
    self.word2idx, self.word2cnt = {}, {}
    self.idx2word = {0 : 'SOS', 1 : 'EOS'}
    self.n_words = 2

  def add_sentence(self, senteces):
    for word in senteces.split(' '):
      self.add_word(word)

  def add_word(self, word):
    if word not in self.word2idx:
      self.word2idx[word] = self.n_words
      self.word2cnt[word] = 1
      self.idx2word[self.n_words] = word
      self.n_words += 1

    else:
      self.word2cnt[word] += 1

In [5]:
def preprocessing(string):

  ## 문자열 공백제거 ##
  string = string.strip()
  string = re.sub(r'[^ ㄱ-ㅣ가-힣.!?]+', r" ", string)
  return string


In [6]:
def read_texts(text_path, reverse = False):
  lines = open(text_path, encoding='utf-8').read().strip().split('\n')
  pairs = [[preprocessing(string) for string in line.split('\t\t\t')] for line in lines]

  if reverse:
    pairs = [list(reversed(pair)) for pair in pairs]
    input_corpus = Lang('dialect')
    output_corpus = Lang('standard')

  else:
    input_corpus = Lang('standard')
    output_corpus = Lang('dialect')

  return input_corpus, output_corpus, pairs

In [7]:
def prepare_dataset(text_path, reverse = False):
  input_corpus, output_corpus, pairs = read_texts(text_path, reverse)
  for pair in pairs:
    input_corpus.add_sentence(pair[0])
    output_corpus.add_sentence(pair[1])

  print(input_corpus.name, input_corpus.n_words)
  print(output_corpus.name, output_corpus.n_words)

  return input_corpus, output_corpus, pairs

text_path = f'{ROOT_PATH}/dataset/corpuses.txt'
dialect, standard, pairs = prepare_dataset(text_path, True)

dialect 266263
standard 251671


In [8]:
random_pair = random.choice(pairs)
print(random_pair)

['왜냐면 나는 어른들을 대할 때 쫌 어려워하는 펀인데', '왜냐면 나는 어른들을 대할 때 조금 어려워하는 펀인데']


In [9]:
class Encoder(nn.Module):

  def __init__(self, input_size, hidden_size):
    super(Encoder, self).__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(input_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size)

  def forward(self, input, hidden):
    embedded = self.embedding(input).view(1, 1, -1)
    output = embedded

    output, hidden = self.gru(output, hidden)
    return output, hidden

  def init_hidden(self):
    return torch.zeros(1, 1, self.hidden_size, device=device)

In [10]:
class Decoder(nn.Module):

  def __init__(self, hidden_size, output_size):
    super(Decoder, self).__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size)
    self.out = nn.Linear(hidden_size, output_size)

    self.softmax = nn.LogSoftmax(dim = 1)

  def forward(self, input, hidden):
    output = self.embedding(input).view(1, 1, -1)
    output = F.relu(output)
    output, hidden = self.gru(output, hidden)
    output = self.softmax(self.out(output[0]))
    return output, hidden_size

  def init_hidden(self):
    return torch.zeros(1, 1, self.hidden_size, device = device)

In [11]:
class AttnDecoder(nn.Module):

  def __init__(self, hidden_size, output_size, dropout_p = 0.1, max_length = 10):
    super(AttnDecoder, self).__init__()
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.dropout_p = dropout_p
    self.max_length = max_length

    self.embedding = nn.Embedding(self.output_size, self.hidden_size)
    self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
    self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
    self.dropout = nn.Dropout(self.dropout_p)
    self.gru = nn.GRU(self.hidden_size, self.hidden_size)
    self.out = nn.Linear(self.hidden_size, self.output_size)

  def forward(self, input, hidden, encoder_output):
    embedded = self.embedding(input).view(1, 1, -1)
    embedded = self.dropout(embedded)

    attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
    attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_output.unsqueeze(0))

    output = torch.cat((embedded[0], attn_applied[0]), 1)
    output, hidden = self.gru(output, hidden)

    output = F.log_softmax(self.out(output[0]), dim=1)
    return output, hidden, attn_weights


  def init_hidden(self):
    return torch.zeros(1, 1, self.hidden_size, device = device)

In [12]:
def idxs_from_sentence(lang, sentence):
  return [lang.word2idx[word] for word in sentence.split(' ')]

def tensor_from_sentence(lang, sentence):
  idxs = idxs_from_sentence(lang,sentence)
  idxs.append(EOS_token)
  return torch.tensor(idxs, dtype=torch.long, device=device).view(-1, 1)

def tensors_from_pair(pair):
  standard_tensor = tensor_from_sentence(standard, pair[1])
  dialect_tensor = tensor_from_sentence(dialect, pair[0])

  return (dialect_tensor, standard_tensor)


print(pairs[5][0], idxs_from_sentence(dialect, pairs[5][1]))

그러면 사실은 우리는 뭐 없어서 못 묵지. [46, 47, 48, 49, 50, 3, 21226]


In [23]:
teacher_forcing_ratio = 0.5

def train(input_tensor, label_tensor, encoder, decoder, encoder_optim,
          decoder_optim, loss_func, max_length = 10):
  
  encoder_hidden = encoder.init_hidden()

  encoder_optim.zero_grad()
  decoder_optim.zero_grad()

  input_length = input_tensor.size(0)
  label_length = label_tensor.size(0)

  encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device = device)
  loss = 0

  for ei in range(input_length):
    encoder_output, encoder_hidden = encoder(
        input_tensor[ei], encoder_hidden)
    encoder_outputs[ei] = encoder_output[0, 0]

  decoder_input = torch.tensor([[[SOS_token]]], device=device)
  decoder_hidden = decoder.init_hidden()

  print(f'\n\n\n {decoder_input.dim()} \n\n\n')
  use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

  ## teacher_forcing을 사용할 경우에는 레이블 값을 다음 입력으로 전달
  if use_teacher_forcing:
    for di in range(label_length):
      decoder_output, decoder_hidden, decoder_attn = decoder(
          decoder_input, decoder_hidden, encoder_outputs)
      loss += loss_func(decoder_output, label_tensor[di])

  ## teacher_forcing을 사용하지 않을 경우에는 자신의 예측을 다음 입력으로 전달
  else:
    for di in range(label_length):
      decoder_output, decoder_hidden, decoder_attn = decoder(
          decoder_input, decoder_hidden, encoder_outputs)
      topv, topi = decoder_output.topk(1)
      decoder_input = topi.squeeze().detach()

      loss += loss_func(decoder_output, label_tensor[di])
      if decoder_input.item() == EOS_token: break
  
  loss.backward()
  encoder_optim.step()
  decoder_optim.step()

  return loss.item() / label_length

In [24]:
import time, math

def as_minutes(seconds):
  minute = math.floor(secods / 60)
  seconds -= minute * 60
  return f'{minute}min {seconds}sec'

def time_since(since, percent):
  now = time.time()
  spend = now - since
  es = spend / percent
  rs = es - spend
  return f'{as_minutes(spend)} (- {as_minutes(rs)})'

In [25]:
def train_iters(encoder, decoder, epochs, lr = 1e-2, print_every = 1000, plot_every = 100):
  start = time.time()

  plot_losses = []
  print_loss_total = 0
  plot_loss_total = 0

  encoder_optim = optim.SGD(encoder.parameters(), lr = lr)
  decoder_optim = optim.SGD(decoder.parameters(), lr = lr)

  training_pairs = [tensors_from_pair(random.choice(pairs)) 
                    for _ in range(epochs)]
  print(training_pairs)
  loss_func = nn.NLLLoss()

  for epoch in range(1, epochs + 1):
    print(f'======= {epoch} start ====')
    training_pair = training_pairs[epoch - 1]
    input_tensor = training_pair[0]
    print(training_pair, type(input_tensor))
    label_tensor = training_pair[1]

    loss = train(input_tensor, label_tensor, encoder, decoder,
                 encoder_optim, decoder_optim, loss_func)
    
    if epoch % print_every == 0:
      print_loss_avg = print_loss_total / print_every
      print_loss_total = 0
      print(f'{time_since(start, epoch / epochs)} ({epoch} {epoch *100/ epochs}%) {print_loss_avg:.4f}')

    
    if epoch % plot_every == 0:
      plot_loss_avg = plot_loss_total / plot_every
      plot_losses.append(plot_loss_avg)
      plot_loss_total = 0

    show_plot(plot_losses)


In [26]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.switch_backend('agg')

def show_plot(points):
  plt.figure()
  fig, ax = plt.subplots()

  loc = ticker.MultipleLocator(base=0.2)
  ax.yaxis.set_major_locator(loc)
  plt.plot(points)

In [27]:
hidden_size = 256
encoder1 = Encoder(dialect.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoder(hidden_size, standard.n_words, dropout_p=0.1).to(device)

train_iters(encoder1, attn_decoder1, 75000, print_every = 5000)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



IndexError: ignored

In [21]:
  decoder_input = torch.tensor([[[SOS_token]]], device=device)
  decoder_input.dim()

3

In [29]:
torch.zeros(10, 1, device = device)

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')