In [1]:
from opal_ptx import CuModuleWrapper, kernel_transformer

from torch.utils.cpp_extension import load
import torch

from matplotlib import pyplot as plt

In [14]:
@kernel_transformer.kernel()
def simple_wmma_gemm(A_: "u64", B_: "u64", D_: "u64", m: "u32", n: "u32", k: "u32"):
    A: u64
    ptx.cvta.to._global.u64(A, A_)

    B: u64
    ptx.cvta.to._global.u64(B, B_)
    
    D: u64
    ptx.cvta.to._global.u64(D, D_)
    
    x: u32 = u32("%ctaid.x") * 16
    y: u32 = u32("%ctaid.y") * 16

    a: b32(8)
    b: b32(8)
    d: b32(8)

    for i in range(len(d)):
        ptx.mov.b32(d[i], 0)
    
    KI: u32 = 0
    while KI < k:
        _A: u64 = A + (y * k + KI) * 2
        _B: u64 = B + (KI * n + x) * 2
        
        ptx.wmma.load.a.sync.aligned._global.m16n16k16.row.f16({*a}, [_A], k)
        ptx.wmma.load.b.sync.aligned._global.m16n16k16.row.f16({*b}, [_B], n)
    
        ptx.wmma.mma.sync.aligned.m16n16k16.row.row.f32.f32({*d}, {*a}, {*b}, {*d})
        KI += 16

    _D: u64 = D_ + (y * n + x) * 4
    ptx.wmma.store.d.sync.aligned.m16n16k16._global.row.f32([_D], {*d}, m)


In [54]:
@kernel_transformer.kernel()
def simple_wmma_gemm_f16(A_: "u64", B_: "u64", D_: "u64", m: "u32", n: "u32", k: "u32"):
    A: u64
    ptx.cvta.to._global.u64(A, A_)

    B: u64
    ptx.cvta.to._global.u64(B, B_)
    
    D: u64
    ptx.cvta.to._global.u64(D, D_)
    
    x: u32 = u32("%ctaid.x") * 16
    y: u32 = u32("%ctaid.y") * 16

    a: b32(8)
    b: b32(8)
    d: b32(4)

    for i in range(len(d)):
        ptx.mov.b32(d[i], 0)
    
    KI: u32 = 0
    while KI < k:
        _A: u64 = A + (y * k + KI) * 2
        _B: u64 = B + (KI * n + x) * 2
        
        ptx.wmma.load.a.sync.aligned._global.m16n16k16.row.f16({*a}, [_A], k)
        ptx.wmma.load.b.sync.aligned._global.m16n16k16.row.f16({*b}, [_B], n)
    
        ptx.wmma.mma.sync.aligned.m16n16k16.row.row.f16.f16({*d}, {*a}, {*b}, {*d})
        KI += 16

    _D: u64 = D_ + (y * n + x) * 2
    ptx.wmma.store.d.sync.aligned.m16n16k16._global.row.f16([_D], {*d}, m)


In [105]:
kernel_builder = kernel_transformer.KernelBuilder()
simple_wmma_gemm_f16(kernel_builder)
kernel_code = kernel_builder.generate()

M = 512
N = 512
K = 512

A = torch.randn((M, K), dtype=torch.float16, device="cuda") + 1
B = torch.randn((K, N), dtype=torch.float16, device="cuda") + 1
D = torch.zeros((M, N), dtype=torch.float16, device="cuda")

check = (A @ B)
wrapper = CuModuleWrapper()
wrapper.load_ptx_code(kernel_code)

wrapper.launch_kernel("simple_wmma_gemm_f16", ((M // 16), (N // 16), 1), (32, 1, 1), (A.data_ptr(), B.data_ptr(), D.data_ptr(), M, N, K), 0)
print("Maximum element difference", (D - check).abs().max())
print(D) 
if torch.allclose(D, check, atol=0.124, rtol=0):
    print("MATCH")
else:
    print("We fucked up")

Maximum element difference tensor(2.5000, device='cuda:0', dtype=torch.float16)
tensor([[501.0000, 400.0000, 498.0000,  ..., 496.0000, 516.0000, 407.7500],
        [502.5000, 432.2500, 535.5000,  ..., 504.0000, 545.5000, 527.5000],
        [518.5000, 452.0000, 511.2500,  ..., 520.5000, 558.0000, 465.5000],
        ...,
        [467.2500, 402.0000, 497.2500,  ..., 443.5000, 532.5000, 473.5000],
        [525.0000, 482.2500, 564.5000,  ..., 520.5000, 527.0000, 469.5000],
        [563.5000, 437.5000, 497.2500,  ..., 509.2500, 530.5000, 497.2500]],
       device='cuda:0', dtype=torch.float16)
We fucked up


In [95]:
kernel_builder = kernel_transformer.KernelBuilder()
simple_wmma_gemm(kernel_builder)
kernel_code = kernel_builder.generate()

M = 512
N = 512
K = 512

A = torch.randn((M, K), dtype=torch.float16, device="cuda") + 1
B = torch.randn((K, N), dtype=torch.float16, device="cuda") + 1
D = torch.zeros((M, N), dtype=torch.float32, device="cuda")

check = (A.to(torch.float32) @ B.to(torch.float32)).to(torch.float32)
wrapper = CuModuleWrapper()
wrapper.load_ptx_code(kernel_code)

wrapper.launch_kernel("simple_wmma_gemm", ((M // 16), (N // 16), 1), (32, 1, 1), (A.data_ptr(), B.data_ptr(), D.data_ptr(), M, N, K), 0)
print("Maximum element difference", (D - check).abs().max())

if torch.allclose(D, check, atol=0.124, rtol=0):
    print("MATCH")
else:
    print("We fucked up")

Maximum element difference tensor(0.0026, device='cuda:0')
MATCH
