<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Mish.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [68]:
import torch
import triton
import triton.language as tl

In [69]:
@triton.jit
def tanh(x):
    exp_x = tl.exp(x)
    exp_neg_x = 1.0 / exp_x
    return (exp_x - exp_neg_x) / (exp_x + exp_neg_x)
@triton.jit
def softplus(x):
  return tl.log(1+tl.exp(x))

In [70]:
@triton.jit
def forward_kernel(x_ptr,out_ptr,stride_m,stride_n,m,n,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=64):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_m[:,None] < m) & (offs_n[None,:] < n)
  input_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  x = tl.load(input_ptrs,mask=mask)
  out = x * tanh(softplus(x))
  tl.store(out_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n,out)


In [71]:
def _mish_forward(x:torch.tensor):
  assert x.is_cuda and x.is_contiguous()
  m,n = x.shape
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 64
  out = torch.empty_like(x,device=x.device,dtype=x.dtype)
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  forward_kernel[grid](x,out,x.stride(0),x.stride(1),m,n,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return out

In [89]:
@triton.jit
def backward_kernel(x_ptr,din_ptr,dout_ptr,stride_m,stride_n,m,n,
                    BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=64):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_m[:,None] < m) & (offs_n[None,:] < n)
  input_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  x = tl.load(input_ptrs,mask=mask)
  dout_ptrs = dout_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  dout = tl.load(dout_ptrs,mask=mask)
  tanh_sp = tanh(softplus(x))
  comp = x * (1 - tanh_sp*tanh_sp) * tl.sigmoid(x)
  din = tanh_sp + comp
  tl.store(din_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n,din)

In [90]:
def _mish_backward(x:torch.tensor,dout:torch.tensor):
  assert x.is_cuda and x.is_contiguous()
  assert dout.is_cuda
  dout = dout.contiguous()
  m,n = x.shape
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 64
  din = torch.empty_like(x,device=x.device,dtype=x.dtype)
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  backward_kernel[grid](x,din,dout,x.stride(0),x.stride(1),m,n,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return din

In [91]:
class Mish(torch.autograd.Function):
  @staticmethod
  def forward(ctx,input):
    ctx.save_for_backward(input)
    out = _mish_forward(input)
    return out
  @staticmethod
  def backward(ctx,dout):
    input = ctx.saved_tensors[0]
    dinput = _mish_backward(input,dout)
    return dinput

In [96]:
input = torch.rand((128,128),device='cuda',dtype=torch.float32,requires_grad=True)
out = Mish.apply(input)
loss = out.mean()
loss.backward()
triton_grad,input.grad = input.grad,None
out_torch = torch.nn.functional.mish(input)
loss1 = out_torch.mean()
loss1.backward()
torch_grad = input.grad

In [97]:
print(torch.allclose(input_grad,grad_torch))

True
