In [1]:
from opal_ptx import CuModuleWrapper, kernel_transformer

from torch.utils.cpp_extension import load
import torch

from matplotlib import pyplot as plt

In [2]:
@kernel_transformer.kernel()
def simple_wmma_gemm(A: "u64", B: "u64", D: "u64", m: "u32", n: "u32", k: "u32"):
    x: u32 = u32("%ctaid.x") * 16
    y: u32 = u32("%ctaid.y") * 16

    a: b32(8)
    b: b32(8)
    d: b32(8)

    for i in range(8):
        ptx.mov.b32(d[i], 0)
    
    KI: u32 = 0
    while KI < k:
        _A: u64 = A + (y * k + KI) * 2
        _B: u64 = B + (KI * n + x) * 2
        
        ptx.wmma.load.a.sync.aligned._global.m16n16k16.row.f16({*a}, [_A], k)
        ptx.wmma.load.b.sync.aligned._global.m16n16k16.row.f16({*b}, [_B], n)
    
        ptx.wmma.mma.sync.aligned.m16n16k16.row.row.f32.f32({*d}, {*a}, {*b}, {*d})
        KI += 16

    _D: u64 = D + (y * n + x) * 4
    ptx.wmma.store.d.sync.aligned.m16n16k16._global.row.f32([_D], {*d}, m)


In [3]:
kernel_builder = kernel_transformer.KernelBuilder()
simple_wmma_gemm(kernel_builder)
kernel_code = kernel_builder.generate()

M = 32
N = 32
K = 32

#A = torch.eye(K, dtype=torch.float16, device="cuda")
#A = torch.triu(torch.full((K, K), 3, dtype=torch.float16, device="cuda"))
A = torch.rand((M, K), dtype=torch.float16, device="cuda")
B = torch.eye(K, dtype=torch.float16, device="cuda")

D = torch.zeros((M, N), dtype=torch.float32, device="cuda")

check = A @ B

wrapper = CuModuleWrapper()
wrapper.load_ptx_code(kernel_code)

wrapper.launch_kernel("simple_wmma_gemm", ((M // 16), (N // 16), 1), (32, 1, 1), (A.data_ptr(), B.data_ptr(), D.data_ptr(), M, N, K))

print(D)
print((D - check).sum())

tensor([[0.2263, 0.0316, 0.0526,  ..., 0.9868, 0.8169, 0.5220],
        [0.9502, 0.2150, 0.6509,  ..., 0.4800, 0.8105, 0.3823],
        [0.7480, 0.5503, 0.3638,  ..., 0.0263, 0.2659, 0.4893],
        ...,
        [0.0878, 0.2206, 0.6890,  ..., 0.4414, 0.4863, 0.3894],
        [0.7188, 0.8145, 0.6309,  ..., 0.2413, 0.3589, 0.5679],
        [0.6646, 0.8462, 0.4480,  ..., 0.6274, 0.3765, 0.9536]],
       device='cuda:0')
tensor(0., device='cuda:0')
