# Transformer blocks in Depth

This blog target to analyze the transformer blocks in depth, from pytorch code to the kernel level implementation, and discuss the details of the potential optimizations.

In [2]:
import torch
import torch.nn as nn

## 1. Attention Layer

### 1.1. Vanilla self-attention layer in Pytorch

In [3]:
class SelfAtten(nn.Module):

    def __init__(self, dim=512) -> None:
        super().__init__()
        self.dim = dim
        self.qkv_proj = nn.Linear(dim, dim*3)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x, mask):
        # x: [B, S, D]
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        # [B, S, S]
        score = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5) 
        score = score.masked_fill(mask == 0, -1e3)
        attn = score.softmax(dim=-1)
        # [B, S, D]
        out = torch.matmul(attn, v)

        out = self.out_proj(out)
        return out

## Kernels

Pytorch implementation of the above layer requires 10 kernels, which are listed below:

4 GEMM kernels.
1 Softmax kernel.
And some elementwise for the mask compute.

![kernels](./media/attention-torch-kernel-trace.png)



## Onnx graph

![onnx graph of attention layer](./media/attention.png)

## TensorRT kernels

After the onnx lowered to TRT, some elementwise kernels and softmax kernels are fused together, there are only 7 kernels in total.
![kernels](./media/attention-trt-kernel.png)



Analyze of the fusions:

- Bias of first GEMM, and 3 slices (tensor.chunk) are fused together into _myl_bb0_3_AddSliSliSli
    The shape compute is done in host, no kernels needed.
    
- Transpose of the QK^T gemm is fused with Matmul
- Mask and softmax are fused together into _myl_bb0_2_*
- The bias add after output projection gemm is not fused, which is not good.




## 2. Multi-head

In [None]:
import torch
import torch.nn as nn

class MultiHeadSelfAtten(nn.Module):

    def __init__(self, dim=512, head=8) -> None:
        super().__init__()
        self.dim = dim
        self.head = head
        self.qkv_proj = nn.Linear(dim, dim*3)
        self.out_proj = nn.Linear(dim, dim)

    # [B, S, D] -> [B, H, S, D/H]
    def reshape(self, x):
        return x.reshape(x.shape[0], -1, self.head, self.dim // self.head).transpose(1, 2)

    def self_attn(self, q, k, v, mask):
        # q, k, v: [B, H, S, D/H]
        # out: [B, H, S, D/H]
        score = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5) 
        score = score.masked_fill(mask == 0, -1e3)
        attn = score.softmax(dim=-1)
        out = torch.matmul(attn, v)
        return out

    def forward(self, x, mask):
        # x: [B, S, D]
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        batch = q.shape[0]
        q, k, v = self.reshape(q), self.reshape(k), self.reshape(v)
        out = self.self_attn(q, k, v, mask)
        # merge heads: [B, H, S, D/H] -> [B, S, D]
        out = out.transpose(1, 2).reshape(batch, -1, self.dim)

        out = self.out_proj(out)
        return out

## Onnx graph of multi head attention
![onnx graph of multi head attention](./media/multihead-attention-onnx.png)

compare with the single head self-attention, multi-head attention has addtional reshape and transpose before the self-attention compute, which moves the head dimension as batch dim, and do attention, and then transpose->reshpae back, so the onnx graph is more complicated.

Following 9 TRT kernels are launched for the multi-head attention:

![multi-head attention kernels](./media/multihead-attention-trt-kernel.png)

## 3. Layer Norm

In [1]:
import torch
import torch.nn as nn

norm = nn.LayerNorm(512) 
x = norm(torch.rand(1, 32, 512))

In pytorch, the layer norm is done by one kernel. The onnx graph looks like this
![onnx graph of layer norm](./media/layer-norm-onnx.png)


## 4. Whole Transformer Decoder Block

![trt kernels of transformer decoder block](./media/decoder-attention-trt-kernel.png)

Whoel decoder block needs 13 kernels in TRT, addtional 4 kernels compared with the multi head attention.

- 2 LayerNorm+Residual. LayerNorm and the residual are fused together, which is good. 
- 1 GELU kernel
- 2 MLP gemm. 