# Transformer blocks in Depth

This blog target to analyze the transformer blocks in depth, from pytorch code to the kernel level implementation, and discuss the details of the potential optimizations.

In [2]:
import torch
import torch.nn as nn

## 1. Attention Layer

### 1.1. Vanilla self-attention layer in Pytorch

In [3]:
class SelfAtten(nn.Module):

    def __init__(self, dim=512) -> None:
        super().__init__()
        self.dim = dim
        self.qkv_proj = nn.Linear(dim, dim*3)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x, mask):
        # x: [B, S, D]
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        # [B, S, S]
        score = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5) 
        score = score.masked_fill(mask == 0, -1e3)
        attn = score.softmax(dim=-1)
        # [B, S, D]
        out = torch.matmul(attn, v)

        out = self.out_proj(out)
        return out

## Kernels

Pytorch implementation of the above layer requires 10 kernels, which are listed below:

4 GEMM kernels.
1 Softmax kernel.
And some elementwise for the mask compute.

![kernels](./media/attention-torch-kernel-trace.png)



## Onnx graph

![onnx graph of attention layer](./media/attention.png)

## TensorRT kernels

After the onnx lowered to TRT, some elementwise kernels and softmax kernels are fused together, there are only 7 kernels in total.
![kernels](./media/attention-trt-kernel.png)



Analyze of the fusions:

- Bias of first GEMM, and 3 slices (tensor.chunk) are fused together into _myl_bb0_3_AddSliSliSli
    The shape compute is done in host, no kernels needed.
    
- Transpose of the QK^T gemm is fused with Matmul
- Mask and softmax are fused together into _myl_bb0_2_*
- The bias add after output projection gemm is not fused, which is not good.




## 2. Multi-head