# Intro to Triton: softmax

## Summary

In this example, we will look at softmax, written in Triton. 

Here is the calculation we will be implementing. For the 2D tensor X, compute softmax on each row. (Row index denoted "r" and column index denoted "c")

For each r, compute:
$$rowmax(X_r) = max_c X_r$$
$$softmax(X_r) = \frac{exp(X_r - rowmax(X_r))}{\sum_c exp(X_r - rowmax(X_r))}$$ 

## Triton implementation

We are going to write a softmax in Triton step-by-step. In each step, we'll have a working Triton program. Each version will introduce a new concept.
Here are the steps we'll take:
1. Compute softmax on a single row
2. Compute softmax on all the rows, with each row computed in parallel
3. Generalize so that it works on input whose shape(1) is not limited to powers of two

At each step, we will also
- demonstrate how to call the Triton kernel from within a Pytorch program
- check that we have a correct implementation by comparing the results against those of Pytorch

## Version 1: Compute softmax on a single row

Here is one way to compute the softmax using Triton, with the caveat that it computes softmax on just the first row.

In [1]:
import triton
import triton.language as tl

# this annotation marks the function as a Triton program
@triton.jit(interpret=True) 
def softmax_kernel_v1(
    output, # pointer to the Pytorch tensor where the result will go 
    input,  # pointer to the Pytorch tensor containing the input
    NUM_COLUMNS: tl.constexpr # number of columns in the input and output tensors; limitation: must be a power of 2
):
    # form an array of pointers to the elements of the input tensor
    col_offsets = tl.arange(0, NUM_COLUMNS)       
    input_pointer_array = input + col_offsets

    # load the row from the input tensor into a block
    row = tl.load(input_pointer_array)

    ############################
    # compute softmax on the row
    # Triton uses similar syntax for operations on blocks
    # as is used for numpy arrays and Pytorch tensors
    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_of_row = numerator / denominator
    ############################

    # form an array of pointers to the elements of the output tensor
    output_pointer_array = output + col_offsets

    # store the result block to the output tensor
    tl.store(output_pointer_array, softmax_of_row)

The softmax calculation itself (in the middle) looks similar to how you would write the softmax in Pytorch or Numpy. However, there are some additional details, in the rest of the function: pointers, load, and store.

To understand why the additional details are necessary, it is helpful to look at a diagram representing the architecture of a typical GPU. 

<img src="img/architecture_load_store1.png" alt="TODO before practice talk" width="1024"/>

The "device memory" is where the Pytorch tensors live. And, there are some number of "processing units" that do the computation. We'll focus on a single processing unit for now. In order for the processing unit to do calculations on the data, it must first "load" data from device memory into its "local memory". And for the result of the calculations to be saved, the processing unit must "store" from local memory into device memory.

Triton has `tl.load` and `tl.store` for doing load and store. In the above program `row` is a "block" of data residing in local memory; a copy of the first row of the input Pytorch tensor.

To tell the system what place in device memory we want to load/store to, we need to provide a "pointer". A Triton pointer refers to the starting location of some data in device memory (i.e., a location within a Pytorch tensor).
A Triton program does not return values directly, rather it stores the result to a Pytorch tensor that has already been created in device memory before the function was called.

One more detail. While the pointers passed into the function really just refer to the start of the Pytorch tensor, `tl.load`/`tl.store` actually take an array of pointers. By "array" we mean a regular, N-dimensional structure, much like a numpy array or Pytorch tensor. To create this array of pointers, we add the indices of the columns. In this case we want all the columns (indices 0 to `NUM_COLUMNS` - 1). Here is an example:

In [2]:
# using numpy to examine your Triton kernel's pointer calculations is a helpful debugging trick!
import numpy
numpy.random.seed(0)

