In [76]:
# refer: github.com/pbcquoc
import torch, os, math, copy
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# Positional Encoder
(this trick makes transformer and variants awesome but how ?)

In [79]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(Embedder, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        embed = self.embedding(x)
        return embed

class PosisionalEncoder(nn.Module):
    def __init__(self, d_model=768, max_seq_len=256, dropout=0.1):
        super(PosisionalEncoder, self).__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0,d_model,2):
                pe[pos, i] = math.sin(pos / (10000 ** (2*i/d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** (2*i/d_model)))
        pe = pe.unsqueeze(0)
        # this makes pe is not trained/updated by optimizer
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        seq_len = x.size(1)
        pe = Variable(self.pe[:, :seq_len], requires_grad=False)
        if x.is_cuda:
            pe.cuda()
        x = self.dropout(x + pe)
        return x
PosisionalEncoder(512)(torch.rand(5, 30, 512)).shape    

torch.Size([5, 30, 512])

# Multihead Attention operator
(awesome feature extractor)

In [46]:
class MultiheadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=None):
        super(MultiheadAttention, self).__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None
        
        # init mattrix weights for key, query and value
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)
    def forward(self, q, k, v, mask=None):
        """
        Parameters:
        -----------
        q: tensor shape `(batch_size, seq_len, d_model)`
        k: tensor shape `(batch_size, seq_len, d_model)`
        v: tensor shape `(batch_size, seq_len, d_model)`
        mask: tensor shape `(batch_size, 1, seq_len)`, the mask of self-attn layer at Decoder
        Return:
        -------
        output: tensor shape `(batch_size, seq_len, d_model)`
        """
        # calculate query, key, value vector from weight mattrix
        batch_size = q.size(0)
        q = self.q_linear(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # perfrom scale-dot attention op
        score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None: 
            mask = mask.unsqueeze(1)
            score = score.masked_fill(mask==0, -1e9)
        score = F.softmax(score, -1)
        if self.dropout: 
            output = self.dropout(score)
        output = torch.matmul(score, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.out(output)
        return output

MultiheadAttention(8, 512, 0.1)(torch.rand(8, 30, 512), torch.rand(8, 30, 512), torch.rand(8, 30, 512)).shape

torch.Size([8, 30, 512])

# Residual connection and Layer normalization
(faster converge and avoid losing information)

In [20]:
class Norm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
    
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(d_model))
        self.bias = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

Norm(512)(torch.randn(8, 128, 512)).shape

torch.Size([8, 128, 512])

In [34]:
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model))
    def forward(self, x):
        out = self.ff(x)
        return out
FeedForward()(torch.randn(8,128,512)).shape

torch.Size([8, 128, 512])

# Encoder, Decoder block

In [98]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.attn = MultiheadAttention(n_heads, d_model, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
    def forward(self, x, mask):
        """
        Parameters:
        -----------
        x: tensor shape `(batch_size, seq_len, model_dim)`
        mask: tensor shape `(batch_size, 1, model_dim)` for mask self-attention
        Return:
        -------
        out: tensor shape `(batch_size, seq_len, model_dim)`
        """
        x_norm = self.norm_1(x)
        x = x + self.dropout_1(self.attn(x_norm, x_norm, x_norm, mask))
        x_norm = self.norm_2(x)
        x = x = self.dropout_2(self.ff(x_norm))
        return x
net = EncoderBLock(512, 8, 2048)
net(torch.randn(8, 30, 512), torch.randn(8, 1, 30)).shape

torch.Size([8, 30, 512])

In [94]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(DecoderBlock, self).__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.norm_3 = Norm(d_model)
        
        self.attn_1 = MultiheadAttention(n_heads, d_model, dropout)
        self.attn_2 = MultiheadAttention(n_heads, d_model, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        """
        Parameters:
        -----------
        x: tensor input of target batch sentences
            shape `(batch_size, seq_len, d_model)`
        encoder_output: tensor output (contextual embedding) of encoder block
            shape `(batch_size, seq_len, d_model)`
        src_mask: tensor mask for encoder output
            shape `(batch_size, 1, seq_len)`
        tgt_mask: tensor for hide the future represented of predicted token from current step
            shape `(batch_size, 1, seq_len)`
        Return:
        -------
        out: tensor, contextual embedding of sentence
            shape `(batch_size, seq_len, d_model)`
        """
        x_norm = self.norm_1(x)
        x = x + self.dropout_1(self.attn_1(x_norm, x_norm, x_norm, tgt_mask))
        
        # get corr between current token embedding of decoder with all token embedding from encoder
        x_norm = self.norm_2(x)
        x = x + self.dropout_2(self.attn_2(x_norm, encoder_output, encoder_output, src_mask))
        
        x_norm = self.norm_3(x)
        x = x + self.dropout_3(self.ff(x_norm))
        return x
net = DecoderBlock(512, 8, 2048)
net(torch.randn(8, 30, 512), torch.randn(8, 30, 512), torch.randn(8, 1, 30), torch.randn(8, 1, 30)).shape

torch.Size([8, 30, 512])

# Build Transformers

In [96]:
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [166]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer, dropout=0.1):
        super(Encoder, self).__init__()
        self.N = num_layer
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PosisionalEncoder(d_model, max_seq_len, dropout)
        self.layers = get_clones(EncoderBlock(d_model, n_heads, d_ff, dropout), num_layer)
        self.norm = Norm(d_model)
    def forward(self, x, mask):
        """
        Parameters:
        -----------
        x: tensor, token idx of input sents
            shape `(batch_size, seq_len)`
        mask: tensor, shape `(batch_size, 1, seq_len)`
        Return:
        -------
        out: tensor, shape `(batch_size, seq_len, d_model)`
        """
        out = self.embed(x)
        out = self.pe(out)
        for i in range(self.N):
            out = self.layers[i](out, mask)
        out = self.norm(out)
        return out
en_vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer = 256, 30, 512, 8, 2048, 6
net = Encoder(en_vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer)
net(torch.LongTensor(8, max_seq_len).random_(0, en_vocab_size), torch.rand(8, 1, max_seq_len)).shape

torch.Size([8, 30, 512])

In [167]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer, dropout=0.1):
        super(Decoder, self).__init__()
        self.N = num_layer
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PosisionalEncoder(d_model, max_seq_len, dropout)
        self.layers = get_clones(DecoderBlock(d_model, n_heads, d_ff, dropout), num_layer)
        self.norm = Norm(d_model)
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        """
        Parameters:
        -----------
        x: tensor, token idx of input sents
            shape `(batch_size, seq_len)`
        encoder_output: tensor, contextual embedding for input sents
            shape `(batch_size, seq_len, d_model)`
        src_mask, tgt_mask: tensor, shape `(batch_size, 1, seq_len)`
            2 mask for sentence decoder and encoder respectively
        Return:
        -------
        out: contextual embedding of whole predicted sents
        """
        out = self.embed(x)
        out = self.pe(out)
        for i in range(self.N):
            out = self.layers[i](out, encoder_output, src_mask, tgt_mask)
        out = self.norm(out)
        return out
