In [2]:
import torch
import torch.nn as nn
import torch.functional as F
import math

In [3]:
X = torch.randn(128, 64, 512)
print(X.shape) 

torch.Size([128, 64, 512])


In [4]:
d_model = 512
n_head = 8

In [None]:
class multi_head_attention(nn.Module):
    def __init__(self, d_model, n_head):
        super(multi_head_attention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head
        q, k, v = self.W_q(q), self.W_k(k), self.W_v(v)
        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        score = q @ k.transpose(2, 3) / math.sqrt(n_d)
        if mask is not None:
            # mask = torch.tril(torch.ones(time, time, dtype=torch.bool))
            score = score.masked_fill(mask == 0, float('-inf'))
        score = self.softmax(score) @ v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)

        output = self.W_combine(score)
        return output

attention = multi_head_attention(d_model, n_head)
output = attention(X, X, X)
print(output.shape)
print(output)

torch.Size([128, 64, 512])
tensor([[[-3.5616e-01,  2.2980e-01, -1.0412e-01,  ...,  4.2845e-01,
           4.5223e-01, -2.4949e-01],
         [-1.6509e-02, -5.9100e-02,  1.3037e-01,  ...,  3.1069e-02,
           1.7961e-01, -1.2885e-01],
         [ 1.4183e-01, -1.8829e-01,  3.1000e-01,  ...,  1.3032e-01,
          -1.2106e-01, -2.0772e-01],
         ...,
         [ 4.5330e-02,  3.2771e-03,  1.5398e-01,  ...,  8.2911e-02,
          -6.8392e-02, -2.4243e-02],
         [ 4.0261e-02,  2.4556e-04,  1.2942e-01,  ...,  9.0920e-02,
          -5.6196e-02,  1.1676e-02],
         [ 1.6857e-02, -5.3559e-03,  1.5256e-01,  ...,  1.0340e-01,
          -6.6616e-02, -1.7736e-03]],

        [[ 8.9919e-02, -3.7148e-02, -6.1425e-01,  ..., -3.9366e-02,
          -6.7160e-02,  3.2894e-01],
         [ 4.1554e-02, -5.7711e-02, -5.6047e-01,  ..., -1.7539e-01,
          -2.8162e-01, -2.3067e-01],
         [-3.8199e-02, -1.8102e-01, -5.2488e-01,  ..., -1.5993e-01,
          -2.3844e-01,  7.6380e-02],
         ...

In [6]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size):
        super(TokenEmbedding, self).__init__(vocab_size, embed_size, padding_idx=1)

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len)
        pos = pos.unsqueeze(1).float()
        _2i = torch.arange(0, d_model, step=2, device=device)
        self.encoding[:, 0::2] = torch.sin(pos / 10000 ** (_2i / d_model))
        self.encoding[:, 1::2] = torch.cos(pos / 10000 ** (_2i / d_model))

    def forward(self, x):
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :]

In [None]:
class TransformerEmbedding(nn.Modules):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEmbedding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)
    
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super(LayerNorm, self).__init__() 
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.std(-1, unbiased=False, keepdim=True)
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
&&&&&&&&&&
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])


In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, dropout=0.1):
        self.fc1 = nn.Linear(d_model, hidden)
        self.fc2 = nn.Linear(hidden, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = multi_head_attention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.drop1 = nn.Dropout(drop_prob)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm2 = LayerNorm(d_model)
        self.drop2 = nn.Dropout(drop_prob)
    
    def forward(self, x, mask=None):
        _x = x
        x = self.attention(x, x, x, mask)

        x = self.drop1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)

        x = self.drop2(x)
        x = self.norm2(x + _x)
        return x
        

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()
        self.attention1 = multi_head_attention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)

        self.cross_attention = multi_head_attention(d_model, n_head)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = LayerNorm(d_model)
        self.dropout3 = nn.Dropout(drop_prob)

    def fromward(self, dec, enc, t_mask, s_mask):
        _x = dec
        x = self.attention1(dec, dec, dec, t_mask)
        
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            _x = x
            x = self.cross_attention(x, enc, enc, s_mask)

            x = self.dropout2(x)
            x = self.norm2(x + _x)

        _x = x
        x = self.ffn(x)

        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self, env_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super(Encoder, self).__init__()
        self.embedding = TransformerEmbedding(d_model, max_len, env_voc_size, drop_prob, device)
        self.layers = nn.ModuleList([EncoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layers)])
    
    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super(Decoder, self).__init__()
        self.embedding = TransformerEmbedding(d_model, max_len, dec_voc_size, drop_prob, device)
        self.layers = nn.ModuleList([DecoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layers)])

        self.fc = nn.Linear(d_model, dec_voc_size)

    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)
        
        dec = self.fc(dec)
        return dec

In [None]:
class Transformer(nn.Mudule):
    def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, max_len, d_model, n_heads, ffn_hidden, n_layers, drop_prob, device):
        super(Transformer, self).__init__()


        self.encoder = Encoder(enc_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)
        self.decoder = Decoder(dec_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
 
    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        len_q, len_k = q.shape[1], k.shape[1]

        # (batch, Time, len_q, len_k)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1, 1, 1, len_k)
        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)

        mask = q & k

        return mask

    def make_casual_mask(self, q, k):
        len_q, len_k = q.shape[1], k.shape[1]
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)
        return mask
    
    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * self.make_casual_mask(trg, trg)
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)
        
        enc = self.encoder(src, src_mask)
        dec = self.decoder(trg, enc, trg_mask, src_mask, src_trg_mask)
        return dec