def pointer_example1(input, n_rows, NUM_COLUMNS):
    input_data = numpy.round(numpy.random.rand(n_rows, NUM_COLUMNS), 3)
    print("input (tensor) = {}".format(input_data))
    
    # input represents the pointer, which is just a single integer representing the starting location of the input tensor living in device memory
    print("input (pointer) = {}".format(input))
    print("NUM_COLUMNS = {}".format(NUM_COLUMNS))
    
    col_offsets = numpy.arange(0, NUM_COLUMNS)
    print("col_offsets = {}".format(col_offsets))
          
    input_pointer_array = input + col_offsets

    print("row will be a block of shape {}, containing a copy of the data at device memory locations {}\nwhich is the data {}\n".format(input_pointer_array.shape, input_pointer_array, input_data[0][col_offsets]))

pointer_example1(400, 4, 6)
pointer_example1(1300, 3, 10)
# try your own!

input (tensor) = [[0.549 0.715 0.603 0.545 0.424 0.646]
 [0.438 0.892 0.964 0.383 0.792 0.529]
 [0.568 0.926 0.071 0.087 0.02  0.833]
 [0.778 0.87  0.979 0.799 0.461 0.781]]
input (pointer) = 400
NUM_COLUMNS = 6
col_offsets = [0 1 2 3 4 5]
row will be a block of shape (6,), containing a copy of the data at device memory locations [400 401 402 403 404 405]
which is the data [0.549 0.715 0.603 0.545 0.424 0.646]

input (tensor) = [[0.118 0.64  0.143 0.945 0.522 0.415 0.265 0.774 0.456 0.568]
 [0.019 0.618 0.612 0.617 0.944 0.682 0.36  0.437 0.698 0.06 ]
 [0.667 0.671 0.21  0.129 0.315 0.364 0.57  0.439 0.988 0.102]]
input (pointer) = 1300
NUM_COLUMNS = 10
col_offsets = [0 1 2 3 4 5 6 7 8 9]
row will be a block of shape (10,), containing a copy of the data at device memory locations [1300 1301 1302 1303 1304 1305 1306 1307 1308 1309]
which is the data [0.118 0.64  0.143 0.945 0.522 0.415 0.265 0.774 0.456 0.568]



To ground this in a more familiar analogy, if this was purely numpy/Pytorch code, and `input` was an array/tensor (rather than a pointer), the code would be `row = input[0][0:NUM_COLUMNS]`.

### How to call a Triton kernel

Since Triton passes results through pointers, calling a Triton kernel involves an additional detail. We need to allocate a Pytorch tensor where the result will go.

To hide this detail (and other details, as we'll see) it is customary to hide the call to a Triton kernel inside a wrapper function.

In [3]:
import torch

# Compute softmax of x (limitation: just the first row)
def softmax_v1(x):    
    # Allocate output
    y = torch.empty_like(x)

    n_cols = x.shape[1]

    # Call the kernel and wait for it to finish
    softmax_kernel_v1[(1,)](y, x, NUM_COLUMNS=n_cols)

    return y

Ignore the `[(1,)]` for now. It has to do with how Triton introduces parallelism, which we'll look at in the next version.

Let's test it against Pytorch's built-in softmax.

In [4]:
torch.manual_seed(0)
x = torch.randn((4, 6))
print("x = {}".format(x))

# test helper
def check(expected, actual):
    if torch.allclose(expected, actual):
        print("PASS")
    else:
        print("FAIL")

# run Pytorch softmax
# just look at first row for now
torch_softmax_result = torch.softmax(x, dim=1)
print("torch softmax (row 0)  = {}".format(torch_softmax_result[0]))

# run Triton softmax
# just look at first row for now
softmax_v1_result = softmax_v1(x)
print("triton softmax (row 0) = {}".format(softmax_v1_result[0]))

# compare
# just look at first row for now
check(torch_softmax_result[0], softmax_v1_result[0])

x = tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920],
        [-0.3160, -2.1152,  0.4681, -0.1577,  1.4437,  0.2660],
        [ 0.1665,  0.8744, -0.1435, -0.1116,  0.9318,  1.2590],
        [ 2.0050,  0.0537,  0.6181, -0.4128, -0.8411, -2.3160]])
