## GPUs and challenges in programming them. Meet Triton

<img src="https://openaicom.imgix.net/778bccdf-6cb5-4d9f-8811-f8df6da52d84/gpu-architecture.svg?fm=auto&auto=compress,format&fit=min&w=1919&h=612" width="800" height="500">
Source: https://openaicom.imgix.net/778bccdf-6cb5-4d9f-8811-f8df6da52d84/gpu-architecture.svg?fm=auto&auto=compress,format&fit=min&w=1919&h=612

The architecture of modern GPUs can be roughly divided into three major components—DRAM, SRAM and ALUs—each of which must be considered when optimizing CUDA code:

- Memory transfers from DRAM must be coalesced into large transactions to leverage the large bus width of modern memory interfaces.
- Data must be manually stashed to SRAM prior to being re-used, and managed so as to minimize shared memory bank conflicts upon retrieval.
- Computations must be partitioned and scheduled carefully, both across and within Streaming Multiprocessors (SMs), so as to promote instruction/thread-level parallelism and leverage special-purpose ALUs (e.g., tensor cores).

Triton tries to abstact aways most of these challenges away. For example:

- The only way to interact with DRAM is through `load/store` operations. One still has to be careful and cognizant of memory layouts. More on that later.
- The result of `load` is managed automatically and will be placed into SRAM.
- All high-level functionality is done in SRAM.
- Scheduling is done automatically. SM resources (shared memory, registers, number of blocks to be processed) are limited, and it is very easy to make a mistake which will result in hardware underutilization.

Fundamentally, CUDA C kernels operate per-thread basis, while Triton assumes per-block paralellism.

## Import and fix seeds

In [1]:
import torch
import triton
import triton.language as tl

TORCH_SEED = 17
TRITON_SEED = 13

torch.manual_seed(TORCH_SEED)

<torch._C.Generator at 0x7fabd8f4ab70>

## Example 1: unary element-wise operations.

Unary element-wise operations are among the easiest to parallelize and learn basic Triton from.
Let's implement a kernel that fuses ReLU (Rectified Linear Unit) with dropout.

Recall that

\begin{align}
ReLU(x) & = \max(x, 0), \\
dropout(x, p) & = \begin{cases}
 0 \text{ w. p. } p, \\
 \frac{x}{1-p} \text{ w. p. } 1 - p,
\end{cases}
\text{, so that } \mathbb{E}[dropout(x, p)] = x.
\end{align}

In [2]:
@triton.jit                                                                                                            
def _triton_relu_dropout_kernel(in_ptr, out_ptr, numel, p, seed, BLOCK: tl.constexpr):
    # Get block id.                                        
    block_id = tl.program_id(axis=0)              
    # Pointer offsets per element in a block.                                                                          
    offset = block_id * BLOCK + tl.arange(0, BLOCK)        
    # Load data but make sure no pointer access past numel.
    # NOTE: in_ptr + offset is a block of pointers!
    res = tl.load(in_ptr + offset, mask=offset < numel)                                                                
    # Sample dropout probabilities.                        
    probs = tl.rand(seed, offset)                                                                                      
    # Apply relu and dropout.                              
    res = tl.where(res > 0 and probs > p, res / (1. - p), 0.)                                                          
    # Write to DRAM but make sure no pointer access past numel.                                                        
    tl.store(out_ptr + offset, res, mask=offset < numel)                                                               

# Public API
def triton_relu_dropout(x, p=0.5, seed=TRITON_SEED, BLOCK=1024):
    # ND array -> 1D array                                 
    x_vector = x.reshape(-1)                                                                                           
    out_vector = torch.empty_like(x_vector)                                                                            
    grid = (triton.cdiv(x_vector.numel(), BLOCK),)                                                                     
    _triton_relu_dropout_kernel[grid](x_vector, out_vector, x_vector.numel(), p, seed, BLOCK=BLOCK)
                                                                                                                       
    return out_vector.reshape(x.shape)

# PyTorch eager counterpart
def pytorch_relu_dropout(x, p=0.5):                        
    import torch.nn.functional as F                                                                                    
    return F.dropout(F.relu(x), p=p, inplace=True) 

Some basic benchmarks:

In [3]:
x = torch.rand(1000, 1000, device='cuda')  # float32 by default
%timeit triton_relu_dropout(x); torch.cuda.synchronize()  # NEVER FORGET to sync devices!!!
%timeit pytorch_relu_dropout(x); torch.cuda.synchronize()

