<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/cosine_similarity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [1]:
import torch
import triton
import triton.language as tl

In [6]:
@triton.jit
def cos_sim_kernel(x_ptr,y_ptr,cosine_ptr,n_elements,BLOCK_SIZE:tl.constexpr,num_warps=1):
  pid = tl.program_id(axis=0)
  offs = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE)
  mask = offs < n_elements
  x = tl.load(x_ptr + offs,mask=mask,other=0.0)
  y = tl.load(y_ptr + offs,mask=mask,other=0.0)
  result = tl.sum(x * y)
  norm_x = tl.sum(x * x)
  norm_y = tl.sum(y * y)
  tl.atomic_add(cosine_ptr,result)
  tl.atomic_add(cosine_ptr+1,norm_x)
  tl.atomic_add(cosine_ptr+2,norm_y)


In [7]:
def cosine_similarity(x:torch.tensor,y:torch.tensor):
  assert x.is_cuda and y.is_cuda
  assert x.is_contiguous() and y.is_contiguous()
  assert len(x) == len(y)
  cosine = torch.zeros(3,device=x.device,dtype=x.dtype)
  block_size = 128 if len(x) > 128 else 32
  n_elements = x.numel()
  grid = (triton.cdiv(n_elements,block_size),)
  cos_sim_kernel[grid](x,y,cosine,n_elements,block_size)
  return cosine[0] / (cosine[1].sqrt() * cosine[2].sqrt())

In [25]:
a = torch.randn(4096,device='cuda')
b = torch.randn(4096,device='cuda')
cosine = cosine_similarity(a,b)
cosine1 = torch.nn.functional.cosine_similarity(a,b,dim=0)
print(torch.allclose(cosine,cosine1))

True


In [26]:
import torch
import triton
from prettytable import PrettyTable

def torch_cosine_similarity(x: torch.Tensor, y: torch.Tensor):
    return torch.nn.functional.cosine_similarity(x, y, dim=0)

# Benchmark function
def benchmark(fn, *args, warmup=8, steps=128):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    # Warmup
    for _ in range(warmup):
        fn(*args)
    torch.cuda.synchronize()
    # Actual timing
    start.record()
    for _ in range(steps):
        fn(*args)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / steps

# Define vector sizes for benchmarking
sizes = [
    1024,
    4096,
    16384,
    65536,
    4996*256
]

# Create random input tensors
x_tensors = [torch.randn(size, device='cuda', dtype=torch.float32).contiguous()
            for size in sizes]
y_tensors = [torch.randn(size, device='cuda', dtype=torch.float32).contiguous()
            for size in sizes]

# Run benchmarks
triton_times = [benchmark(cosine_similarity, x, y)
                for x, y in zip(x_tensors, y_tensors)]
torch_times = [benchmark(torch_cosine_similarity, x, y)
               for x, y in zip(x_tensors, y_tensors)]

# Create results table
table = PrettyTable()
table.field_names = ["Vector Size", "Triton Time (ms)", "PyTorch Time (ms)"]

for size, triton_t, torch_t in zip(sizes, triton_times, torch_times):
    table.add_row([f"{size:,}", f"{triton_t:.4f}", f"{torch_t:.4f}"])

print("Cosine Similarity Benchmark Results:")
print(table)

# Calculate and display speedup
print("\nSpeedup (PyTorch / Triton):")
for size, triton_t, torch_t in zip(sizes, triton_times, torch_times):
    speedup = torch_t / triton_t if triton_t > 0 else float('inf')
    print(f"Size {size:,}: {speedup:.2f}x")

Cosine Similarity Benchmark Results:
+-------------+------------------+-------------------+
| Vector Size | Triton Time (ms) | PyTorch Time (ms) |
+-------------+------------------+-------------------+
|    1,024    |      0.2392      |       0.0947      |
|    4,096    |      0.1535      |       0.1010      |
|    16,384   |      0.1385      |       0.0977      |
|    65,536   |      0.1348      |       0.0992      |
|  1,278,976  |      0.2225      |       0.2804      |
+-------------+------------------+-------------------+

Speedup (PyTorch / Triton):
Size 1,024: 0.40x
Size 4,096: 0.66x
Size 16,384: 0.70x
Size 65,536: 0.74x
Size 1,278,976: 1.26x
