<a href="https://colab.research.google.com/github/christpaul94/MastersThesis_PaulChrist/blob/main/GPU_MatrixMultiplicationDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import numpy as np
import torch
import triton
import triton.language as tl
import time




torch.set_float32_matmul_precision('high')

# ========================================================================
# 1. TRITON MATMUL KERNEL
# ========================================================================
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K'],
)
# Eigener kernel
@triton.jit
def matmul_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n
    offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
    offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        offs_k = k + tl.arange(0, BLOCK_SIZE_K)
        a_ptrs = A_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
        a_tile = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
        b_ptrs = B_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
        b_tile = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
        accumulator += tl.dot(a_tile, b_tile)
    c_ptrs = C_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    tl.store(c_ptrs, accumulator, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def matmul(a, b):
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)
    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

# ========================================================================
# 2. BENCHMARK
# ========================================================================

# --- Parameter ---
M = N = K = 2**12
print(M)
dtype = torch.float32
device_gpu = 'cuda'
device_cpu = 'cpu'

N_REPEATS_GPU = N_REPEATS_CPU =  1

print(f"Benchmark MatMul: M={M}, N={N}, K={K}, dtype={dtype}")
print(f"Wiederholungen: {N_REPEATS_CPU}\n")

# --- Daten erstellen ---
a_gpu = torch.randn(M, K, device=device_gpu, dtype=dtype).contiguous()
b_gpu = torch.randn(K, N, device=device_gpu, dtype=dtype).contiguous()
a_cpu = a_gpu.to(device_cpu)
b_cpu = b_gpu.to(device_cpu)

a_np = a_cpu.numpy()
b_np = b_cpu.numpy()

# print(a_np)
# print(b_np)

# --- CUDA Events f체r pr채zises GPU-Timing ---
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)


# --- 2. Benchmark: NumPy (CPU) ---
c_np = np.zeros((M, N), dtype=a_np.dtype)

# Messung
start_time_np = time.perf_counter()
for _ in range(N_REPEATS_CPU):
    c_np = np.dot(a_np, b_np) # np.dot ist der Standard f체r MatMul
end_time_np = time.perf_counter()
np_time_ms = ((end_time_np - start_time_np) * 1000) / N_REPEATS_CPU
#print(c_np)

# --- 3. Benchmark: torch.matmul (GPU, cuBLAS Referenz) ---
c_torch = torch.empty(M, N, device=device_gpu, dtype=dtype)



# Messung
start_event.record()
for _ in range(N_REPEATS_GPU):
    c_torch = torch.matmul(a_gpu, b_gpu)
end_event.record()
torch.cuda.synchronize()
torch_time_ms = start_event.elapsed_time(end_event) / N_REPEATS_GPU


# --- 4. Benchmark: Manueller Triton-Kernel (GPU) ---
c_triton = torch.empty(M, N, device=device_gpu, dtype=dtype)


for _ in range(10):
    c_triton = matmul(a_gpu, b_gpu)
torch.cuda.synchronize()

# Messung
start_event.record()
for _ in range(N_REPEATS_GPU):
    c_triton = matmul(a_gpu, b_gpu)
end_event.record()
torch.cuda.synchronize()
triton_time_ms = start_event.elapsed_time(end_event) / N_REPEATS_GPU


# --- 5. Korrektheitspr체fung ---
try:
    # Vergleiche Torch-CPU und NumPy
    np.allclose(c_np, c_torch.cpu().numpy(), atol=1e-5, rtol=1e-4)
    print("Korrekte Berechnung PyTorch")

    # Vergleiche Triton mit der Torch-GPU-Referenz
    np.allclose(c_np, c_triton.cpu().numpy(), atol=1e-5, rtol=1e-4)
    print("Korrekte Berechnung Triton")

except Exception as e:
    print(f"  Inkorrekte Berechnung. Fehler: {e}")

# --- 6. Ergebnisse ---
print("\n--- Benchmark-Ergebnisse (MatMul) ---")

print(f"NumPy (CPU):             {np_time_ms:.4f} ms")
print(f"PyTorch (GPU, cuBLAS):   {torch_time_ms:.4f} ms")
print(f"Manueller Triton (GPU):  {triton_time_ms:.4f} ms")


4096
Benchmark MatMul: M=4096, N=4096, K=4096, dtype=torch.float32
Wiederholungen: 1

Korrekte Berechnung PyTorch
Korrekte Berechnung Triton

--- Benchmark-Ergebnisse (MatMul) ---
NumPy (CPU):             1752.0670 ms
PyTorch (GPU, cuBLAS):   47.8194 ms
Manueller Triton (GPU):  38.6560 ms
