<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/DyT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [101]:
import torch
import triton
import triton.language as tl
from torch import nn
from torch.nn import functional as F

In [191]:
# Define our custom PyTorch layer
class DyT(nn.Module):
    def __init__(self, scaling_factor: float):
        super().__init__()
        self.scale = scaling_factor  # Scaling factor for the tanh activation
        # Parameters initialized as float64 for precision
        self.weight = nn.Parameter(torch.ones(1, dtype=torch.float64), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(1, dtype=torch.float64), requires_grad=True)

    def forward(self, x):
        # Simple forward pass: weight * tanh(scale * x) + bias
        return self.weight * F.tanh(self.scale * x) + self.bias

In [192]:
@triton.jit
def tanh(x):
    exp_x = tl.exp(x)
    exp_neg_x = 1.0 / exp_x
    return (exp_x - exp_neg_x) / (exp_x + exp_neg_x)

In [193]:
@triton.jit
def _forward_kernel(x_ptr,normalized_ptr,stride_m,stride_n,scale,weight_ptr,
                    bias_ptr,m,n,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,num_warps=64):
    # Get program IDs for the block grid
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    # Calculate offsets for our block
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    # Create a mask to avoid out-of-bounds access
    mask = (offs_m[:, None] < m) & (offs_n[None, :] < n)

    # Pointer arithmetic to load our chunk of input
    x_ptrs = x_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    x = tl.load(x_ptrs, mask=mask)

    # Load weight and bias (these are scalars)
    weight = tl.load(weight_ptr)
    bias = tl.load(bias_ptr)

    # Compute the output: weight * tanh(scale * x) + bias
    out = weight * tanh(scale * x) + bias  # Note: using Triton's tanh

    # Store the result in our output tensor
    out_ptrs = normalized_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    tl.store(out_ptrs, out, mask=mask)

In [199]:
def _forward_DyT(x:torch.tensor,scale:float,weight:torch.tensor,bias:torch.tensor):
  assert x.is_cuda and weight.is_cuda, "Inputs must be on CUDA!"
  assert x.is_contiguous(), "Input tensor must be contiguous for Triton."

  m, n = x.shape  # Get input dimensions
  # Create an empty tensor to hold the output
  normalized = torch.empty_like(x, device=x.device, dtype=x.dtype)

  # Define block sizes for the grid (tuned for performance)
  BLOCK_SIZE_M = 32
  BLOCK_SIZE_N = 64
  grid = (triton.cdiv(m, BLOCK_SIZE_M), triton.cdiv(n, BLOCK_SIZE_N))

  # Launch the kernel with our grid
  _forward_kernel[grid](
        x, normalized, x.stride(0), x.stride(1), scale, weight, bias,
        m, n, BLOCK_SIZE_M, BLOCK_SIZE_N)
  return normalized

In [200]:
@triton.jit
def _backward_kernel(x_ptr,dout_ptr,dweight_ptr,dbias_ptr,dx_ptr,stride_m,stride_n,
                     scale,weight_ptr,bias_ptr,m,n,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,num_warps=64):
  # Same grid setup as forward
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  mask = (offs_m[:, None] < m) & (offs_n[None, :] < n)

  # Load input and gradient of output (dout)
  x_ptrs = x_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
  dout_ptrs = dout_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
  x = tl.load(x_ptrs, mask=mask)
  dout = tl.load(dout_ptrs, mask=mask)

  # Load weight and bias
  weight = tl.load(weight_ptr)
  bias = tl.load(bias_ptr)  # Not really used here, but kept for consistency

  # Compute intermediate tanh for gradient calculations
  tan = tanh(scale * x)
  # Gradients: dweight and dbias are sums over the block
  dweight = tl.sum(dout * tan)
  dbias = tl.sum(dout)

  # Gradient w.r.t. input: dout * weight * scale * (1 - tanh^2)
  dx = dout * weight * scale * (1 - tan * tan)

  # Accumulate gradients atomically since multiple blocks might write to these
  tl.atomic_add(dweight_ptr, dweight)
  tl.atomic_add(dbias_ptr, dbias)

  # Store the input gradient
  dx_ptrs = dx_ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
  tl.store(dx_ptrs, dx, mask=mask)

In [201]:
def _backward_DyT(x:torch.tensor,dout:torch.tensor,scale:float,weight:torch.tensor,bias:torch.tensor):
  assert x.is_cuda and x.is_contiguous()
  assert dout.is_cuda
  dout = dout.is_contiguous() if not dout.is_contiguous() else dout
  dweight = torch.empty(1,device='cuda',dtype=x.dtype)
  dbias = torch.empty(1,device='cuda',dtype=x.dtype)
  dx = torch.empty_like(x,device='cuda',dtype=x.dtype)
  m,n = x.shape
  BLOCK_SIZE_M = 32
  BLOCK_SIZE_N = 64
  grid = (triton.cdiv(m,BLOCK_SIZE_M),triton.cdiv(n,BLOCK_SIZE_N))
  _backward_kernel[grid](x,dout,dweight,dbias,dx,x.stride(0),
                         x.stride(1),scale,weight,bias,m,n,BLOCK_SIZE_M,BLOCK_SIZE_N)
  return (dx,dweight,dbias)

In [202]:
# Test it out!
layer = DyT(0.5).to('cuda')  # Create our layer and move to GPU
input = torch.rand((512, 512), device='cuda', dtype=torch.float64, requires_grad=True)

# PyTorch forward and backward
out = layer(input)
loss = out.sum()
loss.backward()

# Triton forward and backward
out_triton = _forward_DyT(input, layer.scale, layer.weight, layer.bias)
grads = _backward_DyT(input, torch.ones_like(input), layer.scale, layer.weight, layer.bias)

In [203]:
print("Output match between PyTorch and Triton:", torch.allclose(out, out_triton))

print("Input gradient match:", torch.allclose(input.grad, grads[0]))
# allowing a relative tolerance of 1e-2 for floating-point differences
print("Weight gradient match (tol=1e-2):", torch.allclose(layer.weight.grad, grads[1], rtol=1e-2))
# also with a tolerance of 1e-2
print("Bias gradient match (tol=1e-2):", torch.allclose(layer.bias.grad, grads[2], rtol=1e-2))

Output match between PyTorch and Triton: True
Input gradient match: True
Weight gradient match (tol=1e-2): True
Bias gradient match (tol=1e-2): True
