# Sequence to Sequence Learning with Neural Networks [1]

## Part 1. 논문 정리와 모델 구현

# Summary

Deep Neural Networks (DNNs) 는 입력과 출력 데이터를 고정 길이의 벡터로 인코딩하는 방식으로 음성 인식과 객체 탐지 등의 분야에서 우수한 성과를 내고 있었다. 그러나 이러한 방법론은 길이가 제각각인 시퀀스 처리에는 도입하기 어렵다는 한계가 있었다. 해당 논문에서는 일반적인 시퀀스 처리에 도입할 수 있도록 두 개의 Recurrent Neural Networks (RNNs) 을 사용한 네트워크 구조를 제안하였다.

이 네트워크는 두 개의 RNNs로 이루어져 있다. 하나는 한 번에 한 단위씩 (논문에서는 단어 기준) 읽는 방식으로 여러 번에 걸쳐 하나의 시퀀스를 입력받아 고정 길이의 벡터로 인코딩한다. 다른 하나는 인코딩된 벡터를 읽고 한 번에 한 단위씩 출력하는 방식으로 하나의 시퀀스를 만들어낸다.

해당 논문에서는 Long Short-Term Memory (LSTM) 이 긴 시간 동안의 의존성이 필요한 문제에서 유용하다는 점에 착안하여 네트워크의 RNN 구조로 LSTM을 선택하였다. (그리고 [2]에서 제안한 LSTM을 사용하였다.)

그리고 제안한 방법을 WMT'14 English to French Machine Translation 작업에 적용하였다. 그 결과 전체 테스트 세트에서 34.8에 달하는 BLEU 점수를 얻었다.

<figure align="center">
  <img src="https://drive.google.com/uc?export=view&id=1dDZOhoDTm_m8kgyXr2qDZ67UQ2wSmU3X" width=300 />
  <figcaption>Encoder and Decoder Architecture</figcaption>
</figure>

<figure align="center">
  <img src="https://drive.google.com/uc?export=view&id=1fwFTSzQm0yqZD_rWsISIePdg4UDsUcyo" width=900 />
  <figcaption>Example of Sequence to Sequence Network Flow</figcaption>
</figure>

# Models

In [14]:
import torch
import torch.nn as nn

## LSTM Module

<figure align="center">
  <img src="https://drive.google.com/uc?export=view&id=19Au229q45Z6hCb6qMFepzRKsl_4RrRyw" width=400 />
  <figcaption>LSTM Architecture [2]</figcaption>
</figure>

$$
i_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) \\
f_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) \\
c_t = f_t c_{t-1} + i_t \tanh(W_{xc}x_t + W_{hc}h_{t-1} + b_c) \\
o_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + W_{co}c_t + b_o) \\
h_t = o_t \tanh(c_t) \\
\quad\\
\text{The above equations are equivalent to:} \\
\quad\\
i_t = \sigma(W_i \cdot [x_t, h_{t-1}, c_{t-1}] + b_i) \\
f_t = \sigma(W_f \cdot [x_t, h_{t-1}, c_{t-1}] + b_f) \\
c_t = f_t c_{t-1} + i_t \tanh(W_c \cdot [x_t, h_{t-1}] + b_c) \\
o_t = \sigma(W_o \cdot [x_t, h_{t-1}, c_t] + b_o) \\
h_t = o_t \tanh(c_t)
$$
<div align="center">Equations of LSTM [2]</div>

