# Reconstructing Transformer

This notebook purpose is auto-didactic exploration of transformer code connected to equations from the paper.

### Contents

In [1]:
import torch
from matplotlib import pyplot as plt

The instance of encoder-decoder transformer is the default model in pytorch, it is constructed of encoder and decoder blocks, just like the original paper.

In [2]:
transformer = torch.nn.Transformer(
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    activation="relu",
    batch_first=True,
)


transformer = transformer.eval()  # turn off dropout etc.

This model has over 40 million parameters, which means that it is hard to train quickly, but it can be easily explored.

In [3]:
sum(param.numel() for param in transformer.parameters() if param.requires_grad)

44140544

Both Encoder and Decoder are abstractions, consisting of multiple encoder and decoder blocks. It is possible to create instances of those blocks one by one, each of those blocks is governed by a single set of equations, which are the same for each layer, but with different parameters. 

In [4]:
print(f"Encoder: {sum(param.numel() for param in transformer.encoder.parameters() if param.requires_grad)}")
print(f"Decoder: {sum(param.numel() for param in transformer.decoder.parameters() if param.requires_grad)}")

Encoder: 18915328
Decoder: 25225216


Encoder consist of encoder block with self-attention and feed-forward layer. Usually in NLP applications words are converted to tokens and to embeddings, which allow the model to process the input, in this implementation initial embedding is skipped, since transformer can be used to different data types, not only text.

In [5]:
transformer.encoder.layers[0]

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [6]:
inputs = torch.rand(8, 32, 512)  # shape = (BATCH_SIZE, SEQUENCE_LEN, EMBEDDING_DIM)  EMBEDDING_DIM also known as d_model or model dimension

Processing toy input (random tensor) does not change its change, when processed with encoder block or full encoder.

In [7]:
with torch.no_grad():
    layer_outputs = transformer.encoder.layers[0](inputs)
    encoder_outputs = transformer.encoder(inputs)

layer_outputs.shape, encoder_outputs.shape

(torch.Size([8, 32, 512]), torch.Size([8, 32, 512]))

Encoder has two sub-layers: Multi-Head Attention and Feed-Forward Network.

In [8]:
# create smaller encoder block
encoder_block = torch.nn.TransformerEncoderLayer(
    d_model=32,
    nhead=1,
    dim_feedforward=128,
    dropout=0.0,
    activation="relu",
    batch_first=True,
)

Self attention is computed first

In [9]:
with torch.no_grad():
    x = torch.rand(1, 4, 32)  # one example with 4 items and model dimension of 32
    outputs, scores = encoder_block.self_attn(query=x, key=x, value=x)  # in self attention Q, K, V are all the same

outputs.shape, scores.shape


(torch.Size([1, 4, 32]), torch.Size([1, 4, 4]))

After each learnable operation, there is a residual connection and a layer normalization.

In [10]:
with torch.no_grad():
    x = x + outputs  # skip-connection
    x = encoder_block.norm1(x)  # layer normalization

x.shape

torch.Size([1, 4, 32])

Feed-forward network is applied to the output of the self-attention sub-layer.

In [11]:
with torch.no_grad():
    x = encoder_block.linear1(x)  # linear up-projection transformation
    x = encoder_block.activation(x)  # activation
    x = encoder_block.linear2(x)  # linear down-projection transformation

Each sub-layer has a residual connection around it followed by a layer normalization.

In [12]:
with torch.no_grad():
    x = x + outputs  # skip-connection
    x = encoder_block.norm1(x)  # layer normalization

x.shape

torch.Size([1, 4, 32])

# Decoder

Decoder has 3 sublayers, two are the same as in encoder and the third one is multi-head cross-attention.

In [13]:
transformer.decoder.layers[0]

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)

In [14]:
inputs = torch.rand(8, 32, 512)
target = torch.rand(8, 3, 512)

with torch.no_grad():
    encoder_outputs = transformer.encoder(src=inputs)  # encoder output is called memory in pytorch
    decoder_outputs = transformer.decoder(tgt=target, memory=encoder_outputs)
    outputs = transformer(src=inputs, tgt=target)

encoder_outputs.shape, decoder_outputs.shape, outputs.shape

(torch.Size([8, 32, 512]), torch.Size([8, 3, 512]), torch.Size([8, 3, 512]))

In [15]:
# all masks are none by default so outputs are exactly the same 
torch.all(outputs == decoder_outputs)

tensor(True)

