### Installation

In [None]:
# install Pytorch-ROCm dependencies
!pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
# install wave and its dependencies
!pip install wave-lang

### Wave Tensor/Matrix Addition Kernel Tutorial

#### (1) Importing Libraries

In [None]:
import torch

import wave_lang.kernel.wave as wave
import wave_lang.kernel.lang as tkl
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.lang.wave_types import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.debugging.html_viewer import html_viewer

In [None]:
# Define symbolic dimensions for our matrices
M = tkl.sym.M # Rows of A, B and C
N = tkl.sym.N # Cols of A, B and C

# Define workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N

# Define the address space for our memory buffers
ADDRESS_SPACE_A = tkl.sym.ADDRESS_SPACE_A
ADDRESS_SPACE_B = tkl.sym.ADDRESS_SPACE_B
ADDRESS_SPACE_C = tkl.sym.ADDRESS_SPACE_C

#### (2) Constraints
To specify how we want to distribute the different dimensions of our problem.

In [None]:
# Define constraints for the kernel
constraints = [
    # specifies how computation is tiled
    tkw.WorkgroupConstraint(M, BLOCK_M, 0),
    tkw.WorkgroupConstraint(N, BLOCK_N, 1),
    tkw.WaveConstraint(M, BLOCK_M / 2),
    tkw.WaveConstraint(N, BLOCK_N / 2),
    tkw.HardwareConstraint(
        threads_per_wave=64,
        vector_shapes={M: BLOCK_M, N: BLOCK_N}
    )
]

#### (3) Kernel Definition
Our tensor addition example with compute C = A + B. They are 2D matrices of dimension MxN in f16 precision.

In [None]:
# The kernel
@wave.wave(constraints)
def matrix_add(
    # defines matrix in memory of req dimension with specific data types
    a: Memory[M, N, ADDRESS_SPACE_A, tkl.f16],
    b: Memory[M, N, ADDRESS_SPACE_B, tkl.f16],
    c: Memory[M, N, ADDRESS_SPACE_C, tkl.f16],
):
    # Intialize the accumulator register with zeroes
    c_reg = Register[M,N,tkl.f16](0.0)

    # loads values from memory into registers
    a_reg = wave.read(a)
    wave.debug_log(a_reg, label='Values in a_register', handler=html_viewer)

    b_reg = wave.read(b)
    wave.debug_log(b_reg, label='Values in b_register', handler=html_viewer)

    # compute the sum
    c_reg = a_reg + b_reg
    wave.debug_log(c_reg, label='Values in c_register', handler=html_viewer)
    
    # writing results back to memory
    wave.write(c_reg, c)

You can find the debug-view.html in the same directory

#### (4) Testing the kernel 
A function to verify our implementation works by comparing with PyTorch reference.

In [None]:
# The kernel testing function
def test_matrix_add():
    # Create test matrices
    m, n = 128, 128

    # Initialize input matrices with random values
    torch.manual_seed(0)
    a = torch.randn(m, n, dtype=torch.float16, device="cuda")
    b = torch.zeros(m, n, dtype=torch.float16, device="cuda")
    c = torch.zeros(m, n, dtype=torch.float16, device="cuda")

    # Set hyperparameters for compilation
    hyperparams = {
        ADDRESS_SPACE_A: SHARED_ADDRESS_SPACE,
        ADDRESS_SPACE_B: SHARED_ADDRESS_SPACE,
        ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
        BLOCK_M: 64,
        BLOCK_N: 64,
        M: m,
        N: n
    }

    # Compile the kernel
    options = WaveCompileOptions(
        subs=hyperparams,
    )
    options = set_default_run_config(options)
    compiled_gemm = wave_compile(options, matrix_add)

    # Run the Tensor Addition kernel
    compiled_gemm(a, b, c)

    # Verify the result using PyTorch
    expected = a + b

    # Check if results are close (accounting for floating-point precision)
    assert torch.allclose(c.to(torch.float16), expected, rtol=1e-2, atol=1e-2), \
        f"Matrix Addition result doesn't match expected output\nMax difference: {(c - expected).abs().max()}"    

    print("Matrix Addition test passed!")

# Run the test
test_matrix_add()