de_vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer = 256, 30, 512, 8, 2048, 6
net = Decoder(de_vocab_size, max_seq_len, d_model, n_heads, d_ff, num_layer)
net(torch.LongTensor(8, max_seq_len).random_(0, de_vocab_size), torch.rand(8, max_seq_len, d_model), torch.randn(8, 1, max_seq_len), torch.randn(8, 1, max_seq_len)).shape    

torch.Size([8, 30, 512])

In [168]:
class Transformer(nn.Module):
    def __init__(self, en_config, de_config):
        super(Transformer, self).__init__()
        self.encoder = Encoder(**en_config)
        self.decoder = Decoder(**de_config)
        self.fc = nn.Linear(de_config["d_model"], de_config["vocab_size"])
    def forward(self, src_sent, tgt_sent, src_mask, tgt_mask):
        encoder_output = self.encoder(src_sent, src_mask)
        decoder_output = self.decoder(tgt_sent, encoder_output, src_mask, tgt_mask)
        out = self.fc(decoder_output)
        return out

In [169]:
en_config = {
    "vocab_size": 256,
    "max_seq_len": 30,
    "d_model": 512,
    "n_heads": 8,
    "d_ff": 2048 ,
    "num_layer": 6}
de_config = {
    "vocab_size": 128,
    "max_seq_len": 18,
    "d_model": 512,
    "n_heads": 8,
    "d_ff": 2048 ,
    "num_layer": 6}

In [172]:
batch_size = 8
en_seq_len = en_config["max_seq_len"]
de_seq_len = de_config["max_seq_len"]
en_vocab_size = en_config["vocab_size"]
de_vocab_size = de_config["vocab_size"]

net = Transformer(en_config, de_config)
print(sum(p.numel() for p in net.parameters() if p.requires_grad))

prob_map = net(src_sent=torch.LongTensor(batch_size, en_seq_len).random_(0, en_vocab_size),\
                tgt_sent=torch.LongTensor(batch_size, de_seq_len).random_(0, de_vocab_size),\
                src_mask=torch.randn(batch_size, 1, en_seq_len),\
                tgt_mask=torch.randn(batch_size, 1, de_seq_len))
prob_map.shape

44402816


torch.Size([8, 18, 128])

In [174]:
prob_map[0,0,:]

tensor([-0.3601,  0.4778, -0.8045, -0.0144, -0.0967, -0.4138,  0.8579, -0.4447,
         0.2281, -0.0091,  0.3790,  0.0720,  0.0519,  0.3575,  0.4493, -0.5355,
        -0.6952,  0.1924,  0.0304,  0.0606, -0.2941, -0.3592, -1.0357,  0.4034,
         0.2400,  0.2867, -1.2320,  0.2998,  0.2474,  0.1706, -0.6095,  0.1451,
         0.5043, -0.0268, -0.0201, -0.3259, -0.4269, -0.0205,  0.2256,  0.0126,
         0.3072, -0.7930, -0.7144,  0.4032,  0.1989,  0.8907, -0.8826, -0.3180,
        -0.7040, -0.0209,  0.2044,  0.4262, -1.5077, -0.2173, -0.7608, -0.6945,
         0.4776, -0.0646,  0.7162,  0.4895, -0.3851, -1.2172,  0.6286, -0.4064,
        -0.2850, -0.3122,  0.7426,  0.7348,  0.5608,  0.2874, -0.4218,  0.2649,
         0.0022,  1.0576,  0.6075,  0.7453,  0.4029,  0.8586, -0.2010,  0.4629,
        -0.1588,  0.8472,  0.2970,  0.0262,  0.2182,  0.3083,  0.3395,  0.1802,
         0.3352,  0.4177,  0.4831, -0.4374,  0.4496,  0.0138,  0.2590, -0.9276,
         0.4266, -0.8178,  0.3704,  0.23

# Load dataset