x = x.to(torch.half)
%timeit triton_relu_dropout(x); torch.cuda.synchronize()
%timeit pytorch_relu_dropout(x); torch.cuda.synchronize()

# bfloat16 requires GPUs with compute capability >= 8.0
# x = x.to(torch.bfloat16)
# %timeit triton_relu_dropout(x); torch.cuda.synchronize()
# %timeit pytorch_relu_dropout(x); torch.cuda.synchronize()

47.5 µs ± 292 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
119 µs ± 54.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
36.6 µs ± 66.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
55.3 µs ± 50.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


##### Not bad! Now we know how to fuse several unary kernels into one and turn it from memory bound into compute bound!
Alternatively, one could use `torch.jit` and the latest coolest feature of PyTorch2 - `torch.compile`.
We will not discuss these much here, however.

## Example 2: operations with reductions.

Implementing efficient reductions in CUDA C is not at all trivial and requires some expert level skills.
But first, some ...

### Preliminaries.
#### Addresses of elements in an ND array.
Since Triton uses pointer arithmetic, we need to know how to offset pointers to right positions in an ND array.

If $A$ is an array of dimension $d$, then the address of element $A[i_0, ..., i_{d-1}]$, id($A[i]$), is 
$\mathbf{id(A) + i_0 \cdot A.stride(0) + ... + i_{d-1} \cdot A.stride(d-1)}.$ 

For example:

In [4]:
t = torch.arange(3 * 4 * 5)
t3 = t.reshape(3, 4, 5)
assert t3[2, 3, 1] == t[2 * t3.stride(0) + 3 * t3.stride(1) + 1 * t3.stride(2)]
assert t3[1, 2, 3] == t[1 * t3.stride(0) + 2 * t3.stride(1) + 3 * t3.stride(2)]

#### Memory contiguity.
Moving data from/to DRAM is fast when the chunk of memory that is being read from/written to is contiguous through the so called DRAM bursts.

SRAM, however, does not have such restrition and could be accessed in a non-contiguous manner.

We should be cognizant of that because, even though Triton handles most cases of data movement efficiently, it does not handle them all. You will see examples of these down below.

We say that a ND array $A$ is $\mathbf{contiguous}$ if for any indices
$i = (i_1, ..., i_d)$ and $j = (j_1, ..., j_d)$ such that $i < j$ lexicographically we have that

\begin{align}
\text{id}(A) &\le \text{id}(A[i]) < \text{id}(A[j]), \text{ and} \\
\text{id}(A[i]) & < \text{id}(A) + A.\text{size, if you prefer NumPy or } \\
\text{id}(A[i]) & < \text{id}(A) + A.\text{numel()/size(), if you prefer PyTorch}.
\end{align}

Equivalently, using strides, $A$ is $\mathbf{contiguous}$ if
\begin{align}
A.stride(-1) &= 1, \\
A.stride(-i-1) &= A.shape[-i] \cdot A.stride(-i), \text{ for } i \in \{1, ..., A.dim()\}.
\end{align}
$\mathbf{NOTE:}$ the strides are non-negative.

In [5]:
t = torch.rand(2, 3, 4, 5)
assert t.is_contiguous()

def check_contiguous(t):
    isc = (t.stride(-1) == 1)
    for i in range(1, t.dim()):
        isc = (t.stride(-i-1) == t.shape[-i] * t.stride(-i)) and isc
    return isc

assert check_contiguous(t) == True
assert check_contiguous(t.transpose(-1, -2)) == False
assert check_contiguous(t.transpose(0, -1)) == False

x = torch.rand(3, 4)
print(x.stride())  # x is row-major
y = x.mT.contiguous().mT
print(y.stride())  # y is column-major

(4, 1)
(1, 3)


$A$ is $\mathbf{Triton-contiguous}$ if just id($A) \le $ id($A[i]) < $ id($A$) + $A$.size hold for all $i < A$.shape.

As long as the DRAM block that we read from/write to is $\mathbf{Triton-contiguous}$, all `tl.load/tl.store` operations with it are going to be efficient as well!

In most practical cases (see below) we can also be fine as long as there is a single dimension with stride 1, provided that it is large enough compared to other dims. Triton compiler is smart and will organize contiguous DRAM-SRAM movements if strides permit that! For example, when working with square 2D block, it does not matter if they are row-major or column-major.

### Numerically stable softmax

