<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Triton_Conv1d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import triton
import triton.language as tl

In [2]:
layer = nn.Conv1d(4,7,3,3).to('cuda')
input = torch.rand((4,4,9),device='cuda')
out = layer(input)
layer.weight.shape
layer.bias.shape

torch.Size([7])

In [3]:
@triton.jit
def conv1d_kernel(input_ptr,input_batch_stride,
        input_channel_stride,
        input_col_stride,
        width,
        channels,
        kernel_ptr,
        kernel_width,
        kernel_dim_stride,
        kernel_channel_stride,
        kernel_col_stride,
        bias_ptr,
        output_ptr,
        output_width,
        output_batch_stride,
        output_channel_stride,
        output_col_stride,
        BLOCK_SIZE_COL:tl.constexpr,
        BLOCK_SIZE_CHANNELS:tl.constexpr):
  batch_id = tl.program_id(axis=0)
  kernel_id = tl.program_id(axis=1)
  col_id = tl.program_id(axis=2)
  input_offs = batch_id * input_batch_stride
  kernel_offs = kernel_id * kernel_dim_stride
  channel_offs = tl.arange(0,BLOCK_SIZE_CHANNELS)
  channel_mask = channel_offs[:,None] < channels
  channel_kernel_offs = channel_offs[:,None] * kernel_channel_stride
  channel_offs = channel_offs[:,None] * input_channel_stride
  col_offs =  tl.arange(0,BLOCK_SIZE_COL)
  col_kernel_mask = col_offs[None,:] < kernel_width
  col_kernel_offs = col_offs[None,:] * kernel_col_stride
  col_offs = col_id * kernel_width + tl.arange(0,BLOCK_SIZE_COL)
  mask_col = col_offs[None,:] < width
  col_offs = col_offs[None,:] * input_col_stride
  bias = tl.load(bias_ptr+kernel_id)
  input_ptrs = input_ptr + input_offs+ channel_offs + col_offs
  input  = tl.load(input_ptrs,mask=(channel_mask) & (mask_col))
  kernel_ptrs = kernel_ptr + kernel_offs + channel_kernel_offs + col_kernel_offs
  kernel = tl.load(kernel_ptrs,mask=(channel_mask) & (col_kernel_mask))
  elem = tl.sum(input * kernel) + bias
  out_ptrs = output_ptr + output_batch_stride * batch_id + output_channel_stride * kernel_id + col_id
  tl.store(out_ptrs,elem)

In [4]:
def Triton_Conv1d(input,kernel,bias):
  assert input.is_cuda and kernel.is_cuda, 'Input or kernel is not on GPU'
  assert input.ndim == 3, f'Input needs to be 4 dimensional, provided: {input.shape}'
  assert kernel.ndim == 3, f'Kernel size needs to be 4 dimensional, provided: {kernel.shape}'
  assert bias.shape[0] == kernel.shape[0], f'Bias dimension should be same as the kernel 1st dimension'
  batch_size,channels,width = input.shape
  num_kernels,kernel_depth,kernel_width = kernel.shape
  assert width % kernel_width == 0,f'invalid compatibility {width} is not multiple of {kernel_width}'
  assert channels == kernel_depth, f"Kernel channel depth ({kernel_depth}) and input channel depth ({channels}) should be same"
  output = torch.empty((batch_size,num_kernels,width//kernel_width),device=input.device)
  BLOCK_SIZE_COL = triton.next_power_of_2(kernel_width)
    # parallelize across the batch and kernels and grouped rows (groupe rows = kernel_height)
  grid = (batch_size, num_kernels, width//kernel_width)
  conv1d_kernel[grid](
        input_ptr=input,
        input_batch_stride=input.stride(0),
        input_channel_stride=input.stride(1),
        input_col_stride=input.stride(2),
        width=width,
        channels=channels,
        kernel_ptr=kernel,
        kernel_width=kernel_width,
        kernel_dim_stride=kernel.stride(0),
        kernel_channel_stride=kernel.stride(1),
        kernel_col_stride=kernel.stride(2),
        bias_ptr=bias,
        output_ptr=output,
        output_width=width//kernel_width,
        output_batch_stride=output.stride(0),
        output_channel_stride=output.stride(1),
        output_col_stride=output.stride(2),
        BLOCK_SIZE_COL=BLOCK_SIZE_COL,
        BLOCK_SIZE_CHANNELS=channels)

  return output

In [5]:
layer = nn.Conv1d(4,12,16,16).to('cuda')
input = torch.rand((16,4,64),device='cuda')
out_triton = Triton_Conv1d(input,layer.weight,layer.bias)
out = layer(input)

In [6]:
print(torch.allclose(out,out_triton,1e-4))

True