In [15]:
class LSTMLayer(nn.Module):

  def __init__(self, input_size, hidden_size, dtype=torch.float, device='cpu'):
    super(LSTMLayer, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.factory_kwargs = {'dtype': dtype, 'device': device}

    # cell_size == hidden_size
    self.linear_write = nn.Linear(input_size + 2 * hidden_size, hidden_size,
                                  **self.factory_kwargs)
    self.linear_forget = nn.Linear(input_size + 2 * hidden_size, hidden_size,
                                  **self.factory_kwargs)
    self.linear_cell = nn.Linear(input_size + hidden_size, hidden_size,
                                  **self.factory_kwargs)
    self.linear_output = nn.Linear(input_size + 2 * hidden_size, hidden_size,
                                  **self.factory_kwargs)

  def forward(self, input, states=None):
    """Args:
        input: torch.Tensor, [seq_len, input_size] or
          [seq_len, batch_size, input_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [hidden_size]  or [batch_size, hidden_size]
            cell: torch.Tensor, [hidden_size] or [batch_size, hidden_size]

    Return:
        output: torch.Tensor, [seq_len, hidden_size] or
            [seq_len, batch_size, hidden_size]
        states: a tuple of two torch.Tensor
            hidden: torch.Tensor, [hidden_size] or [batch_size, hidden_size]
            cell: torch.Tensor, [hidden_size] or [batch_size, hidden_size]
    """
    assert (2 <= len(input.shape) <= 3) and input.size(-1) == self.input_size, \
      "The shape of the `input` should be [seq_len, input_size] or " \
      "[seq_len, batch_size, input_size]"

    is_batched = len(input.shape) == 3
    if is_batched:
      seq_len, batch_size, _ = input.shape
      outputs = torch.zeros(seq_len, batch_size, self.hidden_size,
                            **self.factory_kwargs)
      if states is None:
        hidden = torch.zeros(batch_size, self.hidden_size,
                             **self.factory_kwargs)
        cell = torch.zeros(batch_size, self.hidden_size, **self.factory_kwargs)
      else:
        hidden, cell = states
    else:
      seq_len, _ = input.shape
      outputs = torch.zeros(seq_len, self.hidden_size, **self.factory_kwargs)
      if states is None:
        hidden = torch.zeros(self.hidden_size, **self.factory_kwargs)
        cell = torch.zeros(self.hidden_size, **self.factory_kwargs)
      else:
        hidden, cell = states

    assert (1 <= len(hidden.shape) <= 2) and \
      hidden.size(-1) == self.hidden_size, \
      "The shape of the `hidden` should be [hidden_size] or " \
      "[batch_size, hidden_size]"
    assert (1 <= len(cell.shape) <= 2) and \
      cell.size(-1) == self.hidden_size, \
      "The shape of the `cell` should be [hidden_size] or " \
      "[batch_size, hidden_size]"
    
    seq_len = input.size(0)
    for i in range(seq_len):
      # input becomes [input_size] or [batch_size, input_size]
      combined = torch.cat((input[i], hidden, cell), dim=len(input[i].shape)-1)
      write = torch.sigmoid(self.linear_write(combined))
      forget = torch.sigmoid(self.linear_forget(combined))

      combined = torch.cat((input[i], hidden), dim=len(input[i].shape)-1)
      cell = forget * cell + write * torch.tanh(self.linear_cell(combined))
      
      combined = torch.cat((input[i], hidden, cell),
                           dim=len(input[i].shape)-1)
      output = torch.sigmoid(self.linear_output(combined))
      hidden = output * torch.tanh(cell)
      outputs[i] = hidden

    return outputs, (hidden, cell)

In [16]:
class LSTM(nn.Module):

  def __init__(self, input_size, hidden_size, num_layers, dtype=torch.float, device='cpu'):
    super(LSTM, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.factory_kwargs = {'dtype': dtype, 'device': device}

    layers = [LSTMLayer(input_size, hidden_size, **self.factory_kwargs)] + \
      [LSTMLayer(hidden_size, hidden_size, **self.factory_kwargs)
      for _ in range(num_layers - 1)]
    self.layers = nn.ModuleList(layers)

  def forward(self, input, states=None):
    """Args:
        input: torch.Tensor, [seq_len, input_size] or
          [seq_len, batch_size, input_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]

    Return:
        output: torch.Tensor, [seq_len, hidden_size] or
            [seq_len, batch_size, hidden_size]
        states: a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
    """
    assert (2 <= len(input.shape) <= 3) and input.size(-1) == self.input_size, \
      "The shape of the `input` should be [seq_len, input_size] or " \
      "[seq_len, batch_size, input_size]"

    is_batched = len(input.shape) == 3
    if is_batched:
      seq_len, batch_size, _ = input.shape
      if states is None:
        hiddens = torch.zeros(self.num_layers, batch_size, self.hidden_size,
                             **self.factory_kwargs)
        cells = torch.zeros(self.num_layers, batch_size, self.hidden_size,
                           **self.factory_kwargs)
      else:
        hiddens, cells = states
    else:
      seq_len, _ = input.shape
      if states is None:
        hiddens = torch.zeros(self.num_layers, self.hidden_size,
                             **self.factory_kwargs)
        cells = torch.zeros(self.num_layers, self.hidden_size,
                           **self.factory_kwargs)
      else:
        hiddens, cells = states

    assert (2 <= len(hiddens.shape) <= 3) and \
      hiddens.size(0) == self.num_layers and \
      hiddens.size(-1) == self.hidden_size, \
      "The shape of the `hidden` should be [num_layers, hidden_size] or " \
      "[num_layers, batch_size, hidden_size]"
    assert (2 <= len(cells.shape) <= 3) and \
      cells.size(0) == self.num_layers and \
      cells.size(-1) == self.hidden_size, \
      "The shape of the `cell` should be [num_layers, hidden_size] or " \
      "[num_layers, batch_size, hidden_size]"
    
    next_hiddens = torch.zeros_like(hiddens)
    next_cells = torch.zeros_like(cells)
    
    output = input
    for i in range(self.num_layers):
      # hidden and cell are [hidden_size] or [batch_size, hidden_size]
      output, (hidden, cell) = self.layers[i](output, (hiddens[i], cells[i]))
      next_hiddens[i], next_cells[i] = hidden, cell

    return output, (next_hiddens, next_cells)

## Encoder Module

In [17]:
class Encoder(nn.Module):

  def __init__(self, input_size, embed_size, hidden_size, num_rnn_layers,
               padding_index, dtype=torch.float, device='cpu'):
    super(Encoder, self).__init__()
    self.input_size = input_size
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.num_rnn_layers = num_rnn_layers
    self.factory_kwargs = {'dtype': dtype, 'device': device}

    self.embedding = nn.Embedding(input_size, embed_size, padding_index,
                                  **self.factory_kwargs)
    self.rnn = LSTM(embed_size, hidden_size, num_rnn_layers,
                    **self.factory_kwargs)

  def forward(self, input, states=None):
    """Args:
        input: torch.Tensor, [seq_len] or [seq_len, batch_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_rnn_layers, hidden_size] or
                [num_rnn_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_rnn_layers, hidden_size] or
                [num_rnn_layers, batch_size, hidden_size]

    Return:
        output: torch.Tensor, [seq_len, hidden_size] or
            [seq_len, batch_size, hidden_size]
        states: a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_rnn_layers, hidden_size] or
                [num_rnn_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_rnn_layers, hidden_size] or
                [num_rnn_layers, batch_size, hidden_size]
    """
    embedded = self.embedding(input)
    output, (hidden, cell) = \
      self.rnn(embedded) if states is None else self.rnn(embedded, states)
    return output, (hidden, cell)

## Decoder Module

In [18]:
class Decoder(nn.Module):

  def __init__(self, embed_size, hidden_size, output_size, num_rnn_layers,
               padding_index, dtype=torch.float, device='cpu'):
    super(Decoder, self).__init__()
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.num_rnn_layers = num_rnn_layers
    self.factory_kwargs = {'dtype': dtype, 'device': device}

    input_size = output_size
    self.embedding = nn.Embedding(input_size, embed_size, padding_index,
                                  **self.factory_kwargs)
    self.rnn = LSTM(embed_size, hidden_size, num_rnn_layers,
                    **self.factory_kwargs)
    self.linear = nn.Linear(hidden_size, output_size, **self.factory_kwargs)
    # linear layer가 들어가는게 맞나?

  def forward(self, input, states=None, beam_size=1, max_len=50,
              teacher_forcing_ratio=0.):
    """Args:
        input: torch.Tensor, [seq_len] or [seq_len, batch_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
        beam_size (optional): a non-negative integer
        max_len (optional): a non-negative integer
        teacher_forcing_ratio (optional): a float number between 0 and 1

    Return:
        output: torch.Tensor, [max_len, output_size] or
            [max_len, batch_size, output_size]
        states: a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
    """
    #TODO: sample until all rows have more than one EOS
    #TODO: forward with beam search 구현
    use_beam_search = beam_size != 1
    if use_beam_search:
      raise NotImplementedError()
    else:
      # input.size(0) == target length
      if self.training: max_len = input.size(0)
      
      is_batched = len(input.shape) == 2
      if is_batched:
        outputs = torch.zeros(max_len, input.size(1), self.output_size,
                             **self.factory_kwargs)
      else:
        outputs = torch.zeros(max_len, self.output_size, **self.factory_kwargs)

      assert states is not None, "You should give hidden states and cell " \
        "states into the decoder"
      hidden, cell = states

      inputs = input
      input_shape = (1, input.size(1)) if is_batched else (1,)
      input = inputs[0].view(input_shape) # [1] or [1, batch_size]
      for i in range(1, max_len):
        embedded = self.embedding(input)
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        output = self.linear(output)
        outputs[i] = output.view(outputs.shape[1:])
        if self.training and torch.randn(1) < teacher_forcing_ratio:
          # use teacher forcing
          input = inputs[i].view(input_shape)
        else:
          # do not use teacher forcing
          input = output.argmax(len(inputs.shape)).view(input_shape)
          
      return outputs, (hidden, cell)

## A Whole Seq2Seq Module

In [19]:
class Seq2SeqNetwork(nn.Module):

  def __init__(self, input_size, embed_size, hidden_size, output_size,
               num_rnn_layers, padding_index, dtype=torch.float, device='cpu'):
    super(Seq2SeqNetwork, self).__init__()
    self.input_size = input_size
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.num_rnn_layers = num_rnn_layers
    self.factory_kwargs = {'dtype': dtype, 'device': device}

    self.encoder = Encoder(input_size, embed_size, hidden_size, num_rnn_layers,
                           padding_index, **self.factory_kwargs)
    self.decoder = Decoder(embed_size, hidden_size, output_size, num_rnn_layers,
                           padding_index, **self.factory_kwargs)

  def forward(self, src, trg, beam_size=1, max_len=50,
              teacher_forcing_ratio=0.):
    """Args:
        src: torch.Tensor, [src_len] or [src_len, batch_size]
        trg: torch.Tensor, [trg_len] or [trg_len, batch_size]
        beam_size (optional): a non-negative integer
        max_len (optional): a non-negative integer
        teacher_forcing_ratio (optional): a float number between 0 and 1

    Return:
        output: torch.Tensor, [trg_len, output_size] or
            [trg_len, batch_size, output_size]
    """
    _, (hidden, cell) = self.encoder(src)
    output, _ = self.decoder(trg, (hidden, cell), beam_size, max_len,
                             teacher_forcing_ratio=teacher_forcing_ratio)
    return output

  def encode(self, input, states=None):
    """Args:
        input: torch.Tensor, [seq_len] or [seq_len, batch_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]

    Return:
        output: torch.Tensor, [seq_len, hidden_size] or
            [trg_len, batch_size, hidden_size]
        states (optional): a tuple of two torch.Tensor
            hidden: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
            cell: torch.Tensor, [num_layers, hidden_size] or
                [num_layers, batch_size, hidden_size]
    """
    return self.encoder(input, states)

  def decode(self, input, states=None, beam_size=1, max_len=50,
            teacher_forcing_ratio=0.):
    """Args:
        input: torch.Tensor, [seq_len] or [seq_len, batch_size]
        beam_size (optional): a non-negative integer
        max_len (optional): a non-negative integer
        teacher_forcing_ratio (optional): a float number between 0 and 1

    Return:
        output: torch.Tensor, [max_len, output_size] or
            [max_len, batch_size, output_size]
    """
    output, _ = self.decoder(input, states, beam_size, max_len,
                             teacher_forcing_ratio)
    return output

# References

[1] Sequence to Sequence Learning with Neural Networks [[link]](
https://doi.org/10.48550/arXiv.1409.3215)

[2] Generating Sequences With Recurrent Neural Networks [[link]](
https://doi.org/10.48550/arXiv.1308.0850)