# Reading note for "Attention Is All You Need"

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
from Model import Transformer, get_pad_mask, get_subsequent_mask

In [2]:
# Make some fake data
torch.manual_seed(0)

BATCH_SIZE = 16
MAX_LEN_SEQ = 100
LEN_SRC = 100
LEN_TGT = 100
D_WORD_VEC = 512

src_word = torch.rand(BATCH_SIZE, LEN_SRC).long()
tgt_word = torch.rand(BATCH_SIZE, LEN_TGT).long()
# pc = torch.rand(100, 3)

In [3]:
# Hyperparameters

# number of encoder/decoder layers
NUM_LAYER = 6

# The dimensionality of input and output for EncoderDecoder model
D_MODEL = 512

# number of heads/parallel attention layers
NUM_HEAD = 8

# The dimensionality of qurey and key in each head
D_K = D_MODEL // NUM_HEAD
# print(d_k)

# The dimensionality of value in each head (could be different from d_k)
D_V = D_K

# The dimensionality of inner-layer for Position-wise Feed-Forward Network(FFN)
D_FF = 2048

In [4]:
transformer = Transformer(len_src_vocab=LEN_SRC,
                 len_tgt_vocab=LEN_TGT,
                 d_word_vec=D_WORD_VEC,
                 d_model=D_MODEL,
                 num_head=NUM_HEAD,
                 num_layer=NUM_LAYER,
                 d_k=D_K,
                 d_v=D_V,
                 d_ff=D_FF,
)
pred = transformer(src_word, tgt_word)
print(f"pred: {pred.shape}")

enc_output: torch.Size([16, 100, 512])
enc_src_mask: torch.Size([16, 100, 100])
trans_src_word: torch.Size([16, 100])
trans_src_mask: torch.Size([16, 100, 100])
transf_enc_output: torch.Size([16, 100, 512])
trans_dec_output: torch.Size([16, 100, 512])
seq_pred: torch.Size([16, 100, 100])
pred: torch.Size([16, 100])


In [4]:
emb = nn.Embedding(LEN_SRC, D_WORD_VEC)
src_emb = emb(src_word)
print(f"src_emb: {src_emb.shape}")

# # linear project input x into corresponding dimentionalities
# Wq = nn.Linear(d_x, d_model)
# Wk = nn.Linear(d_x, d_model)
# Wv = nn.Linear(d_x, d_model)

# query = Wq(x)
# key = Wk(x)
# value = Wv(x)

# print(f"query: {query.shape}, \nkey: {key.shape}, \nvalue: {value.shape}")

src_emb: torch.Size([16, 100, 512])


In [5]:
class ScaledDotProductAttention(nn.Module):
    
    """
    Compute Scaled Dot Product Attention
    """

    def __init__(self, dropout):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
    
        d_k = query.size(-1) # [batch, len, d_k]

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = scores.softmax(-1)

        if self.dropout is not None:
            scores = self.dropout(scores)

        att = torch.matmul(scores, value)

        return att, scores

In [6]:
# attention = ScaledDotProductAttention(dropout=0.1)
# att, scores = attention(src_emb, src_emb, src_emb, pad_mask)
# print(f"scores: {scores.shape} \nattention: {att.shape}, \n")

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        assert d_model % h == 0

        self.h = h
        self.d_k = d_k
        self.d_v = d_v

        self.wq = nn.Linear(d_model, h * d_k)
        self.wk = nn.Linear(d_model, h * d_k)
        self.wv = nn.Linear(d_model, h * d_v)
        self.fc = nn.Linear(h * d_v, d_model)

        self.attention = ScaledDotProductAttention(dropout=dropout)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, query, key, value, mask=None):

        batch_size = query.size(0)
        
        residual = query # [batch, len_seq, d_model]

        # print(f"query: {query.shape}, \nkey: {key.shape}, \nvalue: {value.shape}")

        # linear projection and split d_model by heads
        query = self.wq(query).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) # [batch, h, len_query, d_k]
        key = self.wk(key).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) # [batch, h, len_key, d_k]
        value = self.wv(value).view(batch_size, -1, self.h, self.d_v).transpose(1, 2) # [batch, h, len_value, d_v]

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.h, 1, 1)

        att, scores = self.attention(query, key, value, mask) # att: [batch, h, len_seq, d_v], scores: [batch, h, len_seq, len_seq]
        # print(f"scores: {scores.shape}, \nattention: {att.shape}")

        # concat heads
        att_cat = att.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_v) # [batch, len_seq, d_model]

        # final linear projection
        output = self.fc(att_cat) # [batch, len_seq, d_model]

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_norm(output + residual)

        return output, scores

