# Transformer Encoder

![jupyter](../images/10/transformer.svg)

## Encoder Block

In [1]:
import torch
from torch import nn
import d2l
import math

In [2]:
#@save
class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_hiddens, num_heads, 
                 dropout, use_bias=False):
        super(EncoderBlock, self).__init__()
        self.attention = d2l.MultiHeadAttention(key_size, query_size,
                                                value_size, num_hiddens,
                                                num_heads, dropout, use_bias)
        self.addnorm1 = d2l.AddNorm(norm_shape, dropout)
        self.ffn = d2l.PositionWiseFFN(num_hiddens, ffn_num_hiddens)
        self.addnorm2 = d2l.AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        # `X` shape: (`batch_size`, `num_steps`, `num_hiddens`)
        # after attention: unchanged
        # after addnorm1: unchanged
        # after ffn: `num_hiddens` -> `ffn_num_hiddens` -> `num_hiddens`
        # after addnorm2: unchanged
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

## Transformer Encoder

In [3]:
#@save
class TransformerEncoder(d2l.Encoder):
    def __init__(self, vocab_size, num_hiddens, 
                 norm_shape, ffn_num_hiddens, num_heads, 
                 num_layers, dropout, use_bias=False):
        super(TransformerEncoder, self).__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                "block" + str(i),
                EncoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
                             norm_shape, ffn_num_hiddens, num_heads, 
                             dropout, use_bias))

    def forward(self, X, valid_lens):
        # rescale by square root
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [4]:
encoder = TransformerEncoder(500, 24, 
                             [100, 24], 50, 8, 
                             2, 0.5)
encoder.eval()
valid_lens = torch.tensor([3, 2])
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])