In [91]:
import torch
import math

In [92]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, feat_dim, max_seq_len=5000):
        super().__init__()
        
        positions = torch.arange(max_seq_len).unsqueeze(dim=1)
        feat_indexes = torch.arange(start=0, end=feat_dim, step=2)
        
        self.positional_encoding = torch.zeros(max_seq_len, feat_dim)
        
        positions_feat_idx_matrix = positions / (10000 ** (feat_indexes / feat_dim))
        
        self.positional_encoding[:, 0::2] = torch.sin(positions_feat_idx_matrix)
        self.positional_encoding[:, 1::2] = torch.cos(positions_feat_idx_matrix)
    
    # x shape is (batch, sequence, elements)
    def forward(self, x):
        return x + self.positional_encoding[:x.shape[-2], :x.shape[-1]].unsqueeze(dim=0)
    
class AttentionHead(torch.nn.Module):
    def __init__(self, input_dimension, key_dimension, value_dimension):
        super().__init__()
        self.query_projection = torch.nn.Linear(in_features=input_dimension, out_features=key_dimension)
        self.key_projection = torch.nn.Linear(in_features=input_dimension, out_features=key_dimension)
        self.value_projection = torch.nn.Linear(in_features=input_dimension, out_features=value_dimension)
        self.attention_scale = math.sqrt(key_dimension)
    
    # x dimension is (batch, sequence, embedding)
    def forward(self, query, key, value):
        projected_query = self.query_projection(query) # (batch, sequence, key_dimension)
        projected_key = self.key_projection(key) # (batch, sequence, key_dimension)
        projected_value = self.value_projection(value) # (batch, sequence, value_dimension)
        scaled_dot_product = torch.matmul(projected_query, torch.transpose(projected_key, 1, 2)) / self.attention_scale # (batch, sequence, sequence)
        attention = torch.nn.functional.softmax(scaled_dot_product, dim=-1) # (batch, sequence, sequence)
        
        return torch.matmul(attention, projected_value) # (batch, sequence, value_dimension)

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_head, input_dimension, key_dimension, value_dimension, output_dimension):
        super().__init__()
        self.attention_heads = [AttentionHead(input_dimension=input_dimension, key_dimension=key_dimension, value_dimension=value_dimension) for _ in range(num_head)]
        self.linear_output = torch.nn.Linear(in_features=num_head * value_dimension, out_features=output_dimension)
    
    # x dimension is (batch, sequence, embedding)
    def forward(self, query, key, value):
        head_results = [attention_head(query=query, key=key, value=value) for attention_head in self.attention_heads] # (batch, sequence, value_dimension)
        concatenated_heads = torch.cat(head_results, dim=-1) # (batch, sequence, value_dimension * num_heads)
        
        return self.linear_output(concatenated_heads) # (batch, sequence, output_dimension)
    
class TransformerEncoder(torch.nn.Module):
    def __init__(self, num_head, input_dimension, ff_inner_dim):
        super().__init__()
        
        if input_dimension % num_head != 0:
            raise Exception("input_dimension is not divisible by num_head!")
        
        self.multi_head_attention = MultiHeadAttention(
            num_head=num_head,
            input_dimension=input_dimension,
            key_dimension=input_dimension // num_head,
            value_dimension=input_dimension // num_head,
            output_dimension=input_dimension, # force output dimension = input dimension so that we can do residual connection
        )
        
        self.layer_norm_mhsa = torch.nn.LayerNorm(normalized_shape=input_dimension)
        
        # force output dimension = input dimension so that we can do residual connection
        self.feed_forward = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_dimension, out_features=ff_inner_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=ff_inner_dim, out_features=input_dimension)
        )
        
        self.layer_norm_ff = torch.nn.LayerNorm(normalized_shape=input_dimension)
    
    # x dimension is (batch, sequence, embedding)
    def forward(self, x):
        mhsa = self.multi_head_attention(query=x, key=x, value=x) # (batch, sequence, input_dimension)
        x = self.layer_norm_mhsa(x + mhsa) # (batch, sequence, input_dimension)
        feed_forward = self.feed_forward(x) # (batch, sequence, input_dimension)
        
        return self.layer_norm_ff(x + feed_forward) # (batch, sequence, input_dimension)
    
