### Seq2Seq

<img src="https://wikidocs.net/images/page/24996/인코더디코더모델.PNG" />

#### Teacher forcing

We can use ground truth as a decoder's input for parallelization. Otherwise the output value of a decoder is used as the next input by default.

In [372]:
import torch
import torch.nn as nn
import random

input_dim = 100
hidden_dim = 200
output_dim = 100

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_first=False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_first = batch_first

        self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=batch_first)
    
    def forward(self, x):
        _, hx = self.rnn(x)
        return hx

class Decoder(nn.Module): # actually a generator
    def __init__(self, input_dim, hidden_dim, output_dim, batch_first=False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.batch_first = batch_first
    
        self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=batch_first)
        self.linear = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, hx):
        x, hx = self.rnn(x, hx)
        output = self.linear(hx)
        return output, x, hx

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, batch_first=False):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, source, target, teacher_forcing_ratio=0.8):
        x = target[0].unsqueeze(0)
        hx = self.encoder(source)

        predicts = []
        for i in range(len(target)):
            output, x, hx = self.decoder(x, hx)
            predicts.append(output)

            if random.random() < teacher_forcing_ratio:
                x = target[i].unsqueeze(0)
            else:
                x = output
        predicts = torch.concat(predicts, dim=0)
        return predicts


encoder = Encoder(input_dim=100, hidden_dim=200)
decoder = Decoder(input_dim=100, hidden_dim=200, output_dim=100)

seq2seq = Seq2Seq(encoder, decoder)

source = torch.randn(56, 128, 100)
target = torch.randn(28, 128, 100)

seq2seq(source, target).shape

torch.Size([28, 128, 100])

### Beam search

$$
\argmax_{y} \prod_{t=1}^{T_y} p(y^{<t>} | x, y^{<1>}, \cdots, y^{<t-1>}) \\

\argmax_{y} \frac{1}{T_y^{\alpha}} \sum_{t=1}^{T_y} \log p(y^{<t>} | x, y^{<1>}, \cdots, y^{<t-1>}) \quad \text{where,}\ 0 \leq \alpha \leq 1
$$

In [373]:
import random
import numpy as np

chars = list(map(chr, range(65,91)))

def model(x: str) -> np.array:
    np.random.seed(hash(x) % 10**9)
    x = np.random.randn(len(chars))
    x = np.exp(x) / np.exp(x).sum()
    np.random.seed()
    return x # random softmax (proba) from the input x

# imagine something like a genetic algorithm selecting the top-k for each iterations
def beam(breadth=3, depth=10, log=True, alpha=1):
    choices = [("", np.log(1.0))] if log else [("", 1.0)]
    for d in range(depth):
        candidates = []
        for seq, proba in choices:
            preds = model(seq)
            if log:
                candidates.extend(zip([seq+c for c in chars], proba+np.log(preds)))
            else:
                candidates.extend(zip([seq+c for c in chars], proba*preds))
        candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
        choices = candidates[:breadth]
    if log:
        choices = [(seq, 1/depth**alpha * proba) for seq, proba in choices]
    return choices

# for b in range(1, len(chars)+1):
#     print(f"{b = :>2}", beam(breadth=b, depth=20, log=False)[0])

for b in range(1, len(chars)+1):
    print(f"{b = :>2}", beam(breadth=b, depth=20, log=True)[0])

b =  1 ('YASANOGDGDHQFPETDNRU', -1.798972939125831)
b =  2 ('YASACGAWCMNQGNZOVDTS', -1.6200365243245871)
b =  3 ('YASACGABGCHKPRBIELIR', -1.630091767739598)
b =  4 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b =  5 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b =  6 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b =  7 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b =  8 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b =  9 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 10 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 11 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 12 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 13 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 14 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 15 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 16 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 17 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 18 ('YASACGAWCMNCNAODDOIA', -1.5992422080213697)
b = 19 ('YASACGAWCMNCNAODDOIA', -1.5992422080213

#### Error analysis on beam search

Human ($y^*$): "Jane visits Africa in September." \
Model ($\hat y$): "Jane visited Africa last September."

$$
\begin{cases}
p(y^*|x) > p(\hat y|x) \rightarrow \text{beam search is at fault} \\
p(y^*|x) \leq p(\hat y|x) \rightarrow \text{model is at fault} \\
\end{cases}
$$