<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Flatten.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import triton
import triton.language as tl

In [24]:
@triton.jit
def flatten_kernel(x_ptr,stride_b,stride_m,stride_n,y_ptr,BLOCK_SIZE_COL:tl.constexpr,b,m,n):
  batch_id = tl.program_id(axis=0)
  row_id = tl.program_id(axis=1)
  col_id = tl.program_id(axis=2)
  col_offs = col_id* BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = col_offs[None,:] < n
  x_ptrs = x_ptr + batch_id * stride_b + row_id * stride_m + col_offs[None,:] * stride_n
  x = tl.load(x_ptrs,mask=mask)
  y_ptrs = y_ptr + batch_id * stride_b + row_id * stride_m + col_offs
  tl.store(y_ptrs,x,mask=mask)


In [28]:
def flatten(x:torch.tensor):
  assert x.is_cuda
  if not x.is_contiguous():
    x = x.contiguous()
  y = torch.empty(x.numel(),device=x.device,dtype=x.dtype)
  b,m,n = x.shape
  BLOCK_SIZE_COL = 128
  grid = (b,m,triton.cdiv(x.shape[-1],BLOCK_SIZE_COL))
  flatten_kernel[grid](x,x.stride(0),x.stride(1),x.stride(2),y,BLOCK_SIZE_COL,b,m,n)
  return y

In [32]:
x = torch.randn((6,32,1024),device='cuda',dtype=torch.float32)
y_torch = x.flatten()
y = flatten(x)

In [34]:
print(torch.allclose(y,y_torch))

True