class TransformerDecoder(torch.nn.Module):
    def __init__(self, num_head, input_dimension, ff_inner_dim):
        super().__init__()
        
        if input_dimension % num_head != 0:
            raise Exception("input_dimension is not divisible by num_head!")
        
        self.multi_head_attention = MultiHeadAttention(
            num_head=num_head,
            input_dimension=input_dimension,
            key_dimension=input_dimension // num_head,
            value_dimension=input_dimension // num_head,
            output_dimension=input_dimension, # force output dimension = input dimension so that we can do residual connection
        )
        
        self.layer_norm_mhsa = torch.nn.LayerNorm(normalized_shape=input_dimension)
        
        self.encoder_multi_head_attention = MultiHeadAttention(
            num_head=num_head,
            input_dimension=input_dimension,
            key_dimension=input_dimension // num_head,
            value_dimension=input_dimension // num_head,
            output_dimension=input_dimension, # force output dimension = input dimension so that we can do residual connection
        )
        
        self.layer_norm_encoder_mhsa = torch.nn.LayerNorm(normalized_shape=input_dimension)
        
        # force output dimension = input dimension so that we can do residual connection
        self.feed_forward = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_dimension, out_features=ff_inner_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=ff_inner_dim, out_features=input_dimension)
        )
        
        self.layer_norm_ff = torch.nn.LayerNorm(normalized_shape=input_dimension)
    
    # x dimension is (batch, sequence, embedding)
    def forward(self, x, encoder_output):
        mhsa = self.multi_head_attention(query=x, key=x, value=x) # (batch, sequence, input_dimension)
        x = self.layer_norm_mhsa(x + mhsa) # (batch, sequence, input_dimension)
        
        encoder_mhsa = self.encoder_multi_head_attention(query=x, key=encoder_output, value=encoder_output) # (batch, sequence, input_dimension)
        x = self.layer_norm_encoder_mhsa(x + encoder_mhsa) # (batch, sequence, input_dimension)
        
        feed_forward = self.feed_forward(x) # (batch, sequence, input_dimension)
        
        return self.layer_norm_ff(x + feed_forward) # (batch, sequence, input_dimension)

class Transformer(torch.nn.Module):
    def __init__(self,
                 input_dimension,
                 encoder_layer_num,
                 encoder_head_num,
                 encoder_ff_inner_dim,
                 decoder_layer_num,
                 decoder_head_num,
                 decoder_ff_inner_dim,
                 output_dimension
                ):
        super().__init__()
        
        self.positional_encoding = PositionalEncoding(feat_dim=input_dimension)
        
        self.encoders = torch.nn.Sequential(
            *[TransformerEncoder(num_head=encoder_head_num, input_dimension=input_dimension, ff_inner_dim=encoder_ff_inner_dim) for _ in range(encoder_layer_num)]
        )
        self.decoders = [TransformerDecoder(num_head=decoder_head_num, input_dimension=input_dimension, ff_inner_dim=decoder_ff_inner_dim) for _ in range(decoder_layer_num)]
        
        self.output_linear = torch.nn.Linear(in_features=input_dimension, out_features=output_dimension)
        
    def forward(self, input_embedding, output_embedding):
        input_embedding = self.positional_encoding(input_embedding)
        encoder_output = self.encoders(input_embedding)
        
        decoder_output = self.positional_encoding(output_embedding)
        
        for decoder in self.decoders:
            decoder_output = decoder(decoder_output, encoder_output)
        
        return self.output_linear(decoder_output) # (batch, sequence, output_dimension)

In [93]:
# shape is (sequence, embedding_dimension)
EMBEDDING_DIMENSION = 512
SEQ_LEN = 3

input_embeddings_sequence = torch.randn((3, SEQ_LEN, EMBEDDING_DIMENSION))
input_embeddings_sequence.shape

torch.Size([3, 3, 512])

In [94]:
model = Transformer(
    input_dimension=input_embeddings_sequence.shape[-1],
    encoder_layer_num=6,
    encoder_head_num=8,
    encoder_ff_inner_dim=2048,
    decoder_layer_num=6,
    decoder_head_num=8,
    decoder_ff_inner_dim=2048,
    output_dimension=3
)

result = model(input_embeddings_sequence, torch.randn((3, 5, EMBEDDING_DIMENSION)))

print(result.shape)
print(result)

torch.Size([3, 5, 3])
tensor([[[ 0.1533, -0.2047, -0.0925],
         [ 0.5067, -0.6973, -0.4356],
         [ 0.1026, -0.2700, -0.4674],
         [-0.1781, -0.1886,  0.1556],
         [-0.0872,  0.2352,  0.1347]],

        [[ 0.9160, -0.3647, -0.4792],
         [ 0.7410, -0.4025, -0.2174],
         [ 0.4471,  0.0610, -0.0307],
         [ 1.0257, -0.6982, -0.5125],
         [ 0.5325, -0.2835, -0.1735]],

        [[-0.6674,  0.0289,  0.5842],
         [-0.0039, -0.8921,  0.2027],
         [ 0.0155, -0.2702,  0.9253],
         [-0.3243, -0.2298,  0.8721],
         [-0.5108, -0.6590,  0.1511]]], grad_fn=<ViewBackward0>)
