In [3]:
import numpy as np
from collections import namedtuple

# Homework 12 - Writing GPU Kernels


The homework will be an extension of what we did in the lecture.
This time, we will implement the matrix multiplication kernel using Triton.
To simplify the problem, we will assume that the matrices are square and of the same size.

In [4]:
def matrix_multiplication(blockIdx, blockDim, threadIdx, A, B, C):
    """
    CUDA kernel for matrix multiplication C = A @ B
    Where A and B are square matrices of size n x n

    Your task is to implement the matrix multiplication kernel
    (see https://www.mathsisfun.com/algebra/matrix-multiplying.html,
    chapter "Multiplying a Matrix by Another Matrix")

    Note that right now the blocks and threads are two-dimensional, so that we can
    process 2D data (matrices). The x-dimension indexes the rows of the
    output matrix C, and the y-dimension indexes the columns of C.
    This allows the kernel to execute in parallel computation of
    ANY output element of C, with the proper indexing of a 2D matrix.

    This does not have to be the most efficient implementation,
    but it should be correct and demonstrate your understanding of
    the problem.
    
    Note: Be careful with the boundary conditions for the threads.
    We do not want to access elements outside the matrix
    ("step" into the undefined memory space).

    :param blockIdx: Block index in the grid
    :param blockDim: Block dimension
    :param threadIdx: Thread index in the block
    :param A: First input matrix
    :param B: Second input matrix
    :param C: Output matrix
    """
    # Calculate row and column for this thread
    row = blockIdx.x * blockDim.x + threadIdx.x
    col = blockIdx.y * blockDim.y + threadIdx.y
    
    # <your_code_here>
    
    # </your_code_here>


def run_kernel(*kernel_args):
    NUM_THREADS = 64

    # Define the dimensions of the grid and block
    DimGrid = namedtuple("block_dimensions", ["x", "y"])
    DimBlock = namedtuple("thread_dimensions", ["x", "y"])
    CurrentBlock = namedtuple("current_block", ["x", "y"])
    CurrentThread = namedtuple("current_thread", ["x", "y"])

    # DimGrid is an object that holds the number of blocks in the x and y dimensions
    dim_grid = DimGrid(
        np.ceil(n / NUM_THREADS).astype(np.int32),
        np.ceil(n / NUM_THREADS).astype(np.int32),
    )
    # DimBlock is an object that holds the number of threads in the x and y dimensions
    dim_block = DimBlock(NUM_THREADS, NUM_THREADS)

    for block_i in range(dim_grid.x):
        for block_j in range(dim_grid.y):
            for thread_i in range(dim_block.x):
                for thread_j in range(dim_block.y):
                    matrix_multiplication(
                        CurrentBlock(block_i, block_j),
                        dim_block,
                        CurrentThread(thread_i, thread_j),
                        *kernel_args,
                    )


n = 16
# define matrix A
A = np.random.randn(n, n)
# define matrix B
B = np.random.randn(n, n)
# the result of A*B is C
C = np.empty((n, n))

run_kernel(A, B, C)

assert np.allclose(C, A @ B)


In [None]:
import os
os.environ["TRITON_INTERPRET"] = "1"

import triton
import triton.language as tl
import torch

@triton.jit
def matrix_multiplication_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Matrix multiplication using 2D program IDs
    
    Notes:
    
    1) This time we will use the 2D program IDs to index the matrix elements.
    i.e.
    ```
    pid_row = tl.program_id(0)
    pid_col = tl.program_id(1)
    ```
    
    2) Most likely you will use an "accumulator" to store the result of the matrix multiplication.
    I recommend to initialize it with zeros:
    ```
    acc = tl.full((BLOCK_SIZE, BLOCK_SIZE), 0.0, dtype=tl.float32)
    ```
    
    This task may be potentially challenging to some of you. 
    If you have any questions, please ask in the Discord.
    Also feel free to converse with AI to get help and seek clarification.
    
    
    :param A_ptr: Pointer to matrix A (shape: M x K)
    :param B_ptr: Pointer to matrix B (shape: K x N)
    :param C_ptr: Pointer to output matrix C (shape: M x N)
    :param M: Number of rows in A and C
    :param N: Number of columns in B and C
    :param K: Number of columns in A and rows in B
    :param BLOCK_SIZE: Tile size for all dimensions
    """
    # <your_code_here>
    # Suggested approach:
    # 1) Get 2D program ID
    # 2) Compute row/column offsets
    # 3) Create masks (x and y dimensions)
    # 4) Initialize accumulator
    # 5) Loop over K dimension
    # 6) Load A and B tiles
    # 7) Perform matrix multiplication
    # 8) Store result
    # </your_code_here>
    
# Test the implementation
n = 16
A = torch.randn((n, n), device='cuda')
B = torch.randn((n, n), device='cuda')
M, K = A.shape
K, N = B.shape

# Allocate output
C = torch.empty((M, N), device=A.device, dtype=A.dtype)

# Define block size
BLOCK_SIZE = 64

# Launch kernel
grid = lambda meta: (
    triton.cdiv(M, meta['BLOCK_SIZE']),
    triton.cdiv(N, meta['BLOCK_SIZE'])
)

matrix_multiplication_kernel[grid](
    A, B, C,
    M, N, K,
    BLOCK_SIZE
)

assert torch.allclose(C, A @ B, atol=1e-2)