<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/outer_product.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [105]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 32}, num_warps=1),
        triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=8),
    ],
    key=['m', 'n'],
)
@triton.jit
def outer_product_kernel(a_ptr,b_ptr,out_ptr,stride_m,stride_n,m,n,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  a_offs = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE)
  b_offs = tl.arange(0,BLOCK_SIZE)
  a_mask = a_offs < m
  a = tl.load(a_ptr+a_offs,mask=a_mask)
  if BLOCK_SIZE >= n:
    b_mask = b_offs < n
    b = tl.load(b_ptr+b_offs,mask=b_mask)
    out = a[:,None] * b[None,:]
    out = out.to(dtype=tl.float32)
    out_ptrs = out_ptr + a_offs[:,None] * stride_m + b_offs[None,:] * stride_n
    tl.store(out_ptrs,out)
  else :
    for step in range(tl.cdiv(n,BLOCK_SIZE)):
      b = tl.load(b_ptr + step*BLOCK_SIZE + b_offs,mask=(step*BLOCK_SIZE+ b_offs)<n)
      out = a[:,None] * b[None,:]
      out = out.to(dtype=tl.float32)
      out_ptrs = out_ptr + a_offs[:,None] * stride_m + (step * BLOCK_SIZE + b_offs[None,:] )* stride_n
      tl.store(out_ptrs,out)

In [106]:
def outer_product(a,b):
  assert a.is_cuda and b.is_cuda
  m = a.shape[0]
  n = b.shape[0]
  out = torch.empty((m,n),device=a.device,dtype=a.dtype)
  assert m == a.numel()
  grid = lambda meta: (triton.cdiv(m,meta['BLOCK_SIZE']),)
  outer_product_kernel[grid](
      a,b,out,out.stride(0),out.stride(1),m,n)
  return out


In [100]:
a = torch.randn(4096,device='cuda')
b = torch.randn(4096,device='cuda')
c = outer_product(a,b)


In [101]:
d = torch.outer(a,b)

In [120]:


# Timing function using CUDA events
def benchmark(fn, *args, warmup=10, reps=100):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warm-up for accurate timing (important for GPU kernels)
    for _ in range(warmup):
        fn(*args)

    # Record actual timing
    torch.cuda.synchronize()
    start.record()
    for _ in range(reps):
        fn(*args)
    end.record()
    torch.cuda.synchronize()

    return start.elapsed_time(end) / reps  # Average time in milliseconds

# Generate random input tensors
m, n = 4096,4096
a = torch.randn(m, device='cuda')
b = torch.randn(n, device='cuda')

# Time the Triton kernel
triton_time = benchmark(outer_product, a, b)
print(f"Triton kernel average time: {triton_time:.3f} ms")

# Time the PyTorch implementation
torch_time = benchmark(lambda a, b: torch.outer(a, b), a, b)
print(f"PyTorch outer product average time: {torch_time:.3f} ms")

# Verify correctness
triton_out = outer_product(a, b)
torch_out = torch.outer(a, b)
assert torch.allclose(triton_out, torch_out, atol=1e-5), "Results do not match!"

print("Triton and PyTorch outputs match!")


Triton kernel average time: 0.329 ms
PyTorch outer product average time: 0.342 ms
Triton and PyTorch outputs match!