During inference language transformers (for example for machine translation) are auto-regressive, which means they require multiple inference steps to produce the output. In this case the output of the previous step is used as an input to the next step. For language tasks decoder output needs to be converted to tokens, which is done by linear layer followed by softmax activation converting dense decoder output to probability over vocabulary.

Decoder outputs representation for whole sequence, but in standard application only last token is used and appended to input sequence for the next step. Loop is broken when special token for sentence end is generated or when model reaches maximum number of steps.


In [16]:
# create smaller encoder block
decoder_block = torch.nn.TransformerDecoderLayer(
    d_model=32,
    nhead=1,
    dim_feedforward=128,
    dropout=0.0,
    activation="relu",
    batch_first=True,
)

In [17]:
x[0].shape, x[1].shape, len(x)

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [87]:
inputs = torch.rand(8, 32, 32)  # (BATCH_SIZE, SEQUENCE_LENGTH, MODEL_DIMENSION)
encoder_outputs = torch.rand(8, 32, 32)
target = torch.rand(8, 24, 32)

with torch.no_grad():
    # self-attention over target which is the decoder input
    x, _ = decoder_block.self_attn(key=target, query=target, value=target)  # ignore attention weights
    x = decoder_block.dropout1(x)  # dropout
    self_attention_outputs = decoder_block.norm1(x + target)  # skip-connection and layer normalization
    # cross-attention using encoder outputs as key and value
    x, _ = decoder_block.multihead_attn(query=self_attention_outputs, key=encoder_outputs, value=encoder_outputs)
    x = decoder_block.dropout2(x)
    multihead_attention_outputs = decoder_block.norm2(x + self_attention_outputs)  # skip connection
    # linear feed forward
    x = decoder_block.linear1(multihead_attention_outputs)
    x = decoder_block.activation(x)
    x = decoder_block.linear2(x)
    x = decoder_block.dropout3(x)
    outputs = decoder_block.norm3(x + multihead_attention_outputs)  # skip connection
    
outputs.shape


torch.Size([8, 24, 32])

### Multi-Head Attention

Key building block of transformer model is multi-head attention. In classical implementation it is used both as self-attention and cross-attention, where the difference between self and cross attention is the inputs to the layer, while the underlying mechanism is the same.

The multi-head attention mechanism has three inputs (with additional masks): 
* `query`
* `key`
* `value`

In [222]:
# see: https://ai.stackexchange.com/questions/35548/when-exactly-does-the-split-into-different-heads-in-multi-head-attention-occur
attention = torch.nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True, bias=False)
softmax = torch.nn.Softmax(dim=-1)
attention = attention.eval()  # turn off dropout

In [223]:
x = torch.rand(2, 16, 32) 

outputs, scores = attention(x, x, x)
outputs.shape, scores.shape

(torch.Size([2, 16, 32]), torch.Size([2, 16, 16]))

In [219]:
for name, param in attention.named_parameters():
    print(f"{name} : {param.shape}")

in_proj_weight : torch.Size([96, 32])
out_proj.weight : torch.Size([32, 32])


The code is optimized so only single parameter matrix is create for query, key and value projections.

In [208]:
# self attention where query, key, value are all the same
x = torch.rand(2, 16, 32)  # no batch

q_proj_weight = attention.in_proj_weight[:32, :]  # query projection weight
k_proj_weight = attention.in_proj_weight[32:64, :]  # key projection weight
v_proj_weight = attention.in_proj_weight[64:, :]  # value projection weight

# multiplication operation ignored batch dimension
# those projections are allowed to have bias, but it is skipped for simplicity
query = x @ q_proj_weight  # query projection is simply a matrix multiplication
key = x @ k_proj_weight  # key can have different size than query and value
value = x @ v_proj_weight

query.shape, key.shape, value.shape

(torch.Size([2, 16, 32]), torch.Size([2, 16, 32]), torch.Size([2, 16, 32]))

Those computed values are used to as input to attention operation

In [209]:
attention_scores = query @ key.transpose(1, 2)  # this is a regular transpose ignoring batch dimension
attention_scores = softmax(attention_scores / torch.sqrt(torch.Tensor([32])))  # element-wise with regularization
attention_scores.shape

torch.Size([2, 16, 16])

In [216]:
outputs = attention_scores @ value  # output computation
outputs = outputs @ attention.out_proj.weight.transpose(0, 1)  # output projection required by multihead attention
outputs.shape

torch.Size([2, 16, 32])

### Feed-Forward Network