<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Triton_Conv2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl
from typing import Tuple

dtype = torch.float32
device = 'cuda:0'


In [6]:
@triton.jit
def conv2d_kernel(
    input_ptr,input_batch_stride,input_channel_stride,input_row_stride,
    input_col_stride, height, width,channels, kernel_ptr, kernel_height, kernel_width, kernel_dim_stride,
    kernel_channel_stride, kernel_row_stride, kernel_col_stride, bias_ptr,output_ptr, output_width, output_batch_stride,
    output_channel_stride, output_row_stride,output_col_stride, BLOCK_SIZE_ROW: tl.constexpr,  BLOCK_SIZE_COL: tl.constexpr,num_stages: tl.constexpr):
  batch_id = tl.program_id(axis=0)
  kernel_id = tl.program_id(axis=1)
  row_id = tl.program_id(axis=2)
  # since bias length is equal to num_kernels , so to access is use (bias_ptr + kernel_id)
  bias_offset = kernel_id
  bias = tl.load(bias_ptr + bias_offset)
  # offset by batch for input
  in_batch_offs = batch_id * input_batch_stride
  # out_offs is assigned by batch * stride_batch and kernel_id * stride_channel because num kernels rely on the 1 dim of the output
  out_offs = batch_id * output_batch_stride + kernel_id * output_channel_stride + row_id * output_row_stride
  kernel_row_offs = tl.arange(0,BLOCK_SIZE_ROW)
  kernel_col_offs = tl.arange(0,BLOCK_SIZE_COL)
  # apply the mask
  kernel_mask = (kernel_row_offs[:,None] < kernel_height) &(kernel_col_offs[None,:] < kernel_width)
  # assign the offset for the kernel
  kernel_offs = kernel_row_offs[:,None] * kernel_row_stride + kernel_col_offs[None,:] * kernel_col_stride
  # input_row_offs depend on kernel_height because we assign each (groupe rows to thread blocks ) group_rows = height // kernel_height
  input_row_offs = row_id * kernel_height + tl.arange(0,BLOCK_SIZE_ROW)
  input_row_mask = input_row_offs[:,None] < height
  input_row_offs = input_row_offs[:,None] * input_row_stride
  # iterate trough the columns , for each group_rows , group_cols we do element wise mul and sum , across all input channels
  for col_id in range(output_width):
    elem = 0.0
    # for each group_cols = input_width // kernel_width we assign the offs
    input_col_offs = col_id * kernel_width + tl.arange(0,BLOCK_SIZE_COL)
    # assign the mask
    input_col_mask = input_col_offs[None,:] < width
    input_col_offs = input_col_offs[None,:] * input_col_stride
    # iterate trough each chanel (depth of image)
    for c in range(channels,num_stages=num_stages):
      # assign the correct block of pointer to load the inputs
      input_ptrs = input_ptr + in_batch_offs + c * input_channel_stride + input_row_offs + input_col_offs
      input = tl.load(input_ptrs,mask=(input_row_mask)&(input_col_mask))
      # assing the correct block of pointer for the kernel -> for each kernel we have depth == depth of image
      kernel_ptrs = kernel_ptr + kernel_id * kernel_dim_stride + c * kernel_channel_stride + kernel_offs
      # load the values
      kernel = tl.load(kernel_ptrs,mask=kernel_mask)
      # element wise multiplication and sum
      elem += tl.sum(input * kernel).to(dtype=tl.float32)

    out_ptrs = output_ptr + out_offs + col_id
    tl.store(out_ptrs,elem+bias)


In [7]:
def conv2d_triton(
    input: torch.Tensor,
    kernel: torch.Tensor,
    bias: torch.Tensor
) -> torch.Tensor:
    assert input.is_cuda and kernel.is_cuda, 'Input or kernel is not on GPU'
    assert len(input.shape) == 4, f'Input needs to be 4 dimensional, provided: {input.shape}'
    assert len(kernel.shape) == 4, f'Kernel size needs to be 4 dimensional, provided: {kernel.shape}'
    assert bias.shape[0] == kernel.shape[0], f'Bias dimension should be same as the kernel 1st dimension'

    batch_size, channels, height, width = input.shape
    # num_kernels == num_out_channels , kernel_depth = num_in_channels
    num_kernels, kernel_depth, kernel_height, kernel_width = kernel.shape
    # assert compatibility between height and kernel_height , to match it
    assert height%kernel_height == 0 and width%kernel_width == 0, f"Input height and width should be divisible by the kernel height and width"
    assert channels == kernel_depth, f"Kernel channel depth ({kernel_depth}) and input channel depth ({channels}) should be same"

    output = torch.empty((batch_size, num_kernels, height//kernel_height, width//kernel_width), device=device, dtype=dtype)
    # next_power_of_2 gives the pow(2,n) that is equal to kernel_height or bigger then it
    BLOCK_SIZE_ROW = triton.next_power_of_2(kernel_height)
    BLOCK_SIZE_COL = triton.next_power_of_2(kernel_width)
    # parallelize across the batch and kernels and grouped rows (groupe rows = kernel_height)
    grid = (batch_size, num_kernels, height//kernel_height)

    conv2d_kernel[grid](
        input_ptr=input,
        input_batch_stride=input.stride(0),
        input_channel_stride=input.stride(1),
        input_row_stride=input.stride(2),
        input_col_stride=input.stride(3),
        height=height,
        width=width,
        channels=channels,
        kernel_ptr=kernel,
        kernel_height=kernel_height,
        kernel_width=kernel_width,
        kernel_dim_stride=kernel.stride(0),
        kernel_channel_stride=kernel.stride(1),
        kernel_row_stride=kernel.stride(2),
        kernel_col_stride=kernel.stride(3),
        bias_ptr=bias,
        output_ptr=output,
        output_width=width//kernel_width,
        output_batch_stride=output.stride(0),
        output_channel_stride=output.stride(1),
        output_row_stride=output.stride(2),
        output_col_stride=output.stride(3),
        BLOCK_SIZE_ROW=BLOCK_SIZE_ROW,
        BLOCK_SIZE_COL=BLOCK_SIZE_COL,
        num_stages=4,)

    return output


In [16]:
layer = torch.nn.Conv2d(3,8,(4,4),4).to('cuda')
input = torch.rand((1,3,40,40),device='cuda')
out = layer(input)
out_triton = conv2d_triton(input,layer.weight,layer.bias)

In [32]:
print(torch.allclose(out,out_triton,1e-4))

True