Recall that a softmax is a map
\begin{align}
\text{softmax}: x=(x_1, ..., x_n) \mapsto 
    \bigg( \frac{\exp(x_1)}{\exp(\sum_{i=1}^n x_i)}, ...,
           \frac{\exp(x_n)}{\exp(\sum_{i=1}^n x_n}  \bigg).
\end{align}

We can see that softmax($x) = $ softmax($x + \alpha$) for any scalar $\alpha$. To utilize mantissa better, $\alpha$ is chosen to be $\alpha = -\max(x_1, ..., x_n)$.

This is the stable kernel we want to implement, but we will further generalize it to ND arrays with the `dim` argument support similar to what PyTorch has.

In [14]:
@triton.jit                                                                                                                                                                                   
def _triton_softmax_kernel(                                                                    
    in_ptr, out_ptr,                                                                           
    row_stride,                                                                                
    dim_size,                                                                                  
    BLOCK: tl.constexpr):                                                                      
    row_idx = tl.program_id(axis=0)                                                            
    in_row_ptr = in_ptr + row_idx * row_stride                                                 
    out_row_ptr = out_ptr + row_idx * row_stride                                               
                                                                                               
    ##################################                                                                                                                                                        
    # Stage 1. Find max in each row. #                                                         
    ##################################                                                         
    col_offsets = tl.arange(0, BLOCK)                                                          
    col_tile = tl.load(in_row_ptr + col_offsets, mask=col_offsets < dim_size, other=-float('inf'))                                                                                            
    max_col_val = tl.max(col_tile, axis=0)                                                     
    for _ in range(BLOCK, dim_size, BLOCK):                                                    
        col_offsets += BLOCK                                                                   
        col_tile = tl.load(in_row_ptr + col_offsets, mask=col_offsets < dim_size, other=-float('inf'))                                                                                        
        curr_col_max = tl.max(col_tile, axis=0)                                                
        max_col_val = tl.where(max_col_val > curr_col_max, max_col_val, curr_col_max)                                                                                                         
                                                                                                                                                                                              
    ##############################                                                                                                                                                            
    # Stage 2. Find denominator. #                                                             
    ##############################                                                             
    num = tl.exp(col_tile - max_col_val)                                                       
    denom = tl.sum(num, axis=0)                                                                
    for _ in range(BLOCK, dim_size, BLOCK):                                                                                                                                                   
        col_offsets -= BLOCK                                                                                                                                                                  
        col_tile = tl.load(in_row_ptr + col_offsets, mask=col_offsets < dim_size, other=-float('inf'))                                                                                        
        num = tl.exp(col_tile - max_col_val)                                                   
        denom += tl.sum(num, axis=0)                                                                                                                                                          
                                                                                               
    #############################                                                                                                                                                             
    # Stage 3. Populate output. #                                                              
    #############################                                                              
    tl.store(out_row_ptr + col_offsets, num / denom, mask=col_offsets < dim_size)                                                                                                             
    for _ in range(BLOCK, dim_size, BLOCK):                                                    
        col_offsets += BLOCK                                                                                                                                                                  
        col_tile = tl.load(in_row_ptr + col_offsets, mask=col_offsets < dim_size)                                                                                                             
        num = tl.exp(col_tile - max_col_val)                                                                                                                                                  
        tl.store(out_row_ptr + col_offsets, num / denom, mask=col_offsets < dim_size)                                                                                                         

# Public API
def triton_softmax(x, dim=-1, BLOCK=None):                                                     
    assert -x.dim() <= dim < x.dim()                                                           
    dim = dim + x.dim() if dim < 0 else dim                                                    
    dim_size = x.shape[dim]                                                                    
                                                                                               
    # Make sure scans/reductions are done over contiguous dim.                                         
    xt = x.transpose(dim, -1)                                                                  
    if xt.stride(-1) != 1:                                                                     
        xt = xt.contiguous()                                                                                                                                                                  
    xt_shape = xt.shape                                                                        
    # Squash first xt.dim() - 1 dims.                                                          
    xt_2d = xt.reshape(-1, dim_size)                                                           
                                                                                               
    # Allocate output                                                                          
    res_2d = torch.empty_like(xt_2d)                                                           
     
    # NOTE: we can use blocks larger than CUDA blocks.
    if BLOCK is None:                                                                          
        BLOCK = triton.next_power_of_2(xt_2d.size(-1))
    
    # In modern GPUs, warp is a group of 32 threads
    # that execute the same instruction set in a SIMD-like manner.
    # The larger the block, the more threads we put to work (i.e. larger CUDA block).
    num_warps = 4                                                                                                                                                                             
    if BLOCK >= 2048:                                                                                                                                                                         
        num_warps = 8                                                                                                                                                                         
    if BLOCK >= 4096:                                                                          
        num_warps = 16                                                                         
    # Launch kernel, parallelization over rows.                                                                                                                                                                           
    grid = (xt_2d.size(0),)                                                                    
    _triton_softmax_kernel[grid](xt_2d, res_2d, xt_2d.stride(0), dim_size, BLOCK=BLOCK, num_warps=num_warps)                                                                                  
                                                                                                                                                                                              
    return res_2d.reshape(xt_shape).transpose(dim, -1)                                         
                                                                                               
def pytorch_softmax(x, dim=-1, **kwargs):                                                                                                                                                     
    return torch.nn.functional.softmax(x, dim=dim)                                             


In [7]:
x = torch.rand(3000, 3000, device='cuda')
for dim in (0, 1):
    x_pt = pytorch_softmax(x, dim=dim)
    for block in (None, 1024):
        x_tri = triton_softmax(x, dim=dim, BLOCK=block)
        print((x_tri - x_pt).abs().max())

tensor(1.3388e-09, device='cuda:0')
tensor(1.3388e-09, device='cuda:0')
tensor(1.7462e-10, device='cuda:0')
tensor(2.3283e-10, device='cuda:0')


Some benchmarks:

In [8]:
from torch.utils.benchmark import Timer, Compare
ns = (100, 500, 1000, 2000, 3000)
bench_results = []
for n in ns:
    x = torch.rand(n, n, device='cuda')
    label = f"softmax {n} x {n}"
    for dim in (0, 1):
        for f in (triton_softmax, pytorch_softmax):
            if f is triton_softmax:
                blocks = (512, 1024, None)
            else:
                blocks = (None,)
            for block in blocks:
                sub_label = f"dim={dim}, block={block}"
                smpt = f"{f.__name__}(x, dim={dim}, BLOCK={block})"
                timer = Timer(smpt,
                                globals=globals(),
                                label=label,
                                description=f"{f.__name__}",
                                sub_label=sub_label)
                bench_results.append(timer.blocked_autorange())

compare = Compare(bench_results)
compare.trim_significant_figures()
compare.print()

[-------------------- softmax 100 x 100 ---------------------]
                         |  triton_softmax  |  pytorch_softmax
1 threads: ---------------------------------------------------
      dim=0, block=512   |        30        |                 
      dim=0, block=1024  |        33        |                 
      dim=0, block=None  |        31        |        32.5     
      dim=1, block=512   |        20        |                 
      dim=1, block=1024  |        19        |                 
      dim=1, block=None  |        20        |         5.2     

Times are in microseconds (us).

[-------------------- softmax 500 x 500 ---------------------]
                         |  triton_softmax  |  pytorch_softmax
1 threads: ---------------------------------------------------
      dim=0, block=512   |        30        |                 
      dim=0, block=1024  |        30        |                 
      dim=0, block=None  |        33        |        245      
      dim=1, block=51

## Example 3: matrix multiplication

Given matrices $A, B$ of shape $(..., m, k), (..., k, n)$, respectively, the goal is to produce their matrix product of shape $(..., m, n)$.

We split dimensions $m, n, k$ into tiles of size BLOCK_M, BLOCK_N, BLOCK_K and compute a block-wise matrix multiplication

\begin{align}
C[..., b_m(i), b_n(j)] &= \sum_k A[..., b_m(i), b_k(k)] @ B[..., b_k(k), b_n(j)], \text{ where} \\
b_m(i) &= i \cdot \text{BLOCK_M}:i \cdot \text{BLOCK_M} + \text{BLOCK_M}, \\
b_n(j) &= j \cdot \text{BLOCK_N}:j \cdot \text{BLOCK_N} + \text{BLOCK_N}, \\
b_k(k) &= k \cdot \text{BLOCK_K}:k \cdot \text{BLOCK_K} + \text{BLOCK_K}, \\
i & \in \text{range}\bigg(1, \bigg\lceil \frac{m}{\text{BLOCK_M}} \bigg\rceil\bigg), \\
j & \in \text{range}\bigg(1, \bigg\lceil \frac{n}{\text{BLOCK_N}} \bigg\rceil\bigg), \\
k & \in \text{range}\bigg(1, \bigg\lceil \frac{k}{\text{BLOCK_K}} \bigg\rceil\bigg). \\
\end{align}

We parallelize over a 3D grid of size $(n, m, ...)$. Left-most coordinates advance first, so the blocks are row-major ordered! There is a problem with that, however.

For example, in the following matmul where each matrix is 9 blocks by 9 blocks, we can see that if we compute the output in row-major ordering, we need to load 90 blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped ordering, we only need to load 54 blocks.

<img src="https://triton-lang.org/main/_images/grouped_vs_row_major_ordering.png" width="900" height="600">
Source: https://triton-lang.org/main/_images/grouped_vs_row_major_ordering.png

Luckily, Triton has a function for such orderings, `tl.swizzle2d`. We will use it to improve cache locality in the code below!

In [9]:
@triton.jit
def _triton_matmul_kernel(
    x_ptr, y_ptr, out_ptr,
    x_b_stride, x_m_stride, x_k_stride,
    y_b_stride, y_k_stride, y_n_stride,
    out_b_stride, out_m_stride, out_n_stride,
    m, n, k,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr):
    batch_idx = tl.program_id(axis=2)
    block_row_idx = tl.program_id(axis=1)
    block_col_idx = tl.program_id(axis=0)
    n_block_rows = tl.num_programs(axis=1)
    n_block_cols = tl.num_programs(axis=0)
    
    # Row-major order to group order for better L2 re-use.
    block_row_idx, block_col_idx = tl.swizzle2d(
        block_row_idx, block_col_idx,
        n_block_rows, n_block_cols,
        9
    )

    block_m_offsets = tl.arange(0, BLOCK_M)
    block_n_offsets = tl.arange(0, BLOCK_N)
    block_k_offsets = tl.arange(0, BLOCK_K)

    # x_row = &x[b, b_m(block_row_idx), b_k(0)]
    x_row = (x_ptr + batch_idx * x_b_stride
           + block_row_idx * BLOCK_M * x_m_stride
           + block_m_offsets[:, None] * x_m_stride
           + block_k_offsets[None, :] * x_k_stride)

    # y_col = &y[b, b_k(0), b_n(block_col_idx)]
    y_col = (y_ptr + batch_idx * y_b_stride
           + block_col_idx * BLOCK_N * y_n_stride
           + block_k_offsets[:, None] * y_k_stride
           + block_n_offsets[None, :] * y_n_stride)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for block_offset in range(0, k, BLOCK_K):
        mask_k = (block_offset + block_k_offsets) < k
        x_block = tl.load(x_row, mask=mask_k[None, :], other=0.0)
        y_block = tl.load(y_col, mask=mask_k[:, None], other=0.0)
        acc += tl.dot(x_block, y_block)
        # advance pointers
        x_row += BLOCK_K * x_k_stride
        y_col += BLOCK_K * y_k_stride
        
    # These are not used in the loop above.
    # Them re-materialization tells compiler
    # that their memory could be re-used.
    # Registers are super fast! The more we have them, the better!
    # Also, register spills reduce occupancy!
    block_m_offsets = tl.arange(0, BLOCK_M)
    block_n_offsets = tl.arange(0, BLOCK_N)
    
    
    mask_m = (block_row_idx * BLOCK_M + block_m_offsets) < m
    mask_n = (block_col_idx * BLOCK_N + block_n_offsets) < n

    out = (out_ptr + batch_idx * out_b_stride
         + block_row_idx * BLOCK_M * out_m_stride
         + block_col_idx * BLOCK_N * out_n_stride
         + block_m_offsets[:, None] * out_m_stride
         + block_n_offsets[None, :] * out_n_stride)
    tl.store(out, acc.to(out_ptr.dtype.element_ty), mask=(mask_m[:, None] & mask_n[None, :]))

def triton_matmul(x, y, BLOCK_M=None, BLOCK_N=None, BLOCK_K=None):
    assert x.dim() >= 2 and y.dim() >= 2
    assert x.dtype == y.dtype and x.device == y.device
    m, k = x.shape[-2:]
    k1, n = y.shape[-2:]
    assert k == k1
    x_shape = list(x.shape)
    y_shape = list(y.shape)
    x_shape[-1] = 1
    y_shape[-2] = 1
    out_shape = torch.broadcast_shapes(x_shape, y_shape)
    out = torch.empty(*out_shape, dtype=x.dtype, device=x.device)

    # make sure that x and y are "Triton-contiguous"
    def make_triton_contiguous(t):
        # return t if it is already row-/col-major
        if t.is_contiguous() or t.mT.is_contiguous():
            return t
        else:
            return t.contiguous()

    x = make_triton_contiguous(x)
    y = make_triton_contiguous(y)

    def to_3d(t):
        return t.reshape(-1, *t.shape[-2:])

    x_3d = to_3d(x)
    y_3d = to_3d(y)
    out_3d = to_3d(out)  # will always be a view!

    # Launch kernel
    BLOCK_M, BLOCK_N, BLOCK_K = map(lambda t: 64 if t is None else t, (BLOCK_M, BLOCK_N, BLOCK_K))
    b = out_3d.size(0)
    grid = (triton.cdiv(n, BLOCK_N), triton.cdiv(m, BLOCK_M), b)
    _triton_matmul_kernel[grid](x_3d, y_3d, out_3d,
                                *x_3d.stride(), *y_3d.stride(), *out_3d.stride(),
                                m, n, k,
                                BLOCK_M=BLOCK_M,
                                BLOCK_N=BLOCK_N,
                                BLOCK_K=BLOCK_K,
                                num_stages=1)

    return out_3d.reshape(out_shape)


In [10]:
BLOCK=64
n = 4000
x = torch.rand(4, n, n, device='cuda', dtype=torch.half) / BLOCK
y = triton_matmul(x, x.mT)
z = x @ x.mT
print((y - z).abs().max())

tensor(0.0002, device='cuda:0', dtype=torch.float16)


Some benchmarks:

In [11]:
from torch.utils.benchmark import Timer, Compare
from triton.ops import matmul

def pytorch_matmul(x, y, **kwargs):
    return x @ y

def triton_native_matmul(x, y, **kwargs):
    return matmul(x, y)

ns = (100, 500, 1000, 2000, 3000, 4000, 5000)
bench_results = []
for dtype in (torch.float, torch.half):
    for n in ns:
        x = torch.rand(n, n, device='cuda', dtype=dtype)
        label = f"mm {n} x {n}"
        for f in (triton_matmul, pytorch_matmul, triton_native_matmul):
            if dtype is torch.float:
                blocks = (16, 32, None)
            else:
                blocks = (32, None, 128)
                
            if f is triton_native_matmul:
                # Warm-up!
                triton_native_matmul(x, x.mT)
            
            if f is not triton_matmul:
                blocks = (None,)
                kwarg_str = str()
            else:
                kwarg_str = "BLOCK_M=block, BLOCK_N=block, BLOCK_K=block"
                
            for block in blocks:
                sub_label = f"dtype={dtype}, block={block}"
                smpt = f"{f.__name__}(x, x.mT, " + kwarg_str + ")" 
                timer = Timer(smpt,
                                globals=globals(),
                                label=label,
                                description=f"{f.__name__}",
                                sub_label=sub_label)
                bench_results.append(timer.blocked_autorange())

compare = Compare(bench_results)
compare.trim_significant_figures()
compare.print()

[------------------------------------------ mm 100 x 100 -----------------------------------------]
                                       |  triton_matmul  |  pytorch_matmul  |  triton_native_matmul
1 threads: ----------------------------------------------------------------------------------------
      dtype=torch.float32, block=16    |       40.4      |                  |                      
      dtype=torch.float32, block=32    |       40.8      |                  |                      
      dtype=torch.float32, block=None  |       42.9      |       9.4        |          39.7        
      dtype=torch.float16, block=32    |       40.1      |                  |                      
      dtype=torch.float16, block=None  |       40.2      |       9.8        |          40.0        
      dtype=torch.float16, block=128   |       46.6      |                  |                      

Times are in microseconds (us).

[------------------------------------------ mm 500 x 500 ---------

## References

### Triton
- [Triton library documentation](https://triton-lang.org/main/index.html)
- [Triton GitHub page](https://github.com/openai/triton)
- [Tillet, Philippe G. 2020. Blocked Algorithms for Neural Networks: Design and Implementation on
GPUs. Doctoral dissertation, Harvard University Graduate School of Arts and Sciences](https://dash.harvard.edu/bitstream/handle/1/37368966/ptillet-dissertation-final.pdf)

### GPU programming with CUDA C
- [Programming Massively Parallel Processors. A Hands-on Approach by Wen-mei Hwu, David Kirk, Izzat El Hajj](https://www.elsevier.com/books/programming-massively-parallel-processors/hwu/978-0-323-91231-0)
- [Professional CUDA C Programming by John Cheng, Max Grossman, Ty McKercher](https://www.wiley.com/en-fr/Professional+CUDA+C+Programming-p-9781118739327)