In [114]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data


transformer和lstm最大的区别就是lstm是串行的，而transformer是并行计算的。
Transformer 使用了位置嵌入 (Positional Encoding) 来理解语言的顺序，使用自注意力机制（Self Attention Mechanism）和全连接层进行计算
在Transformer里面我们主要有两个模块，一个encoder负责编码，一个decoder负责解码，这两个coder各有6个layer
encoderlayer只有两个sublayer，而decoderlayer有3个sublayer，如下图所示
![_Z_1B_FXSWI6O3RB~F26_8Q.png](https://i.loli.net/2021/10/07/1gklxAnyU7hSFeu.png)

In [115]:
#数据初始化
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length

# Transformer Parameters
d_model = 512  #Embedding Size (token和位置信息 的编码维度) 字嵌入和位置嵌入的维度，他俩维度相同
d_ff = 2048 #前馈神经网络的隐藏层维度是2048 512->2048->512
d_k = d_v = 64  #V Q K 维度相同
n_layers = 6  #Encoder 和 Decoder中block的个数
n_heads = 8  #Multihead 中 head的个数

In [116]:
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)


class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  
  def __len__(self):
    return self.enc_inputs.shape[0]
  
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

因为Transformer是一次全部输入进去，没有考虑到词语的顺序问题，例如I saw a saw,对于Transformer来说，saw 和 saw是一样的，没有考虑位置信息，
所以我们要增加一个位置嵌入的概念，也就是 Positional Encoding，位置嵌入的维度为 [max_sequence_length, embedding_dimension], 位置嵌入的维度与词向量的维度是相同的，都是 embedding_dimension。max_sequence_length 属于超参数，指的是限定每个句子最长由多少个词构成
最后我们将位置嵌入向量和词向量相加，作为新的输入，这样这个输入就考虑了位置关系，这里就是单纯的相加
位置嵌入向量的计算公式如下所示

![22P5_7_D__~08Y__Z0J7_K8.png](https://i.loli.net/2021/10/07/XfypKAhLQbDdEna.png)
上式中的pos是指一句话中某个词的位置，范围是[0,max_sequence_length),i是指这个字向量的维度序号，[0,embedding_dimension/2),d_model就是embedding_dimension的维度

贴一个链接关于位置嵌入的理解的博客

https://wmathor.com/index.php/archives/1453/

In [117]:
#根据论文里的公式进行复现，完全写死了
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

下面这个函数就是去掉句子中没用的字符

pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) 这是这里面比较核心的一句话
看seq_k里面有没有等于0的值

如果 seq_k 某个位置的值等于 0，那么对应位置就是 True，否则即为 False。举个例子，输入为 seq_data = [1, 2, 3, 4, 0]，seq_data.data.eq(0) 就会返回 [False, False, False, False, True]

In [118]:
#除掉句子中没用的字符
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    #每一个位置上的值和0比较，不是0就是F，是0就是True，再扩展一个维度，word_emb是三维的
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    #batchsize 表示这里有几句话 这里encoder和decoder都会调用
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

下面这一部分的mask是只有decoder才会用到的部分，因为在训练的时候我们需要屏蔽掉未来时刻的词向量，只关注已经出现了的
所以我们需要出现一个下三角是全零的矩阵

In [119]:
def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) #上三角不包括中轴线 保留值
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