In [8]:
# multi_head = MultiHeadAttention(NUM_HEAD, D_MODEL, D_K, D_V)
# output, scores = multi_head(src_emb, src_emb, src_emb, pad_mask)
# print(f"output of multi_head_attention: {output.shape}\nscores: {scores.shape}")

In [9]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()

        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)

        # self.w1 = nn.Conv1d(d_model, d_ff, 1)
        # self.w2 = nn.Conv1d(d_model, d_ff, 1)

        self.relu = nn.ReLU()
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w2(self.relu(self.w1(x)))
        x = self.dropout(x)
        x = self.layer_norm(x + residual)

        return x

In [10]:
# ffn = PositionWiseFeedForward(D_MODEL, D_FF)
# out_ffn = ffn(output)
# print(f"output of ffn: {out_ffn.shape}")

In [11]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, d_ff, h, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_att = MultiHeadAttention(h, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionWiseFeedForward(d_model, d_ff, dropout=dropout)

    def forward(self, enc_input, slf_att_mask=None):
        enc_output, enc_slf_att = self.slf_att(enc_input, enc_input, enc_input, mask=slf_att_mask)
        enc_output = self.pos_ffn(enc_output)
        # print(f"enc_layer_slf_mask: {slf_att_mask.shape}\nenc_layer_output: {enc_output.shape}")
        return enc_output, enc_slf_att

In [12]:
# enc_layer = EncoderLayer(D_MODEL, D_FF, NUM_HEAD, D_K, D_V)
# enc_output, enc_slf_att = enc_layer(src_emb)
# print(f"enc_output: {enc_output.shape}\nenc_slf_att: {enc_slf_att.shape}")

In [13]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, d_ff, h, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_att = MultiHeadAttention(h, d_model, d_k, d_v, dropout=dropout)
        self.enc_att = MultiHeadAttention(h, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionWiseFeedForward(d_model, d_ff, dropout=dropout)

    def forward(self, dec_input, enc_output, slf_att_mask=None, dec_enc_att_mask=None):
        dec_output, dec_slf_att = self.slf_att(dec_input, dec_input, dec_input, mask=slf_att_mask)
        dec_output, dec_enc_att = self.enc_att(dec_output,enc_output, enc_output, mask=dec_enc_att_mask)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_att, dec_enc_att

In [14]:
# dec_layer = DecoderLayer(D_MODEL, D_FF, NUM_HEAD, D_K, D_V)
# dec_output, dec_slf_att, dec_enc_att = dec_layer(src_emb, enc_output)
# print(f"dec_output: {dec_output.shape}\ndec_slf_att: {dec_slf_att.shape}\ndec_enc_att: {dec_enc_att.shape}")

In [15]:
def get_pad_mask(seq_q, seq_k):
	# seq_k和seq_q的形状都是[batch, len_seq]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    # print(pad_mask.shape)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  # shape [batch, len_seq, len_seq]
    return pad_mask

In [16]:
pad_mask = get_pad_mask(src_word, src_word)
# print(f"pad_mask: {pad_mask.shape}") # [batch, len, len]

In [17]:
def get_subsequent_mask(len_q, len_k):
    "Mask out subsequent positions."
    return torch.triu(torch.ones(1, len_q, len_k), diagonal=1) == 0

In [18]:
subsequent_mask = get_subsequent_mask(LEN_SRC, LEN_SRC)
print(f"subsequent_mask: {subsequent_mask.shape}")

subsequent_mask: torch.Size([1, 100, 100])


In [19]:
# # plt.figure(figsize=(5,5))
# plt.imshow(subsequent_mask[0,:20, :20])

In [20]:
class Encoder(nn.Module):

    def __init__(self, len_src_vocab, d_word_vec, num_layer, d_model, h, d_k, d_v, d_ff, dropout=0.1):

        super(Encoder, self).__init__()

        self.src_word_emb = nn.Embedding(len_src_vocab, d_word_vec)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_ff, h, d_k, d_v, dropout) for _ in range(num_layer)
        ])
        self.layer_norm = nn.LayerNorm(d_model)
        self.d_model = d_model

    def forward(self, src_word, src_mask=None):

        enc_output = self.layer_norm(self.dropout(self.src_word_emb(src_word)))
        # print(f"enc_output: {enc_output.shape}")

        for enc_layer in self.layer_stack:
            enc_output, *_ = enc_layer(enc_output, src_mask)
        print(f"enc_output: {enc_output.shape}\nenc_src_mask: {src_mask.shape}")

        return enc_output

In [21]:
# encoder = Encoder(LEN_SRC, D_WORD_VEC, NUM_LAYER, D_MODEL, NUM_HEAD, D_K, D_V, D_FF)
# enc_output = encoder(src_word, pad_mask)
# print(f"encoder output: {enc_output.shape}")

