In [17]:
import time
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(A, B, C, M, N, K, BLOCK_SIZE: tl.constexpr):
    row = tl.program_id(0)
    col = tl.program_id(1)

    a = tl.load(A + row * K + tl.arange(0, BLOCK_SIZE))  # Load A tile
    b = tl.load(B + tl.arange(0, BLOCK_SIZE) * N + col)  # Load B tile

    c = tl.sum(a * b, axis=0)  # Compute partial sum
    tl.store(C + row * N + col, c)  # Store result

def matmul_triton(A, B):
    M, K = A.shape
    K, N = B.shape
    C = torch.empty((M, N), device=A.device, dtype=A.dtype)
    grid = (M, N)
    matmul_kernel[grid](A, B, C, M, N, K, BLOCK_SIZE=16)
    return C

D = 1024*2**5

A = torch.randn(D, D, device="cuda", dtype=torch.float32)
B = torch.randn(D, D, device="cuda", dtype=torch.float32)

start = time.time()
C = matmul_triton(A, B)
print(C[0,0])
print(f"time: {(time.time() - start)*1000:0.4f}")

torch.set_float32_matmul_precision('high')
@torch.compile
def torch_matmul(A, B):
  return A @ B

start = time.time()
C = torch_matmul(A, B)
print(C[0,0])
print(f"time: {(time.time() - start)*1000:0.4f}")

tensor(1.9950, device='cuda:0')
time: 843.2209
tensor(126.4425, device='cuda:0')
time: 626.4899
