In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Module, Linear, Dropout, Parameter
from torch.optim import Adam


class FeedForward(Module):
    def __init__(self, d_model=768, d_ff=2048, dropout=0.2):
        super(FeedForward, self).__init__()
        
        self.d_model = d_model
        self.d_ff = d_ff
        
        self.dropout = Dropout(dropout)
        self.linear1 = Linear(d_model, d_ff)
        self.linear2 = Linear(d_ff, d_model)
        
    def forward(self, inputs):
        # inputs: [batch, time, d_model]
        output = self.linear1(inputs)
        output = F.relu(output)
        output = self.dropout(output)
        output = self.linear2(output)
        
        return output
    
    
class LayerNorm(Module):
    def __init__(self, d_model=768, epsilon=1e-6):
        super(LayerNorm, self).__init__()
        
        self.d_model = d_model
        self.epsilon = epsilon
        self.gamma = Parameter(torch.ones(d_model))
        self.beta = Parameter(torch.zeros(d_model))
        
    def forward(self, inputs):
        # inputs: [batch, time, d_model]
        mean = inputs.mean(dim=2, keepdim=True)
        var = inputs.var(dim=2, keepdim=True)
        
        return self.gamma * (inputs - mean) / torch.sqrt(var + self.epsilon) + self.beta
    
    
class MultiheadAttention(Module):
    def __init__(self, d_model=768, num_heads=8, dropout=0.2):
        super(MultiheadAttention, self).__init__()
        
        self.d_model = d_model
        self.d_k = int(d_model / num_heads)
        if self.d_k * num_heads != d_model:
            raise Exception("d_model cannot be divided by num_heads.")
        self.num_heads = num_heads
            
        self.query = Linear(d_model, d_model)
        self.key = Linear(d_model, d_model)
        self.value = Linear(d_model, d_model)
        
        self.dropout = Dropout(dropout)
        
        self.output = Linear(d_model, d_model)
        
    def forward(self, query, key, value, future_mask=None, pad_mask=None):
        # query, key, value: [batch, time, d_model]
        assert len(query.size()) == 3, "input is not batch"
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        # query, key, value: [batch * num_heads, time, d_k]
        query = torch.cat(torch.split(query, self.d_k, dim=2), dim=0)
        key = torch.cat(torch.split(key, self.d_k, dim=2), dim=0)
        value = torch.cat(torch.split(value, self.d_k, dim=2), dim=0)
        
        # attention_score: [batch * num_heads, time, time]
        attention_score = torch.matmul(query, key.transpose(1,2)) / np.sqrt(self.d_k)
        
        # if mask is True, fill to -inf
        if future_mask is not None:
            attention_score = attention_score.masked_fill(mask=future_mask, value=-float("inf"))
        if pad_mask is not None:
            # reshape pad_mask from [batch, 1, time] to [batch * num_heads, 1, time]
            pad_mask = torch.cat([pad_mask]*self.num_heads, dim=0)
            attention_score = attention_score.masked_fill(mask=pad_mask, value=-float("inf"))
        
        # change score to probability
        attention_score = F.softmax(attention_score, dim=2)
        attention_score = self.dropout(attention_score)
        
        # probability * value: [batch * num_heads, time, d_k]
        output = torch.matmul(attention_score, value)
        
        # reshape output: [batch, time, d_model]
        batch_size = int(output.size()[0] / self.num_heads)
        output = torch.cat(torch.split(output, batch_size, dim=0), dim=2)
        
        # linear projection of output
        output = self.output(output)
        
        return output
    
    
class PositionalEncoding(Module):
    def __init__(self, d_model=768, max_len=150, pad_id=0):
        super(PositionalEncoding, self).__init__()
        
        self.d_model = d_model
        self.max_len = max_len
        self.pad_id = pad_id
        
        self.pe = torch.zeros([max_len, d_model])
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                self.pe[pos, i] = np.sin(pos / 10000 ** (i / d_model))
                self.pe[pos, i+1] = np.cos(pos / 10000 ** (i / d_model))
        
    def forward(self, inputs):
        # inputs: [batch, time]
        batch_size = inputs.size()[0]
        seq_len = inputs.size()[1]
        
        # pad_mask: [batch, time, 1]
        pad_mask = (inputs == self.pad_id)
        pad_mask = pad_mask.view(batch_size, seq_len, 1)
        
        # pe: [max_len, d_model] => [batch, seq_len, d_model]
        pe = torch.stack([self.pe[:seq_len, :]]*batch_size, dim=0)
        pe = pe.masked_fill(mask=pad_mask, value=0)
        
        return pe

In [2]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, num_heads=8, dropout=0.2):
        super(DecoderBlock, self).__init__()
        
        self.self_attention = MultiheadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, inputs, future_mask=None, tgt_pad_mask=None):
        output = self.self_attention(inputs, inputs, inputs, future_mask=future_mask, pad_mask=tgt_pad_mask)
        output = self.dropout(output)
        output_ = self.norm1(output + inputs)

        output = self.feedforward(output_)
        output = self.dropout(output)
        output = self.norm2(output + output_)
        
        return output


class TransformerDecoder(nn.Module):
    def __init__(self, shared_embedding, d_model=512, d_ff=2048, num_heads=8, num_layers=6, max_len=100, dropout=0.2, pad_id=0):
        super(TransformerDecoder, self).__init__()
        
        self.linear = nn.Linear(d_model*2, d_model)
        layers = []
        layer = DecoderBlock(d_model, d_ff, num_heads, dropout)
        for i in range(num_layers):
            layers.append(layer)
        self.layers = nn.ModuleList(layers)
        self.embedding = shared_embedding
        self.pe = PositionalEncoding(self.embedding.embedding_dim, max_len, pad_id)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, decoder_inputs, encoder_output, future_mask=None, tgt_pad_mask=None):
        embedding = self.embedding(decoder_inputs)
        pe = self.pe(decoder_inputs)
        embedding = embedding + pe
        length = decoder_inputs.size(1)
        inputs = torch.cat([embedding, encoder_output[:, 0, :].unsqueeze(dim=1).repeat(1, length, 1)], dim=2)
        inputs = self.linear(self.dropout(inputs))
        output = self.dropout(inputs)
        for layer in self.layers:
            output = layer(output, future_mask, tgt_pad_mask)
        
        return output

In [3]:
vocab_size = 100
hidden_size = 128
embedding = nn.Embedding(vocab_size, hidden_size)
transformer = TransformerDecoder(embedding, 128, 512, 4, 3)
optim = Adam(transformer.parameters(), lr=3e-5)

In [4]:
encoder_output = torch.rand(4, 10, 128)
decoder_inputs = torch.randint(0, 100, size=(4, 10))
outputs = transformer(decoder_inputs, encoder_output)
loss = outputs.sum()

In [5]:
loss.backward()
optim.step()

In [6]:
list(transformer.layers[0].parameters()) == list(transformer.layers[1].parameters())

True

In [7]:
transformer

TransformerDecoder(
  (linear): Linear(in_features=256, out_features=128, bias=True)
  (layers): ModuleList(
    (0): DecoderBlock(
      (self_attention): MultiheadAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (key): Linear(in_features=128, out_features=128, bias=True)
        (value): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (output): Linear(in_features=128, out_features=128, bias=True)
      )
      (norm1): LayerNorm()
      (feedforward): FeedForward(
        (dropout): Dropout(p=0.2, inplace=False)
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
      )
      (norm2): LayerNorm()
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (1): DecoderBlock(
      (self_attention): MultiheadAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (