Hi! You should duplicate this Colab file so you can run and edit it on your own!

# Part 1: ReLU

In [2]:
!pip install triton

Collecting triton
  Downloading triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Downloading triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (188.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.3/188.3 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.6.0


In [2]:
import triton
import triton.language as tl
import torch

In [3]:
@triton.jit
def relu_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  # get block id
  pid = tl.program_id(axis=0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  in_tile = tl.load(in_ptr + offsets)

  tl.store(out_ptr + offsets, tl.maximum(in_tile, 0))

In [4]:
def relu(x: torch.Tensor) -> torch.Tensor:
  output = torch.empty_like(x)
  n_elements = x.numel()

  BLOCK_SIZE = 256
  grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

  relu_kernel[grid](x, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
  return output

In [5]:
test_input = torch.randn(256 * 100, device="cuda")
your_output = relu(test_input)

expected_output = torch.relu(test_input)

# Check if they match
if not torch.allclose(your_output, expected_output):
    print(your_output, expected_output)
else:
    print("Yay!")

Yay!


# Part 2: MatMul

In [15]:
@triton.jit
def matmul_kernel(A, B, C, M, N, K, MT: tl.constexpr, NT: tl.constexpr, KT: tl.constexpr):
  c_m, c_n = tl.program_id(0), tl.program_id(1)
  offset_m = c_m * MT + tl.arange(0, MT)
  offset_n = c_n * NT + tl.arange(0, NT)
  offset_k = tl.arange(0, KT)

  acc = tl.zeros((MT, NT), dtype=C.dtype.element_ty)

  for k in range(0, tl.cdiv(K, KT)):
    k_step = k * KT + offset_k

    a_ptrs = A + (offset_m[:, None] * K + k_step[None, :])
    b_ptrs = B + (k_step[:, None] * N + offset_n[None, :])
    acc = tl.dot(tl.load(a_ptrs), tl.load(b_ptrs), acc)

  c_ptrs = C + (offset_m[:, None] * N + offset_n[None, :])
  tl.store(c_ptrs, acc)

def matmul(a: torch.Tensor, b: torch.Tensor):
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)

    MT, NT, KT = 128, 128, 32

    grid = (triton.cdiv(M, MT), triton.cdiv(N, NT))

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        MT=MT, NT=NT, KT=KT
    )
    return c

In [16]:
M, N, K = 512, 512, 512
a = torch.randn((M, K), device='cuda')
b = torch.randn((K, N), device='cuda')

your_output = matmul(a, b)
expected_output = torch.matmul(a, b)

# Check if they match
if not torch.allclose(your_output, expected_output):
    print(your_output, expected_output)
else:
    print("Yay!")

Yay!


# Part 3: MNIST Inference

Now let's use our kernels to run a real neural network! We have a pre-trained 4-layer MLP:
- **Layer 1:** 784 → 256 (input: flattened 28×28 image)
- **Layer 2:** 256 → 128
- **Layer 3:** 128 → 64
- **Layer 4:** 64 → 10 (output: 10 digit classes)

In [None]:
import urllib.request
import os
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

# Download pretrained weights from GitHub
WEIGHTS_URL = "https://raw.githubusercontent.com/kartva/gpu_workshop/main/learning/mnist_mlp_weights.pt"
WEIGHTS_PATH = "mnist_mlp_weights.pt"

if not os.path.exists(WEIGHTS_PATH):
    print(f"Downloading weights...")
    urllib.request.urlretrieve(WEIGHTS_URL, WEIGHTS_PATH)
    print("Done!")

weights = torch.load(WEIGHTS_PATH, map_location="cuda", weights_only=True)

print("Network Architecture:")
for name, tensor in weights.items():
    print(f"  {name:15} → {tuple(tensor.shape)}")

In [None]:
def forward(x: torch.Tensor, weights: dict) -> torch.Tensor:
    """Run forward pass through the 4-layer MLP."""
    x = x.cuda().float()
    if x.dim() == 1:
        x = x.unsqueeze(0)

    # Layer 1: Linear + ReLU
    w1 = weights["fc1.weight"].T.contiguous()
    b1 = weights["fc1.bias"]
    x = matmul(x, w1) + b1
    x = relu(x)

    # Layer 2: Linear + ReLU
    w2 = weights["fc2.weight"].T.contiguous()
    b2 = weights["fc2.bias"]
    x = matmul(x, w2) + b2
    x = relu(x)

    # Layer 3: Linear + ReLU
    w3 = weights["fc3.weight"].T.contiguous()
    b3 = weights["fc3.bias"]
    x = matmul(x, w3) + b3
    x = relu(x)

    # Layer 4: Linear (no activation)
    w4 = weights["fc4.weight"].T.contiguous()
    b4 = weights["fc4.bias"]
    x = matmul(x, w4) + b4

    return x

In [None]:
# Load MNIST test dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

print(f"Loaded {len(test_dataset)} test images")

In [None]:
import random

# Pick a random test image
idx = random.randint(0, len(test_dataset) - 1)
image, true_label = test_dataset[idx]

# Flatten and run inference
flat_image = image.view(1, 784)
logits = forward(flat_image, weights)
predicted_label = logits.argmax(dim=1).item()

# Display results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.imshow(image.squeeze().numpy(), cmap="gray")
ax1.set_title(f"True Label: {true_label}", fontsize=14)
ax1.axis("off")

probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
colors = ["green" if i == predicted_label else "steelblue" for i in range(10)]
ax2.barh(range(10), probs, color=colors)
ax2.set_yticks(range(10))
ax2.set_xlabel("Probability")
ax2.set_title(f"Predicted: {predicted_label}", fontsize=14)
ax2.set_xlim(0, 1)

plt.tight_layout()
plt.show()

if predicted_label == true_label:
    print(f"Correct! Predicted {predicted_label}")
else:
    print(f"Wrong! Predicted {predicted_label}, actual {true_label}")