<a href="https://colab.research.google.com/github/cnhzgb/MachineL/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [133]:
!pip install ipdb
import torch
from torch import nn
from torch import optim
from torch.utils import data as Data
import numpy as np
import ipdb



In [134]:
d_model = 6 # embedding size
max_len = 1024 # max length of sequence
d_ff = 2048 # feedforward nerual network  dimension
d_k = d_v = 2 # dimension of k(same as q) and v
n_heads = 2 # number of heads in multihead attention
p_drop = 0.1 # propability of dropout
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [135]:
def get_attn_pad_mask(seq_q, seq_k):
  batch, len_q = seq_q.size() # 1, 4
  batch, len_k = seq_k.size() # 1, 4
  pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # 为0则为true，变为f,f,f,true，意思是把0这个结尾标志为true
  return pad_attn_mask.expand(batch, len_q, len_k) # 扩展为1*4*4，最后一列为true，表示抹掉结尾对应的注意力

def get_attn_subsequent_mask(seq):
  attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # [batch, target_len, target_len]
  subsequent_mask = np.triu(np.ones(attn_shape), k=1) # [batch, target_len, target_len]
  subsequent_mask = torch.from_numpy(subsequent_mask)
  return subsequent_mask # [batch, target_len, target_len]

In [136]:
class ScaledDotProductAttention(nn.Module):

  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()

  def forward(self, Q, K, V, attn_mask):
    scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # [batch, n_heads, len_q, len_k]
    scores.masked_fill_(attn_mask, -1e9)

    attn = nn.Softmax(dim=-1)(scores) # [batch, n_heads, len_q, len_k]
    prob = torch.matmul(attn, V) # [batch, n_heads, len_q, d_v]
    return prob


In [137]:
class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()
    self.n_heads = n_heads
    self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
    self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
    self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
    self.fc = nn.Linear(d_v * n_heads, d_model, bias=False)
    self.layer_norm = nn.LayerNorm(d_model)

  def forward(self, input_Q, input_K, input_V, attn_mask):
    residual, batch = input_Q, input_Q.size(0)

    Q = self.W_Q(input_Q).view(batch, -1, n_heads, d_k).transpose(1, 2) # [batch, n_heads, len_q, d_k]
    K = self.W_K(input_K).view(batch, -1, n_heads, d_k).transpose(1, 2) # [batch, n_heads, len_k, d_k]
    V = self.W_V(input_V).view(batch, -1, n_heads, d_v).transpose(1, 2) # [batch, n_heads, len_v, d_v]

    attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # [batch, n_heads, seq_len, seq_len]

    prob = ScaledDotProductAttention()(Q, K, V, attn_mask)

    prob = prob.transpose(1, 2).contiguous() # [batch, len_q, n_heads, d_v]
    prob = prob.view(batch, -1, n_heads * d_v).contiguous() # [batch, len_q, n_heads * d_v]

    output = self.fc(prob) # [batch, len_q, d_model]

    return self.layer_norm(residual + output)


In [138]:
class FeedForwardNetwork(nn.Module):
  '''
  Using nn.Conv1d replace nn.Linear to implements FFN.
  '''
  def __init__(self):
    super(FeedForwardNetwork, self).__init__()
    # self.ff1 = nn.Linear(d_model, d_ff)
    # self.ff2 = nn.Linear(d_ff, d_model)
    self.ff1 = nn.Conv1d(d_model, d_ff, 1)
    self.ff2 = nn.Conv1d(d_ff, d_model, 1)
    self.relu = nn.ReLU()

    self.dropout = nn.Dropout(p=p_drop)
    self.layer_norm = nn.LayerNorm(d_model)

  def forward(self, x):
    # x: [batch, seq_len, d_model]
    residual = x
    x = x.transpose(1, 2) # [batch, d_model, seq_len]
    x = self.ff1(x)
    x = self.relu(x)
    x = self.ff2(x)
    x = x.transpose(1, 2) # [batch, seq_len, d_model]

    return self.layer_norm(residual + x)

