In [43]:
import numpy as np
import torch
import triton
import triton.language as tl
import time

In [44]:
N = 1 << 8

In [45]:
@triton.jit
def init(x_ptr, stride_x, stride_y, N):
    pid_x = tl.program_id(axis=0)
    pid_y = tl.program_id(axis=1)
    
    x = pid_x * 32 + tl.arange(0, 32)
    y = pid_y * 32 + tl.arange(0, 32)
    
    for i in range(0, N, stride_x):
        for j in range(0, N, stride_y):
            idx = (i + x[:, None]) * N + (j + y[None, :])
            mask = (i + x[:, None] < N) & (j + y[None, :] < N)
            val = (i + x[:, None]) * (j + y[None, :])
            tl.store(x_ptr + idx, val, mask=mask)

In [46]:
@triton.jit
def matmul_gpu(a_ptr, b_ptr, c_ptr, N):
    pid_x = tl.program_id(axis=0)
    pid_y = tl.program_id(axis=1)
    
    x = pid_x * 32 + tl.arange(0, 32)
    y = pid_y * 32 + tl.arange(0, 32)
    
    for i in range(0, N, 32*256):
        for j in range(0, N, 32*256):
            i_idx = i + x[:, None]
            j_idx = j + y[None, :]
            
            mask = (i_idx < N) & (j_idx < N)
            val = tl.zeros((32, 32), dtype=tl.int32)
            
            for k in range(0, N):
                a_val = tl.load(a_ptr + i_idx * N + k, mask=i_idx < N, other=0)
                b_val = tl.load(b_ptr + k * N + j_idx, mask=j_idx < N, other=0)
                val += a_val * b_val
            
            tl.store(c_ptr + i_idx * N + j_idx, val, mask=mask)

In [47]:
def matmul_cpu(a_cpu, b_cpu, c_cpu):
    for i in range(N):
        for j in range(N):
            a_cpu[i * N + j] = i * j
            b_cpu[i * N + j] = i * j
    
    for i in range(N):
        for j in range(N):
            val = 0
            for k in range(N):
                val += a_cpu[i * N + k] * b_cpu[k * N + j]
            c_cpu[i * N + j] = val

In [48]:
def main():
    a = torch.zeros(N*N, dtype=torch.int32, device='cuda')
    b = torch.zeros(N*N, dtype=torch.int32, device='cuda')
    c = torch.zeros(N*N, dtype=torch.int32, device='cuda')
    
    start = time.time()
    grid = (triton.cdiv(N, 32), triton.cdiv(N, 32))
    init[grid](a, 32*256, 32*256, N)
    init[grid](b, 32*256, 32*256, N)
    torch.cuda.synchronize()

    start = time.time()
    matmul_gpu[grid](a, b, c, N)
    torch.cuda.synchronize()
    gpu_time = time.time() - start
    
    a_cpu = np.zeros(N*N, dtype=np.int32)
    b_cpu = np.zeros(N*N, dtype=np.int32)
    c_cpu = np.zeros(N*N, dtype=np.int32)
    
    start = time.time()
    matmul_cpu(a_cpu, b_cpu, c_cpu)
    cpu_time = time.time() - start
    
    a_np = a_cpu.reshape(N, N)
    b_np = b_cpu.reshape(N, N)
    start = time.time()
    c_np = np.matmul(a_np, b_np)
    numpy_time = time.time() - start

    print(f"GPU matmul time: {gpu_time:.4f} seconds")
    print(f"CPU loop time: {cpu_time:.4f} seconds")
    print(f"NumPy matmul time: {numpy_time:.4f} seconds")
    
    c_gpu = c.cpu().numpy()
    flag = False
    for i in range(N):
        for j in range(N):
            if c_gpu[i * N + j] != c_cpu[i * N + j]:
                print(f"Error in c[{i}][{j}]: CPU={c_cpu[i * N + j]}, GPU={c_gpu[i * N + j]}")
                flag = True
                break
        if flag:
            break
    
    if not flag:
        print("Success")

In [50]:
main()

  val += a_cpu[i * N + k] * b_cpu[k * N + j]
  val += a_cpu[i * N + k] * b_cpu[k * N + j]


GPU matmul time: 0.0006 seconds
CPU loop time: 5.6488 seconds
NumPy matmul time: 0.0093 seconds
Success
