### Installation

In [None]:
# clone the wave github repo
!git clone https://github.com/iree-org/wave.git
# install ROCm Pytorch dependencies
!pip install -r pytorch-rocm-requirements.txt
# install wave and its dependencies
!pip install wave-lang

### Wave Tensor/Matrix Addition Kernel Tutorial

#### (1) Importing Libraries

In [None]:
import torch

import wave_lang.kernel.wave as tkw
from wave_lang.kernel._support.dtype import f16
from wave_lang.kernel._support.indexing import sym
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.lang.wave_types import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config

In [None]:
# Define symbolic dimensions for our matrices
M = sym.M # Rows of A, B and C
N = sym.N # Cols of A, B and C

# Define workgroup tile sizes
BLOCK_M = sym.BLOCK_M
BLOCK_N = sym.BLOCK_N

# Define the address space for our memory buffers
ADDRESS_SPACE_A = sym.ADDRESS_SPACE_A
ADDRESS_SPACE_B = sym.ADDRESS_SPACE_B
ADDRESS_SPACE_C = sym.ADDRESS_SPACE_C

#### (2) Constraints
To specify how we want to distribute the different dimensions of our problem.

In [None]:
# Define constraints for the kernel
constraints = [
    # specifies how computation is tiled
    tkw.WorkgroupConstraint(M, BLOCK_M, 0),
    tkw.WorkgroupConstraint(N, BLOCK_N, 1),
    tkw.WaveConstraint(M, BLOCK_M / 2),
    tkw.WaveConstraint(N, BLOCK_N / 2),
    tkw.HardwareConstraint(
        threads_per_wave=64,
        vector_shapes={M: BLOCK_M, N: BLOCK_N}
    )
]

#### (3) Kernel Definition
Our tensor addition example with compute C = A + B. They are 2D matrices of dimension MxN in f16 precision.

In [None]:
# The kernel
@tkw.wave(constraints)
def matrix_add(
    # defines matrix in memory of req dimension with specific data types
    a: Memory[M, N, ADDRESS_SPACE_A, f16],
    b: Memory[M, N, ADDRESS_SPACE_B, f16],
    c: Memory[M, N, ADDRESS_SPACE_C, f16],
):
    # Intialize the accumulator register with zeroes
    c_reg = Register[M,N,f16](0.0)

    # loads values from memory into registers
    a_reg = tkw.read(a)
    b_reg = tkw.read(b)

    # compute the sum
    c_reg = a_reg + b_reg
    
    # writing results back to memory
    tkw.write(c_reg, c)

#### (4) Testing the kernel 
A function to verify our implementation works by comparing with PyTorch reference.

In [None]:
# The kernel testing function
def test_gemm():
    # Create test matrices
    m, n = 128, 128

    # Initialize input matrices with random values
    torch.manual_seed(0)
    a = torch.randn(m, n, dtype=torch.float16, device="cuda")
    b = torch.zeros(m, n, dtype=torch.float16, device="cuda")
    c = torch.zeros(m, n, dtype=torch.float16, device="cuda")

    # Set hyperparameters for compilation
    hyperparams = {
        ADDRESS_SPACE_A: SHARED_ADDRESS_SPACE,
        ADDRESS_SPACE_B: SHARED_ADDRESS_SPACE,
        ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
        BLOCK_M: 64,
        BLOCK_N: 64,
        M: m,
        N: n
    }

    # Compile the kernel
    options = WaveCompileOptions(
        subs=hyperparams,
    )
    options = set_default_run_config(options)
    compiled_gemm = wave_compile(options, matrix_add)

    # Run the Tensor Addition kernel
    mlir = compiled_gemm(a, b, c)

    # Verify the result using PyTorch's matmul
    expected = a + b

    # Check if results are close (accounting for floating-point precision)
    assert torch.allclose(c.to(torch.float16), expected, rtol=1e-2, atol=1e-2), \
        f"Tensor Addition result doesn't match expected output\nMax difference: {(c - expected).abs().max()}"    

    print("Tensor Addition test passed!")

    # print(mlir) - if you want to see the generated mlir

# Run the test
test_gemm()