In [None]:
from setup_triton import setup_triton
setup_triton()

# From PyTorch to Triton

Torch docs: https://docs.pytorch.org/docs/stable/index.html

Triton docs: https://triton-lang.org/main/index.html


## Why write Triton kernels at all?

| Scenario | PyTorch eager | `torch.compile` | Triton |
|-----------|--------------|-----------------|--------|
| Simple ops, plenty of kernels exist | ✅ | ✅ | ❌ (overkill) |
| Chain of ops → kernel fusion needed | ⚠️ limited | ✅ sometimes | ⭐ **full control** |
| Novel math / memory pattern | ❌ | ❌ | ⭐ **write it yourself** |

*In short:* Triton is for when you need **peak GPU throughput** and/or **custom data movement** that frameworks can’t fuse for you.


## Strides, contiguity, and why they matter

A 3×4 tensor laid row-major (C-contiguous):

- `A.data` -> a00 a01 a02 a03 a10 a11 a12 a13 a20 a21 a22 a23
- `A.stride(0)` -> 4
- `A.stride(1)` -> 1

*Per-dim stride* = *#elements to skip* to move by 1 in that dim.  
Contiguous tensors have monotonically decreasing strides; views (e.g., transpose) don’t.


For example, `A.stride(0) = 4` means I need to walk `4` cols in order to arrive at the next row.


In [None]:
import torch

# 3.1 Hands-on with strides
B = torch.arange(12, dtype=torch.float32).reshape(3, 4)
print("Strides:", B.stride())        # (4, 1)
print("Contiguous:", B.is_contiguous())  # True

C = B.t()           # transpose: shape 4×3  -> it's doing a "view" of B. That is, changing the stride!
print("Strides:", C.stride())        # (1, 4)
print("Contiguous:", C.is_contiguous())  # False

## Kernel writing

Triton gets **raw pointers** plus **strides**. Therefore, you must be aware of the strides!  Use `.contiguous()` in PyTorch before launching the kernel whenever necessary. 

> To the best of my knowledge, all kernels I've seen assume contiguous memory for simplicity.

### How Triton sees your tensors

Triton *does not* define its own tensor class; you pass **plain `torch.Tensor`s**:

```python
matmul_kernel[grid](
    A, B, C,                     # tensors (device = CUDA)
    M, N, K,                     # scalars (ints)
    A.stride(0), A.stride(1),    # .stride(i) returns the ith stride value (int)
    ...
)
```

Inside a Triton kernel you receive only pointers to your tensors along with integer strides. Everything else (shapes, dtype, device) must be tracked via arguments you pass. Let's see an example.


## Hello World in Triton

The cell below is **fully working**; run it to check your setup.

Don't worry about it just yet. We will go over each line, step-by-step, in the next notebooks.

<img src="figs/offsets.png" width="640" />

In [None]:
import triton
import triton.language as tl

@triton.jit  # compile-time decoration (this is what makes a kernel)
def square_kernel(
    x_ptr, 
    out_ptr,
    n_elements,  # this can vary
    BLOCK_SIZE: tl.constexpr  # this is a constant to Triton
):
    # recover the program id for the first grid axis 
    pid = tl.program_id(axis=0)

    # array containing [0, 1, ..., BLOCK_SIZE]
    offsets = tl.arange(0, BLOCK_SIZE)

    # define the exact indices for a given pid (see image above)
    idxs = pid * BLOCK_SIZE + offsets

    # load the content using these indices from HBM into SRAM
    x = tl.load(x_ptr + idxs)

    # perform computation in SRAM
    x_squared = x * x

    # save x_sq into HBM
    tl.store(out_ptr + idxs, x_squared)


########################################################################

# number of elements
N = 12

# size of each chunk
BLOCK_SIZE = 4

# my data
x = torch.randn(N, dtype=torch.float16).contiguous()

# allocate output memory
out = torch.empty_like(x)

# run "num_blocks threads" in parallel
grid = (triton.cdiv(N, BLOCK_SIZE),)

# launch the kernel by passing the grid as a decorator argument
square_kernel[grid](x, out, N, BLOCK_SIZE=BLOCK_SIZE)

# after this line, the output is stored in the `out`
# which is the pointer we used in `tl.store`

# Compare with groundtruth
torch.allclose(out, x**2, atol=1e-6)

## What’s next?

* Puzzle 1 – Vector Addition  
* Puzzle 2 – Fused Softmax  
* Puzzle 3 – Matmul (GEMM)
* Puzzle 4 - LayerNorm
* Puzzle 5 - Cross-Entropy
* Puzzle 6 - Softmax Attention
* Puzzle 7 - Sparsemax Attention

Happy hacking!  

<img src="figs/sardine-evolution.png" width="512" />