In [19]:
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import cv2
import numpy as np

In [16]:
@triton.jit
def conv2d_kernel(
    input_ptr, output_ptr, filter_ptr, H, W,
    BLOCK_SIZE: tl.constexpr
):
    row = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    col = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    mask = (row < H) & (col < W)

    result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

    for i in range(-1, 2):
        for j in range(-1, 2):
            row_idx = row + i
            col_idx = col + j
            in_bounds = (row_idx >= 0) & (row_idx < H) & (col_idx >= 0) & (col_idx < W)
            input_val = tl.load(input_ptr + row_idx * W + col_idx, mask=in_bounds, other=0.0)

            filter_val = tl.load(filter_ptr + (i + 1) * 3 + (j + 1))
            result += input_val * filter_val

    tl.store(output_ptr + row * W + col, result, mask=mask)

In [20]:
def load_image(image_path, target_size = (128, 128)):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image = cv2.resize(image, target_size)
    image = image.astype(np.float32) / 255.0
    return image

In [22]:
H, W = 128, 128
BLOCK_SIZE = 16

input_image = torch.rand((H, W), dtype=torch.float32, device='cuda')
output_image = torch.empty((H, W), dtype=torch.float32, device='cuda')

conv_filter = torch.tensor([
    [ 0, -1,  0],
    [-1,  4, -1],
    [ 0, -1,  0]
], dtype=torch.float32, device='cuda').flatten()

grid = ((H + BLOCK_SIZE - 1) // BLOCK_SIZE, (W + BLOCK_SIZE - 1) // BLOCK_SIZE)

conv2d_kernel[grid](input_image, output_image, conv_filter, H, W, BLOCK_SIZE=BLOCK_SIZE)

<triton.compiler.compiler.CompiledKernel at 0x7f28b91b1b10>

In [24]:
input_cpu = input_image.cpu().numpy()
output_cpu = output_image.cpu().numpy()

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(input_cpu, cmap='gray')
axes[0].set_title("Input Image")
axes[0].axis("off")

axes[1].imshow(output_cpu, cmap='gray')
axes[1].set_title("Filtered Image (Edge Detection)")
axes[1].axis("off")

plt.show()