下面的attn表示的就是a1和其他几个输入向量的关联程度，公式的主要计算是复现论文里的ScaledDotProductAttention
![_O@@GNAH_J_EK66JP__F_7D.png](https://i.loli.net/2021/10/07/Z6k5HLlhwsVXrCI.png)

Q*K 得到相关每个向量和其他向量的相关性矩阵，再把attn_mask矩阵中true的地方转换成负无穷，这样经过softmax之后才是0(dddd)
最后再乘V矩阵得到最终的矩阵

In [120]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        #乘上k的转置 变成[lenq,lenk]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        #把attn_mask矩阵中为True的地方替换为-1e9
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn


MultiHeadAttention的计算公式如下：

![0__0__~OSD_PMZTWTR23D_D.png](https://i.loli.net/2021/10/07/uaWNUtxYRZH4w1j.png)

MultiHeadAttention的计算步骤如下：

![XJE__O__Y`JC_Z1T6TW6KZQ.png](https://i.loli.net/2021/10/02/GrJ7ZObV8tcapMl.png)



1.在下面的部分，我们定义三个矩阵WQ,WK,WV，通过对位置词向量进行三次线性变换，我们得到三个新的向量

2.我们通过ScaledDotProductAttention计算attention和最后的输出context

3.最后我们把context传入全连接神经网络，再用层归一化进行处理。Layer Normalization 的作用是把神经网络中隐藏层归一为标准正态分布，以起到加快训练速度，加速收敛的作用

Layer Normalization 的作用：(然后用每一列的每一个元素减去这列的均值，再除以这列的标准差，从而得到归一化后的数值)

![DHR2S_1FX_ILO_@N_EVWKER.png](https://i.loli.net/2021/10/07/Su7PGsletynk9i4.png)

完整代码中一定会有三处地方调用 MultiHeadAttention()，Encoder Layer 调用一次，传入的 input_Q、input_K、input_V 全部都是 enc_inputs；Decoder Layer 中两次调用，第一次传入的全是 dec_inputs，第二次传入的分别是 dec_outputs，enc_outputs，enc_outputs

In [121]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        #下面这一步我们给他做一个维度的变换，这是为了后面计算好算
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]
        #中间增加的维度扩城到n_heads个
        #对于encoder这里的attn_mask就是去除的填充字符
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        outputs = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model)(outputs + residual), attn


In [122]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model)(output + residual) # [batch_size, seq_len, d_model]

这里的话我们可以总结一下Transformer Encoder的整体结构了

![W8Y_Q0_XE5U38L_0BQ__TZU.png](https://i.loli.net/2021/10/07/TIQ2JnmflkgRrHb.png)

In [123]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn


In [124]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        #单词转换成输入向量的维度
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])#_ 是占位符， 表示不在意变量 的 值 只是用于循环遍历n次

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        #这里不可以随便命名，这一次的输出，可以当作下一次的输入
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        #要把输入进去的多余字符p去掉，同时保证矩阵的大小
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

In [125]:
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn


In [126]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)#单词转换成输入向量的维度
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, tgt_len, d_model]
        #要把输入进去的多余字符p去掉，同时保证矩阵的大小
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]
        #如果为0 就是T，不是0就是F
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

![1H2PG_~T_P@4_5GX6XZ_9_U.png](https://i.loli.net/2021/10/07/65cZzHyMJEvd2mR.png)

In [127]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)#转换维度 因为我们要判断概率最大的值是哪一个
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        #最后这一步，.view(-1)表示把所有子列表拼成一个长列表 比如说有两句话，每句话有6个向量，这样的话就把这12个向量拼成一个数组，
        # 每个向量的大小是一个tgt_vocab_size
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns


In [128]:
model = Transformer()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

In [129]:
for epoch in range(30):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      #enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
      enc_inputs, dec_inputs, dec_outputs = enc_inputs, dec_inputs, dec_outputs
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))
      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

Epoch: 0001 loss = 2.572361
Epoch: 0002 loss = 2.348653
Epoch: 0003 loss = 2.054852
Epoch: 0004 loss = 1.753708
Epoch: 0005 loss = 1.580555
Epoch: 0006 loss = 1.340775
Epoch: 0007 loss = 1.140786
Epoch: 0008 loss = 0.945785
Epoch: 0009 loss = 0.756428
Epoch: 0010 loss = 0.648802
Epoch: 0011 loss = 0.542791
Epoch: 0012 loss = 0.389108
Epoch: 0013 loss = 0.286723
Epoch: 0014 loss = 0.244752
Epoch: 0015 loss = 0.200260
Epoch: 0016 loss = 0.141691
Epoch: 0017 loss = 0.121220
Epoch: 0018 loss = 0.091178
Epoch: 0019 loss = 0.079356
Epoch: 0020 loss = 0.058770
Epoch: 0021 loss = 0.054729
Epoch: 0022 loss = 0.045721
Epoch: 0023 loss = 0.042151
Epoch: 0024 loss = 0.037746
Epoch: 0025 loss = 0.042066
Epoch: 0026 loss = 0.027881
Epoch: 0027 loss = 0.028147
Epoch: 0028 loss = 0.020305
Epoch: 0029 loss = 0.021631
Epoch: 0030 loss = 0.013160


In [130]:
def greedy_decoder(model, enc_input, start_symbol):
    """
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:         
        dec_input=torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype)],-1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tgt_vocab["."]:
            terminal = True
        print(next_word)            
    return dec_input



In [131]:
# Test
enc_inputs, _, _ = next(iter(loader))
for i in range(len(enc_inputs)):
    greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
    predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
    predict = predict.data.max(1, keepdim=True)[1]
    print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

tensor(1)
tensor(2)
tensor(3)
tensor(4)
tensor(8)
tensor([1, 2, 3, 4, 0]) -> ['i', 'want', 'a', 'beer', '.']
tensor(1)
tensor(2)
tensor(3)
tensor(5)
tensor(8)
tensor([1, 2, 3, 5, 0]) -> ['i', 'want', 'a', 'coke', '.']