In [22]:
class Decoder(nn.Module):

    def __init__(self, len_tgt_vocab, d_word_vec, num_layer, d_model, h, d_k, d_v, d_ff, dropout=0.1):

        super(Decoder, self).__init__()
        self.tgt_word_emb = nn.Embedding(len_tgt_vocab, d_word_vec)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_ff, h, d_k, d_v, dropout) for _ in range(num_layer)
        ])
        self.layer_norm = nn.LayerNorm(d_model)
        self.d_model = d_model

    def forward(self, tgt_word, enc_output, tgt_mask=None, src_mask=None):
        dec_output = self.layer_norm(self.dropout(self.tgt_word_emb(tgt_word)))
        # print(f"tgt_word: {tgt_word.shape}\nenc_output: {enc_output.shape}\ntgt_mask: {tgt_mask.shape}\nsrc_mask: {src_mask.shape}")

        for dec_layer in self.layer_stack:
            dec_output, *_ = dec_layer(dec_output, enc_output, slf_att_mask=tgt_mask, dec_enc_att_mask=src_mask)
        
        return dec_output

In [23]:
# tgt_mask = get_pad_mask(src_word, tgt_word) & get_subsequent_mask(LEN_SRC, LEN_TGT)
# print(f"tgt_mask: {tgt_mask.shape}")
# src_mask = get_pad_mask(src_word, tgt_word)
# print(f"src_mask: {src_mask.shape}")

In [24]:
# decoder = Decoder(LEN_TGT, D_WORD_VEC, NUM_LAYER, D_MODEL, NUM_HEAD, D_K, D_V, D_FF)
# dec_output = decoder(tgt_word, enc_output, tgt_mask, src_mask)
# print(f"dec_output: {dec_output.shape}")

In [35]:
class Transformer(nn.Module):

    def __init__(self, len_src_vocab, len_tgt_vocab, d_word_vec, d_model, d_ff, num_layer, h, d_k, d_v, dropout=0.1):

        super(Transformer, self).__init__()
        
        self.len_src_vocab = len_src_vocab
        self.len_tgt_vocab = len_tgt_vocab
        self.d_model = d_model

        self.encoder = Encoder(
            len_src_vocab=len_src_vocab, 
            d_word_vec=d_word_vec, 
            num_layer=num_layer, 
            d_model=d_model, 
            h=h, d_k=d_k, d_v=d_v, d_ff=d_ff, 
            dropout=dropout
        )

        self.decoder = Decoder(
            len_tgt_vocab=len_tgt_vocab,
            d_word_vec=d_word_vec,
            num_layer=num_layer,
            d_model=d_model,
            h=h, d_k=d_k, d_v=d_v, d_ff=d_ff,
            dropout=dropout
        )

        self.tgt_word_prj = nn.Linear(d_model, len_tgt_vocab)

        assert d_word_vec == d_model

    def forward(self, src_word, tgt_word):

        src_mask = get_pad_mask(src_word, src_word)
        tgt_mask = get_pad_mask(tgt_word, tgt_word) & get_subsequent_mask(self.len_tgt_vocab, self.len_tgt_vocab)
        # print(f"trans_src_mask: {src_mask.shape}\ntrans_tgt_mask: {tgt_mask.shape}")

        enc_output = self.encoder(src_word=src_word, src_mask=src_mask)
        # print(f"trans_src_word: {src_word.shape}\ntrans_src_mask: {src_mask.shape}\ntransf_enc_output: {enc_output.shape}\n")
        dec_output = self.decoder(tgt_word=tgt_word, enc_output=enc_output, tgt_mask=tgt_mask, src_mask=src_mask)
        print(f"trans_src_word: {src_word.shape}\ntrans_src_mask: {src_mask.shape}\ntransf_enc_output: {enc_output.shape}\ntrans_dec_output: {dec_output.shape}")
        seq_pred = self.tgt_word_prj(dec_output)
        print(f"seq_pred: {seq_pred.shape}")

        return F.log_softmax(seq_pred, dim=-1).max(-1)[0]

In [36]:
transformer = Transformer(
    len_src_vocab=LEN_SRC, len_tgt_vocab=LEN_TGT, d_word_vec=D_WORD_VEC, d_model=D_MODEL, d_ff=D_FF, num_layer=NUM_LAYER, h=NUM_HEAD, d_k=D_K, d_v=D_V
)
pred = transformer(src_word, tgt_word)
print(f"pred: {pred.shape}")

enc_output: torch.Size([16, 100, 512])
enc_src_mask: torch.Size([16, 100, 100])
trans_src_word: torch.Size([16, 100])
trans_src_mask: torch.Size([16, 100, 100])
transf_enc_output: torch.Size([16, 100, 512])
trans_dec_output: torch.Size([16, 100, 512])
seq_pred: torch.Size([16, 100, 100])
pred: torch.Size([16, 100])
