# Tile Language in BitBLAS

More flexiable, More Efficient Tile Programming Languange compared with Triton

## Features

- **Simplified Syntax**: Write GPU kernels with a more straightforward and expressive syntax.
- **High Performance**: Achieve performance comparable to manually optimized implementations.
- **Advanced Operations**: Support for complex operations like convolutions, flash-attention, and normalizations.
- **Compatibility**: Works with modern CUDA architectures.

## OP Examples

- [Matrix Multiplication](#quick-start)
- [Flash Attention](#flash-attention)
- [Dequantization GEMM](#dequantization-gemm)
- [RetNet](#retina-net)
- [MAMBA](#mamba)



In [2]:
# Import Tile Language from bitblas
from bitblas import tvm as tvm
from tvm import tl
import tvm.tl.language as T

## Get Started with a GEMM Example

In [8]:
M = N = K = 256

A_shape = (M, K)
B_shape = (N, K)
C_shape = (M, N)
in_dtype = out_dtype = accum_dtype = "float16"

block_M = block_N = 128
block_K = 32
threads = 128
num_stages = 2

A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
@T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(
    (M, N), out_dtype)):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
        A_shared = T.alloc_shared(A_shared_shape, in_dtype)
        B_shared = T.alloc_shared(B_shared_shape, in_dtype)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
            T.copy(A[by * block_M, k * block_K], A_shared)
            T.copy(B[bx * block_N, k * block_K], B_shared)
            T.gemm(A_shared, B_shared, C_local, transpose_A=False, transpose_B=True)
        T.copy(C_local, C[by * block_M, bx * block_N])

func = main

In [13]:
rt_mod, params = tl.lower(func)
print(rt_mod.imported_modules[0].get_source())

#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>

extern "C" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  half_t C_local[128];
  #pragma unroll
  for (int i = 0; i < 64; ++i) {
    *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 4; ++i_1) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 32768) + (i_1 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)threadIdx.x) & 3) 

In [15]:
mod = tl.Profiler(rt_mod, params, [2], tl.TensorSupplyType.Integer)

def ref_program(A, B):
    import torch
    B = B.T
    C = torch.matmul(A.to(torch.float), B.to(torch.float))
    C = C.to(torch.__getattribute__(out_dtype))
    return C

mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("Assert Pass")


Assert Pass


## Manipulate Data Layout and Pipeline

TL also provide interface for users to manupulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:

In [1]:
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    @T.prim_func
    def main(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            
            # Apply memory layout optimizations
            # Or you can define your own memory layout
            T.annotate_layout({
                A_shared: make_swizzle_layout(A_shared),
                B_shared: make_swizzle_layout(B_shared),
            })

            # Enable rasterization for better L2 Cache Locality
            T.use_swizzle(panel_size=10, enable=enable_rasterization)

            # Clear the local buffer
            T.clear(C_local)

            # Auto pipeline the computation
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, k * block_K], A_shared)

                # Instead of using
                # T.copy(B[k * block_K, bx * block_N], B_shared)
                # we can also use Parallel to auto map the thread
                # bindings and vectorize the copy operation.
                for k, j in T.Parallel(block_K, block_N):
                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]

                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C[by * block_M, bx * block_N])

    return main

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


## Implement Dequantize GEMM with simple Syntax

In [2]:
@T.prim_func
def dequant_matmul(
    A: T.Buffer(A_shape, in_dtype),
    B: T.Buffer(B_shape, storage_dtype),
    Ct: T.Buffer((N, M), out_dtype),
):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
        A_shared = T.alloc_shared(A_shared_shape, in_dtype)
        B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
        B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
        B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
        Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)

        T.clear(Ct_local)
        for k in T.Pipelined(
            T.ceildiv(K, block_K), 
            num_stages=num_stages
        ):
            T.copy(A[by * block_M, k * block_K], A_shared)
            T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
            T.copy(B_shared, B_local)
            for i, j in T.Parallel(block_N, block_K):
                B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)(
                    num_bits,
                    B_local[i, j // 2],
                    j % 2,
                    dtype=in_dtype,
                )
            T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
        T.copy(Ct_local, Ct[bx * block_N, by * block_M])

NameError: name 'T' is not defined

## If you want fine-grained control over dequantization at the thread leve

In [None]:
@T.prim_func
def main(
        A: T.Buffer(A_shape, in_dtype),
        B: T.Buffer(B_shape, storage_dtype),
        C: T.Buffer((M, N), out_dtype),
):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
        A_shared = T.alloc_shared(A_shared_shape, in_dtype)
        B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
        B_local = T.alloc_local([local_size_compressed], storage_dtype)
        B_dequantize_local = T.alloc_local([local_size], in_dtype)
        B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

        tx = T.thread_binding(0, threads, thread="threadIdx.x")

        T.clear(C_local)
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
            T.copy(A[by * block_M, k * block_K], A_shared)
            T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)

            for i in T.serial(block_N * block_K // num_elems_per_byte //
                              (threads * local_size_compressed)):
                for v in T.vectorized(0, local_size_compressed):
                    index = i * threads * local_size_compressed + tx * local_size_compressed + v
                    vi = index // (block_K // num_elems_per_byte)
                    vj = index % (block_K // num_elems_per_byte)
                    B_local[v] = B_shared[vi, vj]
                for v in T.serial(0, local_size):
                    B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
                        storage_type, storage_nbit)(
                            num_bits,
                            B_local[v // num_elems_per_byte],
                            v % num_elems_per_byte,
                            dtype=in_dtype,
                        )
                for v in T.vectorized(0, local_size):
                    index = i * threads * local_size + tx * local_size + v
                    vi = index // block_K
                    vj = index % block_K
                    B_dequantize_shared[vi, vj] = B_dequantize_local[v]

            T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)

        T.copy(C_local, C[by * block_M, bx * block_N])

## Flash Attention V3

In [None]:
@T.prim_func
def flash_attention_v3(
    Q: T.Buffer(shape, dtype),
    K: T.Buffer(shape, dtype),
    V: T.Buffer(shape, dtype),
    Output: T.Buffer(shape, dtype),
):
    with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
        Q_shared = T.alloc_shared([block_M, dim], dtype)
        K_shared = T.alloc_shared([block_N, dim], dtype)
        V_shared = T.alloc_shared([block_N, dim], dtype)
        acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
        acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
        acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
        scores_max = T.alloc_fragment([block_M], accum_dtype)
        scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
        scores_scale = T.alloc_fragment([block_M], accum_dtype)
        scores_sum = T.alloc_fragment([block_M], accum_dtype)
        logsum = T.alloc_fragment([block_M], accum_dtype)

        T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
        T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
        T.fill(acc_o, 0)
        T.fill(logsum, 0)
        T.fill(scores_max, -T.infinity(accum_dtype))
        loop_range = (
            T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)
        )
        for k in T.Pipelined(loop_range, num_stages=num_stages):
            T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
            if is_casual:
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.if_then_else(
                        bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
                    )
            else:
                T.clear(acc_s)
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
            T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
            for i, j in T.Parallel(block_M, dim):
                acc_s[i, j] *= scale
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]
            for i, j in T.Parallel(block_M, block_N):
                acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
            T.copy(acc_s, acc_s_cast)
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
        for i, j in T.Parallel(block_M, dim):
            acc_o[i, j] /= logsum[i]
        T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])