In [139]:
source_vocab = {'P' : 0, '我' : 1, '喜欢' : 2, '苹果' : 3}
target_vocab = {'P' : 0, 'i' : 1, 'like' : 2, 'apple' : 3, 'S' : 5}

encoder_input = torch.LongTensor([[1,2,3,0]]).to(device) # 我 喜欢 苹果 P, P代表结束词
decoder_input = torch.LongTensor([[5,1,2,3]]).to(device) # S i like apple, S代表开始词, 并表示右移一位，用于并行训练
decoder_output = torch.LongTensor([[1,2,3,0]]).to(device) # i like apple P, P代表结束词

In [140]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()
    self.source_embedding = nn.Embedding(len(source_vocab), d_model)
    self.attention = MultiHeadAttention()
    self.ffn = FeedForwardNetwork()

  def forward(self, encoder_input): # input 1 * 4 4个单词的编码
    embedded = self.source_embedding(encoder_input) # 1 * 4 * 6 将每个单词的整数字编码扩展到6个浮点数编码
    mask = get_attn_pad_mask(encoder_input, encoder_input) # 1 * 4 * 4的矩阵，最后一列为true，表示忽略结尾词的注意力机制
    encoder_output = self.attention(embedded, embedded, embedded, mask)
    encoder_output = self.ffn(encoder_output)

    return encoder_output

class Decoder(nn.Module):

  def __init__(self):
    super(Decoder, self).__init__()
    self.target_embedding = nn.Embedding(len(target_vocab), d_model)
    self.attention = MultiHeadAttention()
    self.ffn = FeedForwardNetwork()

  def forward(self, decoder_input, encoder_input, encoder_output):
    decoder_output = self.target_embedding(decoder_input)

    decoder_self_attn_mask = get_attn_pad_mask(decoder_input, decoder_input) # [batch, target_len, d_model]
    decoder_subsequent_mask = get_attn_subsequent_mask(decoder_input) # [batch, target_len, target_len]
    decoder_self_mask = torch.gt(decoder_self_attn_mask + decoder_subsequent_mask, 0)

    decoder_output = self.attention(decoder_input, decoder_input, decoder_input, decoder_self_mask)

    decoder_encoder_attn_mask = get_attn_pad_mask(decoder_input, encoder_input) # [batch, target_len, source_len]
    decoder_output = self.attention(decoder_output, encoder_output, encoder_output, decoder_encoder_attn_mask)

    decoder_output = self.ffn(decoder_output) # [batch, target_len, d_model]

    return decoder_output

In [141]:
class Transformer(nn.Module):
  def __init__(self):
    super(Transformer, self).__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()
    self.fc = nn.Linear(d_model, len(target_vocab), bias=False)

  def forward(self, encoder_input, decoder_input):
    encoder_output = self.encoder(encoder_input)
    decoder_output, decoder_self_attns, decoder_encoder_attns = self.decoder(decoder_input, encoder_input, encoder_output)
    decoder_logits = self.fc(decoder_output)

    return decoder_logits.view(-1, decoder_logits.size(-1))

In [142]:
model = Transformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-1)

for epoch in range(3):
  output = model(encoder_input, decoder_input)
  loss = criterion(output, decoder_output.view(-1))

  print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

> [0;32m<ipython-input-135-0b36bf5b0c61>[0m(3)[0;36mget_attn_pad_mask[0;34m()[0m
[0;32m      2 [0;31m  [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m  [0mbatch[0m[0;34m,[0m [0mlen_q[0m [0;34m=[0m [0mseq_q[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m  [0mbatch[0m[0;34m,[0m [0mlen_k[0m [0;34m=[0m [0mseq_k[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m<ipython-input-135-0b36bf5b0c61>[0m(4)[0;36mget_attn_pad_mask[0;34m()[0m
[0;32m      3 [0;31m  [0mbatch[0m[0;34m,[0m [0mlen_q[0m [0;34m=[0m [0mseq_q[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 4 [0;31m  [0mbatch[0m[0;34m,[0m [0mlen_k[0m [0;34m=[0m [0mseq_k[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m 