torch softmax (row 0)  = tensor([0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121])
triton softmax (row 0) = tensor([0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121])
PASS


### Case study: Using Triton to implement a Pytorch Function

To see an example of writing a custom `torch.autograd.Function` in Pytorch using Triton, take a look at the [cross entropy example](https://github.com/openai/triton/blob/04e47d7712e218721e54e741d73729736509abc2/python/triton/ops/cross_entropy.py#L63).
In particular to focus on right now:
- The `forward` method [calls the Triton kernel](https://github.com/openai/triton/blob/04e47d7712e218721e54e741d73729736509abc2/python/triton/ops/cross_entropy.py#L75) `_forward` (you can name the kernel what you want of course)
- The `backward` method [calls the Triton kernel](https://github.com/openai/triton/blob/04e47d7712e218721e54e741d73729736509abc2/python/triton/ops/cross_entropy.py#L93) `_backward`
- Reminder that we need to pass the Triton kernel a tensor to put its results in, whether that is [allocated](https://github.com/openai/triton/blob/04e47d7712e218721e54e741d73729736509abc2/python/triton/ops/cross_entropy.py#L72) or [already exists](https://github.com/openai/triton/blob/04e47d7712e218721e54e741d73729736509abc2/python/triton/ops/cross_entropy.py#L88)
- While this example implements all the computation in Triton, if you only wanted to implement part of your operator in Triton, you could use a mix of Pytorch and Triton

# Version 2: Compute softmax on all the rows, with each row computed in parallel

The function `softmax_v1` only computes softmax on the first row of `x`. To compute softmax for all the rows, we can have Triton call the kernel once for each row.
The tuple in the `[]` is called the "launch grid" and it says how many calls to make. 

In our existing code, the `[(1,)]` means "please launch one instance of the kernel."

```
softmax_kernel_v1[(1,)](y, x, NUM_COLUMNS=n_cols)
```

We can replace the `1` with `x.shape[0]` to say "please launch `x.shape[0]` (number of rows) instances of the kernel."

```
softmax_kernel_v1[(x.shape[0],)](y, x, NUM_COLUMNS=x.shape[1])
```

The calls will run in parallel, up to the available processing units on your device.

We'll also need to make a modification to the kernel. If we left it as is, it would just compute the first row over and over again!

In [5]:
# differences from softmax_kernel_v1 are highlighted with comments

@triton.jit(interpret=True) 
def softmax_kernel_v2(
    output, 
    input,  
    input_row_stride, # number of elements to get to the next row in input tensor
    output_row_stride, # number of elements to get to the next row in output tensor
    NUM_COLUMNS: tl.constexpr 
):
    
    # find out which row we are assigned
    row_index = tl.program_id(0)

    # calculate the pointer to the start of our assigned row
    row_start_pointer = input + row_index * input_row_stride
    
    col_offsets = tl.arange(0, NUM_COLUMNS)
    # use the row_start_pointer now instead of input (which would always be pointing to row 0)
    input_pointer_array = row_start_pointer + col_offsets
    
    row = tl.load(input_pointer_array)

    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    # calculate the pointer to the start of our assigned row
    output_row_start_pointer = output + row_index * output_row_stride
    # use the output_row_start_pointer now instead of input (which would always be pointing to row 0)
    output_pointer_array = output_row_start_pointer + col_offsets
    
    tl.store(output_pointer_array, softmax_output)

### program_id

Triton uses a programming model of Single Program Multiple Data (SPMD). Each call accesses a different subset of the input/output, thereby allowing for parallelism.
It is the programmer's job to say which data belongs to each call. We can ask which call by using `tl.program_id(0)`. In the modified kernel, we
used `tl.program_id(0)` to pick which row we are on.

To seek to the assigned input and output row, we'll need to change the pointer arithmetic to use `row_index`. To do so, we need one more piece of information: stride.

### Stride

To calculate the pointer to the start of a specific row, we need to take into account how far apart elements are in the row dimension. The computer memory is a 1d array, so addresses are 1d. 
Along a specific dimension of a tensor, the distance in memory from one element to the next is the "stride".

`row_start_pointer = input + row_index * input_row_stride`

Here's an example:


In [6]:
# reminder: this is numpy code, used for the purpose of understanding the pointer arithmetic
def pointer_example2(input, n_rows, NUM_COLUMNS, program_id0, input_row_stride):
    input_data = numpy.round(numpy.random.rand(n_rows, NUM_COLUMNS), 3)
    print("input (tensor) = {}".format(input_data))
    
    print("input = {}".format(input))
    print("NUM_COLUMNS = {}".format(NUM_COLUMNS))
    print("program_id(0) = {}".format(program_id0))
    print("input_row_stride = {}".format(input_row_stride))

    row_index = program_id0
    print("row_index = {}".format(program_id0))

    row_start_pointer = input + row_index * input_row_stride
    print("row_start_pointer = {}".format(row_start_pointer))
    
    col_offsets = numpy.arange(0, NUM_COLUMNS)
    print("col_offsets = {}".format(col_offsets))
          
    input_pointer_array = row_start_pointer + col_offsets

    print("row will be a block of shape {}, containing a copy of the data at device memory locations {}\nwhich is the data {}\n".format(input_pointer_array.shape, input_pointer_array, input_data[row_index][col_offsets]))

pointer_example2(400, 4, 6, 0, 6)
pointer_example2(400, 4, 6, 3, 6)
pointer_example2(1300, 3, 16, 0, 16)
pointer_example2(1300, 3, 16, 1, 16)
# try your own!

input (tensor) = [[0.209 0.161 0.653 0.253 0.466 0.244]
 [0.159 0.11  0.656 0.138 0.197 0.369]
 [0.821 0.097 0.838 0.096 0.976 0.469]
 [0.977 0.605 0.739 0.039 0.283 0.12 ]]
input = 400
NUM_COLUMNS = 6
program_id(0) = 0
input_row_stride = 6
row_index = 0
row_start_pointer = 400
col_offsets = [0 1 2 3 4 5]
row will be a block of shape (6,), containing a copy of the data at device memory locations [400 401 402 403 404 405]
which is the data [0.209 0.161 0.653 0.253 0.466 0.244]

input (tensor) = [[0.296 0.119 0.318 0.414 0.064 0.692]
 [0.567 0.265 0.523 0.094 0.576 0.929]
 [0.319 0.667 0.132 0.716 0.289 0.183]
 [0.587 0.02  0.829 0.005 0.678 0.27 ]]
input = 400
NUM_COLUMNS = 6
program_id(0) = 3
input_row_stride = 6
row_index = 3
row_start_pointer = 418
col_offsets = [0 1 2 3 4 5]
row will be a block of shape (6,), containing a copy of the data at device memory locations [418 419 420 421 422 423]
which is the data [0.587 0.02  0.829 0.005 0.678 0.27 ]

input (tensor) = [[0.735 0.962 0.249

### Calling the new kernel

Here is how to call the new version of the kernel.

In [7]:
# differences from softmax_v1 are highlighted with comments

# Compute softmax of x
def softmax_v2(x):    
    y = torch.empty_like(x)

    n_rows, n_cols = x.shape

    # Call n_rows copies of the kernel
    # Triton will make sure each has a unique program_id between 0 and n_rows-1
    softmax_kernel_v2[(n_rows,)](y, x,
                                 input_row_stride=x.stride(0),  # get the strides from the torch tensors
                                 output_row_stride=y.stride(0),
                                 NUM_COLUMNS=n_cols)

    return y

softmax_v2_result = softmax_v2(x)
print("softmax_v2_result =\n{}".format(softmax_v2_result))
check(torch_softmax_result, softmax_v2_result)

softmax_v2_result =
tensor([[0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121],
        [0.0825, 0.0136, 0.1806, 0.0966, 0.4791, 0.1476],
        [0.1036, 0.2103, 0.0760, 0.0785, 0.2227, 0.3089],
        [0.6442, 0.0915, 0.1609, 0.0574, 0.0374, 0.0086]])
PASS


#### Launch grid
The tuple given within `[(n_rows,)]` is called the "launch grid". In this example, we are using a 1d launch grid, `(n_rows,)` which means: call n_rows copies of the kernel, with `tl.program_id(0)` being 0, 1, ..., n_rows-1. 

Aside: Launch grid is a tuple rather than an integer because Triton also accepts 2d and 3d launch grids, which is convenient for slicing the data in multiple dimensions.

#### Strides
We also had to pass in arguments for the input and output strides. Pytorch tensors actually have an API for getting the stride of each dimension. In this case we want dim 0, that is, the row dimension.

Technicality: It would be good form to also pass in a stride for the column dimension, but we are making an assumption that the data is stored densely in row-major. That is that `stride(0) >= n_cols` and `stride(1) == 1`.

### Parallelism

Triton gives you different kinds of control over the parallel execution of a kernel.
- The programmer specifies the number of instances in the launch grid. That, along with how the kernel uses `program_id`, determines how to parallelize the problem across kernel instances. Each kernel instance MAY be assigned to a different processing unit ("SM" on Nvidia GPUs). It is MAY because the number of instances may be greater than the number of available processing units.
  - In our example, we each row assigned to a different kernel instance
  - Aside: Common term for this parallelism is SPMD
- The Triton compiler is in charge of parallelizing operations on blocks onto GPU threads within a single processing unit.
  - In our example, all the tensor calculations (max, exp, sum, etc.) are parallelized this way
  - Aside: Common terms for this parallelism include SIMD and SIMT
- The user can optionally exercise control over scheduling GPU resources with the `num_warps` and `num_stages` parameters (not used in our example)

## Version 3: Generalize so that it works on input whose shape(1) is not limited to powers of two

Triton actually has a limitation that the loaded/stored block must have dimensions that are a power-of-two. This has to do with the algorithm the GPU compiler currently uses to produce efficient code.

**Q: What? But the examples so far all work with `NUM_COLUMNS=6`, which is not a power of two!**
- A: The examples work because in this notebook we are running these examples in the Triton interpreter, rather than compiling them.
 

Triton has a concept called "masks" for fitting any shape of data into fixed-sized blocks. This concept is illustrated in the following version of softmax.

In [8]:
@triton.jit(interpret=True) 
def softmax_kernel_v3(
    output, 
    input,  
    input_row_stride, 
    output_row_stride,
    n_cols,                  # the actual number of columns in input and output
    BLOCK_SIZE: tl.constexpr # the desired length of a loaded/stored block
):
    row_index = tl.program_id(0)

    row_start_pointer = input + row_index * input_row_stride
    
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_pointer_array = row_start_pointer + col_offsets

    # do not load elements past the end of the last column
    # For those remaining elements, set the element in the block to negative infinity
    row = tl.load(input_pointer_array, mask=col_offsets < n_cols, other=-float('inf'))

    row_minus_max = row - tl.max(row, axis=0)    
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    output_row_start_pointer = output + row_index * output_row_stride
    output_pointer_array = output_row_start_pointer + col_offsets

    # only store those elements from the block that are within bounds of the output tensor
    tl.store(output_pointer_array, softmax_output, mask=col_offsets < n_cols)

The mask argument of `tl.load` is a condition over all the elements of the block. For elements where the `mask` is `True`, we'll load from the input tensor. For elements where the `mask` is `False` we fill them with the value given by `other`.

Here's an example:

In [11]:
# reminder: this is numpy code, used for the purpose of understanding the pointer arithmetic
def pointer_example3(input, n_rows, n_cols, BLOCK_SIZE, program_id0, input_row_stride, other):
    input_data = numpy.round(numpy.random.rand(n_rows, n_cols), 3)
    print("input (tensor) = {}".format(input_data))
    
    print("input = {}".format(input))
    print("n_cols = {}".format(n_cols))
    print("BLOCK_SIZE = {}".format(BLOCK_SIZE))
    print("program_id(0) = {}".format(program_id0))
    print("input_row_stride = {}".format(input_row_stride))
    print("fill with padding: other = {}".format(other))

    row_index = program_id0
    print("row_index = {}".format(program_id0))

    row_start_pointer = input + row_index * input_row_stride
    print("row_start_pointer = {}".format(row_start_pointer))
    
    col_offsets = numpy.arange(0, BLOCK_SIZE)
    print("col_offsets = {}".format(col_offsets))
          
    input_pointer_array = row_start_pointer + col_offsets

    # emulating invalid locations in device memory with None
    extended_input_row_data = list(input_data[row_index][col_offsets[:n_cols]]) + ([None] * (BLOCK_SIZE - n_cols))
    # emulating the padding fill here from the mask=col_offsets < n_cols
    data_in_block = numpy.where(col_offsets < n_cols, extended_input_row_data, other)
    masked_pointers = numpy.where(col_offsets < n_cols, input_pointer_array, 'INV')
                               
    print("row will be a block of shape {}, containing a copy of the data at device memory locations {}\nmasked memory locations {}\nwhich is the data {}\n".format(input_pointer_array.shape, input_pointer_array, masked_pointers, data_in_block))

pointer_example3(400, 4, 6, 8, 0, 6, -float("inf"))
pointer_example3(400, 4, 6, 8, 3, 6, -float("inf"))
pointer_example3(1300, 3, 12, 16, 0, 16, 0.0)
pointer_example3(1300, 3, 12, 16, 1, 16, 0.0)
# try your own!

input (tensor) = [[0.861 0.727 0.27  0.131 0.055 0.302]
 [0.262 0.456 0.683 0.696 0.284 0.38 ]
 [0.181 0.789 0.057 0.697 0.779 0.777]
 [0.259 0.374 0.588 0.273 0.371 0.197]]
input = 400
n_cols = 6
BLOCK_SIZE = 8
program_id(0) = 0
input_row_stride = 6
fill with padding: other = -inf
row_index = 0
row_start_pointer = 400
col_offsets = [0 1 2 3 4 5 6 7]
row will be a block of shape (8,), containing a copy of the data at device memory locations [400 401 402 403 404 405 406 407]
masked memory locations ['400' '401' '402' '403' '404' '405' 'INV' 'INV']
which is the data [0.861 0.727 0.27 0.131 0.055 0.302 -inf -inf]

input (tensor) = [[0.46  0.045 0.8   0.077 0.519 0.307]
 [0.578 0.959 0.646 0.035 0.43  0.51 ]
 [0.536 0.681 0.278 0.129 0.393 0.956]
 [0.187 0.904 0.544 0.457 0.882 0.459]]
input = 400
n_cols = 6
BLOCK_SIZE = 8
program_id(0) = 3
input_row_stride = 6
fill with padding: other = -inf
row_index = 3
row_start_pointer = 418
col_offsets = [0 1 2 3 4 5 6 7]
row will be a block of shape

### tl.constexpr

Up to this point, we've been ignoring the fact that `BLOCK_SIZE` (previously `NUM_COLUMNS`) is annotated with the type `tl.constexpr`.
An expression with type `tl.constexpr`, as in "constant expression", means that its value is known at compile time. For example, when
`BLOCK_SIZE` is `8`, Triton will compile a version of softmax specifically where `BLOCK_SIZE` is 8. 

In fact, in Triton all blocks MUST have a shape that is a constant expression. This opens significant opportunities
for compiler optimizations.

Let's summarize the restrictions on our kernels:
- `softmax_kernel_v1` and `softmax_kernel_v2` have the restriction that `NUM_COLUMNS` must be equal to the number of columns in the input (and output).
- `softmax_kernel_v3` relaxes the restriction to `n_cols <= BLOCK_SIZE`.

Those restrictions come from the way we wrote our Triton code. In general, the only fundamental restrictions are those we mentioned above: shape of blocks must 1) constant expressions and 2) powers of two.
(And the power of two restriction is a limitation of the GPU compiler that will eventually be relaxed).

### Calling the new kernel

In [10]:
# differences from softmax_v2 are highlighted with comments

def softmax_v3(x):    
    y = torch.empty_like(x)

    n_rows, n_cols = x.shape

    # For how we've written softmax_kernel_v3, we require n_cols <= BLOCK_SIZE
    # Combined with Triton's general restriction that block shape be powers of two,
    # we set BLOCK_SIZE to the next power of 2 greater than or equal to n_cols.
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    print("n_cols = {}; BLOCK_SIZE = {}".format(n_cols, BLOCK_SIZE))

    softmax_kernel_v3[(n_rows,)](y, x,
                                 input_row_stride=x.stride(0),
                                 output_row_stride=y.stride(0),
                                 n_cols=n_cols, # reminder: kernel needs the number of columns to do the masking
                                 BLOCK_SIZE=BLOCK_SIZE)

    return y

softmax_v3_result = softmax_v3(x)
print("softmax_v3_result =\n{}".format(softmax_v3_result))
check(torch_softmax_result, softmax_v3_result)

n_cols = 6; BLOCK_SIZE = 8
softmax_v3_result =
tensor([[0.0507, 0.0494, 0.1216, 0.1012, 0.3650, 0.3121],
        [0.0825, 0.0136, 0.1806, 0.0966, 0.4791, 0.1476],
        [0.1036, 0.2103, 0.0760, 0.0785, 0.2227, 0.3089],
        [0.6442, 0.0915, 0.1609, 0.0574, 0.0374, 0.0086]])
PASS


### Triton does Just-in-Time compilation (JIT)

`BLOCK_SIZE` is a constant expression, yet we are passing the kernel a dynamic argument for that parameter.
The reason this works is that launching a Triton kernel actually consists of the Triton runtime system doing two steps: 1) compiling a GPU kernel and 2) running the kernel. 
So in step (1), the constant expressions are given their specific values. This approach is called JIT.

Because step (1) adds latency to running your code, Triton uses a cache so that a kernel need not be re-compiled repeatedly unless either the code or a constexpr is new.

## Conclusion

You've taken steps towards being able to do the following:
- Write a Triton implementation of a simple calculation
- Write code to call the Triton code from Pytorch
- Run code with Triton's CPU emulator
- Explain many of the key ideas in the Triton language

What you might do next:
- tweak the examples in this notebook
- check out the [Triton tutorials](https://triton-lang.org/main/getting-started/tutorials/index.html) and [documentation](https://triton-lang.org/main/index.html)
- run our Triton GPU notebooks in AzureML Studio
- [install](https://triton-lang.org/main/getting-started/installation.html) Triton yourself on a GPU-enabled platform

## Appendix: CPU Emulator?

As of writing this tutorial, Triton [runs on](https://github.com/openai/triton#compatibility) certain Nvidia GPUs. However, there is a useful feature to try programs on any platform that supports Pytorch 2.
That's what we refer to as the "CPU Emulator". And, it is how we can run Triton code in this notebook without a GPU (e.g., if you are using this within our Github Codespace).

Whether or not to use the CPU Emulator or GPU is controlled by this annotation on the Triton kernel:
```
# Use CPU emulator
@triton.jit(interpret=True) 

# Use GPU
@triton.jit
```

Note: as of writing this tutorial, you need to [build Triton from source](https://triton-lang.org/main/getting-started/installation.html#from-source) to use the CPU Emulator because there is not yet an official release including the feature.
