<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/diffusion_backward_process.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import triton
import triton.language as tl

In [11]:
@triton.jit
def denoising_kernel(xt_ptr,noise_ptr,z_ptr,out_ptr,stride_b,stride_c,stride_m,
                     btt_ptr,coeff_ptr,at_ptr,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,
                     b,c,h,w):
  batch_id = tl.program_id(axis=0)
  channel_id = tl.program_id(axis=1)
  second_id = tl.program_id(axis=2)
  width_block = tl.cdiv(w,BLOCK_SIZE_COL)
  width_id = second_id % width_block
  height_id = second_id // width_block
  offs_row = height_id * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_col = width_id * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_row[:,None] < h) & (offs_col[None,:] < w)
  xt_ptrs = xt_ptr + batch_id * stride_b + channel_id * stride_c + offs_row[:,None] * stride_m + offs_col[None,:]
  noise_ptrs = noise_ptr + batch_id * stride_b + channel_id * stride_c + offs_row[:,None] * stride_m + offs_col[None,:]
  z_ptrs = z_ptr + batch_id * stride_b + channel_id * stride_c + offs_row[:,None] * stride_m + offs_col[None,:]
  xt = tl.load(xt_ptrs,mask=mask)
  noise = tl.load(noise_ptrs,mask=mask)
  z = tl.load(z_ptrs,mask=mask)
  btt = tl.load(btt_ptr)
  coeff = tl.load(coeff_ptr)
  at = tl.load(at_ptr)
  mu = tl.rsqrt(at) * (xt - coeff * noise)
  out = mu + tl.sqrt(btt) * z
  out_ptrs = out_ptr +  batch_id * stride_b + channel_id * stride_c + offs_row[:,None] * stride_m + offs_col[None,:]
  tl.store(out_ptrs,out,mask=mask)

In [23]:
def denoising(xt:torch.tensor,noise:torch.tensor,t:int):
  assert xt.is_cuda and xt.is_contiguous()
  assert noise.is_cuda and noise.is_contiguous()
  b,c,h,w = xt.shape
  t = torch.full((xt.size(0),),t,device=xt.device)
  beta = torch.linspace(2e-4,0.02,1000).to('cuda')
  alpha = 1 - beta
  alpha_bar = torch.cumprod(alpha,dim=0)
  beta_telda = beta[t] * (( 1 - alpha_bar.roll(1)[t]) /(1 - alpha_bar[t]))
  beta_telda[0] = beta[0]
  beta_telda_t = beta_telda.view(-1,1,1,1)
  alpha_bar_t = alpha_bar[t].view(-1,1,1,1)
  alpha_t = alpha[t].view(-1,1,1,1)
  coeff = beta[t].view(-1,1,1,1) * torch.rsqrt(1-alpha_bar_t)
  z = torch.randn_like(xt,device=xt.device,dtype=xt.dtype)
  out = torch.empty_like(xt,device=xt.device,dtype=xt.dtype)
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 32
  grid = (b,c,triton.cdiv(h,BLOCK_SIZE_ROW),triton.cdiv(w,BLOCK_SIZE_COL))
  denoising_kernel[grid](xt,noise,z,out,xt.stride(0),xt.stride(1),xt.stride(2),
                         beta_telda_t,coeff,alpha_t,BLOCK_SIZE_ROW,BLOCK_SIZE_COL,
                         b,c,h,w)
  return out


In [24]:
input = torch.randn((4,3,128,128),device='cuda')
noise = torch.randn_like(input)
out = denoising(input,noise,500)