In [1]:
import transformer as kernel_transformer

from torch.utils.cpp_extension import load
import torch

from matplotlib import pyplot as plt
trampoline = load(
    "extension",
    sources=["trampoline.cu"],
    extra_ldflags=["-lnvrtc", "-lcuda"],
    extra_cuda_cflags=["-g"],
)


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [2]:
import subprocess
import tempfile
import sass_graph

def visualize_opal_kernel(kernel, arch="sm_75"):
    kernel_builder = kernel_transformer.KernelBuilder()
    wmma_test(kernel_builder)
    kernel_code = kernel_builder.generate(arch)
    
    with tempfile.NamedTemporaryFile() as temp_file:
        temp_file.write(kernel_code.encode("utf-8"))
        subprocess.run(["ptxas", "--gpu-name", arch, temp_file.name])

    cfgs = sass_graph.generate_cfgs(sass_graph.disassemble("elf.o"))
    return sass_graph.display_cfg(cfgs["wmma_test"])

In [3]:
import struct 
def float_to_int32(value):
    # Pack the float into 4 bytes using IEEE 754 format
    packed = struct.pack('f', value)
    
    # Unpack those bytes as a 32-bit unsigned integer
    unpacked = struct.unpack('I', packed)[0]
    
    return unpacked


In [111]:
@kernel_transformer.kernel()
def simple_wmma_gemm(A: "u64", B: "u64", D: "u64", m: "u32", n: "u32", k: "u32"):
    x: u32 = u32("%ctaid.x") * 16
    y: u32 = u32("%ctaid.y") * 16

    a: b32(8)
    b: b32(8)
    d: b32(8)

    for i in range(8):
        ptx.mov.b32(d[i], 0)
    
    KI: u32 = 0
    while KI < k:
        _A: u64 = A + (y * k + KI) * 2
        _B: u64 = B + (KI * n + x) * 2
        
        ptx.wmma.load.a.sync.aligned._global.m16n16k16.row.f16({*a}, [_A], k)
        ptx.wmma.load.b.sync.aligned._global.m16n16k16.row.f16({*b}, [_B], n)
    
        ptx.wmma.mma.sync.aligned.m16n16k16.row.row.f32.f32({*d}, {*a}, {*b}, {*d})
        KI += 16

    _D: u64 = D + (y * n + x) * 4
    ptx.wmma.store.d.sync.aligned.m16n16k16._global.row.f32([_D], {*d}, m)


In [149]:
kernel_builder = kernel_transformer.KernelBuilder()
simple_wmma_gemm(kernel_builder)
kernel_code = kernel_builder.generate()

M = 32
N = 32
K = 32

#A = torch.eye(K, dtype=torch.float16, device="cuda")
#A = torch.triu(torch.full((K, K), 3, dtype=torch.float16, device="cuda"))
A = torch.rand((M, K), dtype=torch.float16, device="cuda")
B = torch.eye(K, dtype=torch.float16, device="cuda")

D = torch.zeros((M, N), dtype=torch.float32, device="cuda")

check = A @ B

wrapper = trampoline.CuModuleWrapper()
wrapper.load_ptx_code(kernel_code)

wrapper.launch_kernel("simple_wmma_gemm", ((M // 16), (N // 16), 1), (32, 1, 1), (A.data_ptr(), B.data_ptr(), D.data_ptr(), M, N, K))

print(D)
print((D - check).sum())

tensor([[0.3228, 0.0168, 0.1490,  ..., 0.5068, 0.2299, 0.5762],
        [0.5312, 0.4934, 0.1993,  ..., 0.3250, 0.9209, 0.8271],
        [0.6807, 0.9517, 0.7168,  ..., 0.2417, 0.2417, 0.2766],
        ...,
        [0.2944, 0.0128, 0.0894,  ..., 0.2976, 0.8716, 0.1316],
        [0.1379, 0.5337, 0.7734,  ..., 0.1713, 0.2522, 0.0244],
        [0.0586, 0.2920, 0.5508,  ..., 0.2754, 0.8848, 0.6543]],
       device='cuda:0')
tensor(0., device='cuda:0')
