# From PyTorch to Triton

Torch docs: https://docs.pytorch.org/docs/stable/index.html

Triton docs: https://triton-lang.org/main/index.html


## Why write Triton kernels at all?

| Scenario | PyTorch eager | `torch.compile` | Triton |
|-----------|--------------|-----------------|--------|
| Simple ops, plenty of kernels exist | ✅ | ✅ | ❌ (overkill) |
| Chain of ops → kernel fusion needed | ⚠️ limited | ✅ sometimes | ⭐ **full control** |
| Novel math / memory pattern | ❌ | ❌ | ⭐ **write it yourself** |

*In short:* Triton is for the last two rows—when you need **peak GPU throughput** and/or **custom data movement** that frameworks can’t fuse for you.




## Strides, contiguity, and why they matter

A 3×4 tensor laid row-major (C-contiguous):

- `A.data` -> a00 a01 a02 a03 a10 a11 a12 a13 a20 a21 a22 a23
- `A.stride(0)` -> 4
- `A.stride(1)` -> 1

*Per-dim stride* = *#elements to skip* to move by 1 in that dim.  
Contiguous tensors have monotonically decreasing strides; views (e.g., transpose) don’t.


For example, `A.stride(0) = 4` means I need to walk `4` cols in order to arrive at the next row.


In [3]:
# 3.1 Hands-on with strides
B = torch.arange(12, dtype=torch.float32, device='cuda').reshape(3, 4)
print("Strides:", B.stride())        # (4, 1)
print("Contiguous:", B.is_contiguous())  # True

C = B.t()           # transpose: shape 4×3  -> it's doing a "view" of B. That is, changing the stride!
print("Strides:", C.stride())        # (1, 4)
print("Contiguous:", C.is_contiguous())  # False

Strides: (4, 1)
Contiguous: True
Strides: (1, 4)
Contiguous: False


🚨 **Kernel warning**: 

Triton gets raw pointers plus strides -- You must be aware of them.  

If you *require* contiguous, call `.contiguous()` in PyTorch **before** launching the kernel. 

To the best of my knowledge, all kernels I've seen assume contiguous memory for simplicity.

## Triton: how it sees your tensors

Triton *does not* define its own tensor class; you pass **plain `torch.Tensor`s**:

```python
matmul_kernel[grid](
    A, B, C,                     # tensors (device = CUDA)
    M, N, K,                     # scalars (ints)
    A.stride(0), A.stride(1),    # .stride(i) returns the ith stride value (int)
    ...
)
```

Inside a Triton kernel you receive only pointers to your tensors along with integer strides.

Everything else (shapes, dtype, device) must be tracked via arguments you pass.


## Hello World in Triton

Below cell is **fully working**; run it to check your setup.

Don't worry about it just yet. We will go over each line, step-by-step, in the next notebooks.

<img src="offsets.png" width="512" />

In [47]:
import triton
import triton.language as tl

@triton.jit  # compile-time decoration (this is what makes a kernel)
def square_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    offsets = tl.arange(0, BLOCK_SIZE)
    idxs = pid * BLOCK_SIZE + offsets
    x = tl.load(x_ptr + idxs)
    tl.store(out_ptr + idxs, x * x)

# Launch
N = 128
BLOCK_SIZE = 16
x = torch.randn(N, device='cuda', dtype=torch.float16).contiguous()
out = torch.empty_like(x)
grid = (triton.cdiv(N, BLOCK_SIZE),)
square_kernel[grid](x, out, N, BLOCK_SIZE=BLOCK_SIZE)

# Compare with groundtruth
torch.allclose(out, x**2, atol=1e-6)

True

## What’s next?

* Puzzle 1 – Vector Addition  
* Puzzle 2 – Fused Softmax  
* Puzzle 3 – Matmul (GEMM)
* Puzzle 4 - LayerNorm
* Puzzle 5 - Cross-Entropy
* Puzzle 6 - Softmax Attention
* Puzzle 7 - Sparsemax Attention

Happy hacking!  

<img src="sardine-evolution.png" width="800" />