# Transformer blocks in Depth

This notebook target to analyze the transformer blocks in depth, from pytorch code to the kernel level implementation, and discuss the details of the potential optimizations.

## 1. Self attention layer

### Vanilla self-attention layer in Pytorch

In [2]:
import torch
import torch.nn as nn
class SelfAtten(nn.Module):

    def __init__(self, dim=512) -> None:
        super().__init__()
        self.dim = dim
        self.qkv_proj = nn.Linear(dim, dim*3)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x, mask):
        # x: [B, S, D]
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        # [B, S, S]
        score = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5) 
        score = score.masked_fill(mask == 0, -1e3)
        attn = score.softmax(dim=-1)
        # [B, S, D]
        out = torch.matmul(attn, v)

        out = self.out_proj(out)
        return out

### Pytorch Kernels

Pytorch implementation of the above layer requires 10 kernels, which are listed below:

4 GEMM kernels.
1 Softmax kernel.
And some elementwise for the mask compute.

![kernels](./media/attention-torch-kernel-trace.png)



## Onnx graph

![onnx graph of attention layer](./media/attention.png)

## TensorRT kernels

After the onnx lowered to TRT, some elementwise kernels and softmax kernels are fused together, there are only 7 kernels in total.
![kernels](./media/attention-trt-kernel.png)



Analyze of the fusions:

- Bias of first GEMM, and 3 slices (tensor.chunk) are fused together into _myl_bb0_3_AddSliSliSli. The shape compute is done in host, no kernels needed.
- Transpose of the QK^T gemm is fused with Matmul
- Mask and softmax are fused together into _myl_bb0_2_*
- The bias add after output projection gemm is not fused, which is not good.


Some implementation can futher fuse the matmul-softmax-matmul into one kernel. In that case, only 4 kernels are needed.
- 1 for QKV gemm, 1 for slice, 1 for fused attention, 1 for output projection gemm.

## 2. Multi-head attention

In [3]:
import torch
import torch.nn as nn

class MultiHeadSelfAtten(nn.Module):

    def __init__(self, dim=512, head=8) -> None:
        super().__init__()
        self.dim = dim
        self.head = head
        self.qkv_proj = nn.Linear(dim, dim*3)
        self.out_proj = nn.Linear(dim, dim)

    # [B, S, D] -> [B, H, S, D/H]
    def reshape(self, x):
        return x.reshape(x.shape[0], -1, self.head, self.dim // self.head).transpose(1, 2)

    def self_attn(self, q, k, v, mask):
        # q, k, v: [B, H, S, D/H]
        # out: [B, H, S, D/H]
        score = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5) 
        score = score.masked_fill(mask == 0, -1e3)
        attn = score.softmax(dim=-1)
        out = torch.matmul(attn, v)
        return out

    def forward(self, x, mask):
        # x: [B, S, D]
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        batch = q.shape[0]
        q, k, v = self.reshape(q), self.reshape(k), self.reshape(v)
        out = self.self_attn(q, k, v, mask)

        # merge heads: [B, H, S, D/H] -> [B, S, D]
        out = out.transpose(1, 2).reshape(batch, -1, self.dim)

        out = self.out_proj(out)
        return out

## Onnx graph of multi head attention
![onnx graph of multi head attention](./media/multihead-attention-onnx.png)

compare with the single head self-attention, multi-head attention has **addtional reshape and transpose** ops before the self-attention compute, split out one head dim and then moves the head dimension as batch, and do normal attention, and **then transpose->reshpae** back, so the onnx graph is more complicated.

Following 9 TRT kernels are launched for the multi-head attention:

![multi-head attention kernels](./media/multihead-attention-trt-kernel.png)

## 3. Layer Norm

In Pytorch LayerNorm has special op and dedicated kernel, but when exporting that to onnx, it's lowered to a sequence of ops.

![layer norm](./media/layer-norm.png)

TRT needs to recoginize this patten and fuse them into one kernel again.



In [1]:
import torch
import torch.nn as nn

norm = nn.LayerNorm(512) 
x = norm(torch.rand(1, 32, 512))

The onnx graph looks like this

![onnx graph of layer norm](./media/layer-norm-onnx.png)


When running LayerNorm only module by TRT, the reduce op can are break point which prevents the fusion.
And the TRT graph looks like this:

![TRT graph of layer norm](./media/layer-norm-trt.png)

But when running the whole model, layer norm is fused into one kernel since TRT use special backend to handle the transformer.

## 4. Whole Transformer Decoder Block

Whole decoder block is a sequence of multihead-attention, layer norm and residual, feed forward, layer norm and residual.

![trt kernels of transformer decoder block](./media/decoder-attention-trt-kernel.png)

Whoel decoder block needs 13 kernels in TRT, addtional 4 kernels compared with the multi head attention.

- 2 LayerNorm+Residual. LayerNorm and the residual are fused together, which is good. 
- 1 GELU kernel
- 2 MLP gemm. 

## 5 Attention mask

Mask is used to zero out scores in certain positions preventing them from being attended to.

To prevent the current token from attending to the future tokens (since the GPT is trained and used in auto-regressive way, no token can know what future token is), a mask is applied to the attention layer, which is a matrix with 0s in the upper triangle (the right positions are masked out) and 1s in the lower triangle. 

The mask shape is `[S, S]` where `S` the seq length.

![decoder mask](./media/decoder-mask.png)

When feeding the model in batches, the tokens of one sentence can be pre-padded to the same length, and a `[BATCH, PADDED LENGTH]` (abbrev as `[B, S]`) binary mask can be applied to prevent the attention layer from attending to the padded tokens, calling it padding mask in the following.

Differences of two masks:
1. the attention mask is a triu matrix of shape `[S, S]`, which is applied exact same way for all samples in the batch (broadcasted in batch dim).
2. the padding mask is no a triu matrix, the broadcasted shape is `[B, 1, S]`, it's broadcasted in the `k/v` dim, since the some `k/v` should be masked out.


## 6. K/V Cache Optimization for inference

If do not consider K/V cache, each iteration the GPT model generate one token, and that token is appended to the end of sequence, and the whole sequence is fed into the model again to generate the next token.

This is not efficient for inference, since the K/V vector (output tensors of K/V gemm) of the context of each iteration is already computed in the previous iteration.

GPT decoder has 2 stages:

1. First stage the past K/V is empty, and the sequence of prompt tokens are fed in.
    K, V of these context (prompt) are computed and outputed.
    One token is generated as next token.

2. Second stage. Repeat until the <end> token is genereted.
    Past K/V + last token is fed in -> Current K/V and current token is generated.


Attention with K/V cache ONNX graph looks like follows

![attention with K/V cache](./media/multihead-attention-kv-cache-onnx.png)

TRT kernels of decoder layer with attention using K/V cache looks like follows:

![decoder layer with K/V cache](./media/multihead-attention-trt-kernel-kv-cache.png)

Note that all the K/V slice concat reshape transpose are fused into one kernel. Whose fusion is even better than the multi-head attention without K/V cache.
W/o cache, there are 13 kernels shown above.