<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Optimized_Conv2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl


In [74]:

@triton.jit
def Conv2d_kernel(
    image_ptr,  # Pointer to the input image tensor
    kernel_ptr,  # Pointer to the kernel (weight) tensor
    bias_ptr,  # Pointer to the bias tensor
    output_ptr,  # Pointer to the output tensor
    stride_i0,  # Stride for batch dimension in the image
    stride_i1,  # Stride for channel dimension in the image
    stride_i2,  # Stride for height dimension in the image
    stride_i3,  # Stride for width dimension in the image
    stride_k0,  # Stride for output channel dimension in the kernel
    stride_k1,  # Stride for input channel dimension in the kernel
    stride_k2,  # Stride for height dimension in the kernel
    stride_k3,  # Stride for width dimension in the kernel
    stride_o0,  # Stride for batch dimension in the output
    stride_o1,  # Stride for channel dimension in the output
    stride_o2,  # Stride for height dimension in the output
    stride_o3,  # Stride for width dimension in the output
    bs,  # Batch size
    c,  # Number of input channels
    h,  # Height of the input image
    w,  # Width of the input image
    n_k,  # Number of output channels (number of kernels)
    k_h,  # Height of the kernel
    k_w,
    BLOCK_SIZE_ROW:tl.constexpr,
    BLOCK_SIZE_COL:tl.constexpr,
    num_warps=1):
  pid_b = tl.program_id(axis=0)
  pid_1 = tl.program_id(axis=1)
  pid_2 = tl.program_id(axis=2)
  pid_k = pid_1 //c
  pid_c = pid_1 % c
  n_pid_w = w//k_w
  pid_h = pid_2 // n_pid_w
  pid_w = pid_2 % n_pid_w
  row_offs = pid_h * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  col_offs = pid_w * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (row_offs[:,None] < h) & (col_offs[None,:] < w)
  input_ptrs = image_ptr + pid_b * stride_i0 + pid_c * stride_i1 + row_offs[:,None] * stride_i2 + col_offs[None,:] * stride_i3
  input = tl.load(input_ptrs,mask=mask)
  kernel_row = tl.arange(0,BLOCK_SIZE_ROW)
  kernel_col = tl.arange(0,BLOCK_SIZE_COL)
  kernel_ptrs = kernel_ptr + pid_c * stride_k0 + pid_c * stride_k1 + kernel_row[:,None] * stride_k2 + kernel_col[None,:] * stride_k3
  kernel = tl.load(kernel_ptrs)
  bias_ptrs = bias_ptr + pid_k
  bias = tl.load(bias_ptrs)
  elem = bias
  elem += tl.sum(input * kernel)
  output_ptrs = output_ptr + pid_b * stride_o0 + pid_k * stride_o1 + pid_h * stride_o2 + pid_w
  tl.atomic_add(output_ptrs,elem)

In [75]:
def Conv2d(image:torch.tensor,kernel:torch.tensor,bias:torch.tensor):
  assert image.is_cuda and kernel.is_cuda
  assert image.is_contiguous()
  bs,c,h,w = image.shape
  n_k,c,k_h,k_w = kernel.shape
  assert h % k_h ==0,w % k_w ==0
  output = torch.empty((bs,n_k,h//k_h,w//k_w),device=image.device,dtype=image.dtype)
  assert image.numel() % 16 ==0
  grid = (bs,n_k * c,(h//k_h)*(w//k_w))
  BLOCK_SIZE_ROW = 8
  BLOCK_SIZE_COL = 8

  Conv2d_kernel[grid](image,kernel,bias,output,image.stride(0),
                      image.stride(1),image.stride(2),image.stride(3),
                      kernel.stride(0),kernel.stride(1),kernel.stride(2),
                      kernel.stride(3),output.stride(0),output.stride(1),
                      output.stride(2),output.stride(3),bs,c,h,w,n_k,c,k_h,k_w,
                      BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return out

In [96]:
image = torch.randn((4,8,32,32),device='cuda',dtype=torch.float32)
convlayer = torch.nn.Conv2d(8,16,(8,8),(8,8)).to('cuda')

In [97]:
out = convlayer(image)
print(out.shape)

torch.Size([4, 16, 4, 4])


In [98]:
out_triton = Conv2d(image,convlayer.weight,convlayer.bias)

In [99]:
print(torch.allclose(out,out_triton))

True


In [104]:
def benchmark(fn,*args,warmup=10,steps=100):
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  for _ in range(warmup):
    fn(*args)
  torch.cuda.synchronize()
  start.record()
  for _ in range(steps):
    fn(*args)
  end.record()
  torch.cuda.synchronize()
  return start.elapsed_time(end)/steps
triton_time = benchmark(Conv2d,image,convlayer.weight,convlayer.bias)
torch_time = benchmark(lambda image: convlayer(image),image)

In [105]:
print(f'time required for the conv2d by triton {triton_time:.4f} ms')
print(f'time required for the conv2d by torch {torch_time:.4f} ms')

time required for the conv2d by triton 0.0889 ms
time required for the conv2d by torch 0.2944 ms
