<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Tanh_Kernel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import triton.language as tl
import triton
import torch
import torch.nn as nn

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
@triton.jit
def tanh(a):
  return (tl.exp(a) - tl.exp(-a) ) / (tl.exp(a) + tl.exp(-a))

In [17]:
@triton.jit
def forward_tanh_kernel(a_ptr,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  block_id = pid * BLOCK_SIZE
  offsets = block_id + tl.arange(0,BLOCK_SIZE)
  mask = offsets < n_elements
  a = tl.load(a_ptr + offsets,mask=mask)
  out = tanh(a)
  tl.store(out_ptr + offsets,out,mask=mask)

In [18]:
def forward_tanh(a):
  out = torch.empty_like(a,device='cuda')
  n_elements = a.numel()
  grid = lambda meta: ( triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  forward_tanh_kernel[grid](a,out,n_elements,BLOCK_SIZE=1024)
  return out


In [19]:
@triton.jit
def backward_tanh_kernel(a_ptr,grad_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  block_id = pid * BLOCK_SIZE
  offsets = block_id + tl.arange(0,BLOCK_SIZE)
  a = tl.load(a_ptr + offsets,mask= offsets < n_elements)
  out = tanh(a)
  Da = 1.0 - ( out * out)
  tl.store(grad_ptr + offsets,Da,mask = offsets < n_elements)

In [20]:
def backward_tanh(a):
  grad = torch.empty_like(a,device='cuda')
  n_elements = a.numel()
  grid = lambda meta: (triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  backward_tanh_kernel[grid](a,grad,n_elements,BLOCK_SIZE=1024)
  return grad

In [21]:
a = torch.randn((1,5),requires_grad=True,device='cuda')
b = forward_tanh(a)
c = backward_tanh(a)

In [22]:
import torch.nn as nn
m = nn.Tanh()
b_pytorch = m(a)
d = b_pytorch.sum()
d.backward()


In [23]:
print(torch.allclose(b_pytorch,b))
print(torch.allclose(a.grad,c))

True
True
