<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Transpose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [2]:
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_warps=2, num_stages=1),
    ],
    key=['m', 'n'] )
@triton.jit
def transpose_kernel(A_ptr,out_ptr,stride_am,stride_an,stride_om,stride_on,m,n,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
  offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N)
  mask = (offs_m[:,None] < m) & (offs_n[None,:] < n)
  input = tl.load(A_ptr + offs_m[:,None] * stride_am + offs_n[None,:] * stride_an,mask=mask)
  in_tr = input.T
  out_ptrs = out_ptr + offs_n[:,None] * stride_om + offs_m[None,:] * stride_on
  tl.store(out_ptrs,in_tr)

In [3]:
def transpose(A:torch.tensor):
  assert A.is_cuda and A.is_contiguous()
  m,n = A.shape
  out = torch.empty((n,m),device=A.device,dtype=A.dtype)
  grid = lambda meta : (triton.cdiv(m,meta['BLOCK_SIZE_M']),triton.cdiv(n,meta['BLOCK_SIZE_N']))
  transpose_kernel[grid](A,out,A.stride(0),A.stride(1),out.stride(0),out.stride(1),m,n)
  return out

In [4]:
A = torch.randn((1024,1024),device='cuda')

In [8]:
A_tr = A.T
B = transpose(A)

In [10]:
print(f'torch transpose is equal to triton tranaspose {torch.allclose(A_tr,B)} \n')
print(f'triton transpose results a contiguous tensor {B.is_contiguous()}\n')
print(f'torch transpose results a contiguous tensor {A_tr.is_contiguous()}')

torch transpose is equal to triton tranaspose True 

triton transpose results a contiguous tensor True

torch transpose results a contiguous tensor False
