# Kernel Development & Optimizations with Triton

Welcome to this hands-on workshop! [OpenAI Triton](https://github.com/triton-lang/triton) is an open-source programming language designed to simplify GPU programming for high-performance tasks, particularly in AI applications, which has been supported by AMD GPUs. This workshop will demonstrate how to set up the Triton development environment and optimize Triton Kernel perofrmance on AMD MI GPU. 

### Agenda:
- Set up Triton Development Enviroment 🖐️⚒️

- Kernel Development Examples (Vector-Add) 

- Kernel Development Hands-on (Online-Softmax) 

    - The Origin of Online-Softmax Algorithm

    - Online-Softmax Kernel Hands-on 🖐️⚒️

    - Performance Benchmark Result Visualization 

    - Skimmed through the advanced Fused-Softmax 🧠
 


## 1. Set up Triton development enviroment 
 

### Step1: Access AMD MI GPU 
In this workshop, we will use the **MI GPU** cloud instance from AMD Developer Cloud. Please access your GPU instance link, which has been provided by this workshop.

### Step2: Run ROCm Official Container on GPU card 
In this workshop, we will work on the pre-built ROCm PyTorch image, which pulled with command 'docker pull rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0'. It has integrated Pytorch with AMD ROCm software stack succeffully. Developers can also try other ROCm as the base image docker hub page from [docker images](https://hub.docker.com/r/rocm/pytorch/tags) if need. 

### Step3: Install OpenAI Triton

####    Uninstall the old Triton
It is strongly recommended to use the latest version Triton in your project, because AMD and other vendors are updating their optimization passes and algorithms frequently in [OpenAI Triton](https://github.com/triton-lang/triton), which can help improve your Triton kernel performance. 


In [None]:
!pip uninstall -y triton

#### Install OpenAI Triton from source codes
The detailed steps to install Triton have been listed here. If meeting any questions or issues when building Triton,please submit them in [Triton Issues](https://github.com/triton-lang/triton/issues).   


In [None]:
%%bash
# # Remove existing Triton folder if it exists
# if [ -d "triton" ]; then
#     echo "Removing existing triton directory..."
#     rm -rf triton
# fi

# # Clone Triton repo
# git clone https://github.com/triton-lang/triton.git

# # Install dependencies and Triton from source (non-editable install)
# cd triton
# pip install -r python/requirements.txt
# pip install .

pip install triton
pip install matplotlib
# if not docker based, pls uncomment below line:
# pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/rocm6.4
pip list | grep -E 'triton|torch'

# Pls ignore the incompatible error which will not  affect the examples' exection in this notebook.
# Confirm it install successfully by showing 'Successfully installed triton-xxx'.

## 2. Kernel Development Examples (Vector-Add)

Once Triton is installed on the machine successfully, we can validate whether it can work well or not on AMD GPU machine. By running the below vector-add sample through python, we can find that Triton kernel can give the sample result with Torch APIs, which means Triton can work well on AMD GPUs.

By running the below Autotune vector-add sample through python, we can find that Triton kernel can be tuned with triton.autotune API and get better result on AMD GPUs.

### Basic Vector-Add kernel (no autotune, fixed BLOCK_SIZE) 👇

In [None]:
import torch
import triton
import triton.language as tl

# ============================================================
# Version 1: Basic Vector-Add kernel (no autotune, fixed BLOCK_SIZE)
# ============================================================

@triton.jit
def add_kernel_basic(x_ptr,  # *Pointer* to first input vector.
                     y_ptr,  # *Pointer* to second input vector.
                     output_ptr,  # *Pointer* to output vector.
                     n_elements,  # Size of the vector.
                     BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
                     ):
    # There are multiple 'programs' processing different data. We identify which program we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    
    # Each program handles a contiguous range of elements.
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)


    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements

    # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of BLOCK_SIZE.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Compute output = x + y
    output = x + y

    # Write the results back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

### Autotuned Vector-Add kernel (w/ autotune of BLOCK_SIZE configuration) 👇

In [None]:
# ============================================================
# Version 2: Autotuned Vector-Add kernel
# ============================================================

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
        triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
        triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
        triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
    ],
    key=['SIZE_CATEGORY'],
)
@triton.jit
def add_kernel_autotune(x_ptr,  # *Pointer* to first input vector.
                        y_ptr,  # *Pointer* to second input vector.
                        output_ptr,  # *Pointer* to output vector.
                        n_elements,  # Size of the vector.
                        SIZE_CATEGORY: tl.constexpr, # Category of the problem. 
                        BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
                        ):
    # Identify the program ID along the 1D grid.
    pid = tl.program_id(axis=0)

    # Compute start offset for this program.
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Create a mask for safe memory operations.
    mask = offsets < n_elements

    # Load inputs from memory.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Perform vector addition.
    output = x + y

    # Store the results.
    tl.store(output_ptr + offsets, output, mask=mask)

### Wrapper Python function to call the kernel 👇

In [None]:
# ============================================================
# Wrapper function to launch either basic or autotuned kernel
# ============================================================
# We can now use the below function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:

USE_AUTOTUNE = False #  User control switch: False and True to enable or disable autotuning

def add(x: torch.Tensor, y: torch.Tensor):
    # Allocate output tensor
    output = torch.empty_like(x)
    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    # add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    if USE_AUTOTUNE:
        print("[INFO] Using autotuned kernel...")
        size_category = triton.next_power_of_2(n_elements)
        add_kernel_autotune[grid](x, y, output, n_elements,size_category)
    else:
        print("[INFO] Using basic kernel (no autotune)...")
        add_kernel_basic[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    # Return the output tensor.
    return output

### Give inputs & Test 👇

In [None]:
# ============================================================
# Test section
# ============================================================
if __name__ == "__main__":
    torch.manual_seed(0)
    DEVICE = triton.runtime.driver.active.get_active_torch_device()
    size = 98432
    x = torch.rand(size, device=DEVICE)
    y = torch.rand(size, device=DEVICE)

    # Compute with PyTorch for reference
    output_torch = x + y

    # Compute with Triton
    output_triton = add(x, y)

    print(output_torch)
    print(output_triton)

    # Compare results
    diff = torch.max(torch.abs(output_torch - output_triton))
    print(f"The maximum difference between torch and triton is {diff}")

    assert torch.allclose(output_torch, output_triton, atol=1e-3, rtol=1e-3), \
        f"Accuracy mismatch! Max diff = {diff}"

    print("✅ Triton kernel vector-add verification success!")

## 3. Kernel Development Hands-on (Online-Softmax)

### 3.1 The Origin of Online-Softmax Algorithm

The softmax function, often used in classification CNN models and even Transformer based LLM models,converts raw output scores or logits, into probabilities by taking the exponential of each value and normalizing these values by dividing by the sum of all the exponentials. This process ensures that the output values are in the range (0,1) and sum up to 1, making them interpretable as probabilities. PyTorch has implemented it as [a standard API](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html). 

Definition of function $y = Softmax(x)$ is:

$$
y_i = \frac{e^{x_i}}{\sum_{j=1}^{V} e^{x_j}} \tag{1}
$$

where $x,y \in \mathbb{R}^V$.

#### 3.1.1 Naive version - Safe-Softmax

To achieve the numerical stability, we need to subtract the maximum value of the row vector from each input element before taking their exponentials. So the definition changes to:

$$
y_i = 
\frac{e^{\left(x_i - \max_{k=1}^V x_k\right)}}
     {\sum_{j=1}^V e^{\left(x_j - \max_{k=1}^V x_k\right)}} \tag{2}
$$

where $x,y \in \mathbb{R}^V$. This is called `Safe Softmax`.

According to the softmax algorithm definition `(Equation 2)`, we implemented the naive version Triton kernel. To get the maximum data and the corresponding sum of all the exponentials, 2 for-loops are implemented in this version kernel, and there is still 1 for-loop to calculate the final softmax result. So total 3 loops are used in this kernel. The algorithm of `Safe Softmax` (from [this paper](https://arxiv.org/pdf/1805.02867)) is:

![safe_softmax](./assets/safe_softmax_algo.png)

Assuming that we need to test this kernel performance on an 8192x8192 tensor. We will conduct the calculation like this:

![softmax_naive](./assets/softmax_naive_50p.png)

- The block size of col dimension is 256.
- We allocate one program per row of the input tensor, which means the grid size is (n_rows,) where n_rows equals number of rows of input tensor.
- The program instance (thread block) scans one row of the tensor and iteratively process the data blocks of current row, to calculate the maximum value $m_k$ of current row. This is 1st for-loop.
- The program instance (thread block) scans one row of the tensor and iteratively process the data blocks of current row, to calculate the denominator (sum of exponentials) value $d_j$ of current row. This is 2nd for-loop.
- The program instance (thread block) scans one row of the tensor and iteratively process the data blocks of current row, to calculate the final softmax value $y_i$ of current row. This is 3rd for-loop.

#### 3.1.2 Online-softmax version 
Triton language is easy to implement an algorothm for GPU developpers. In order to have the better performance of current kernel, we can first figure out whether there is a more efficient algorithm/solution. If so, we had better try the new algorithm in our Triton Kernel. To reduce the memory access caused by 3 for-loops in naive SoftMax algorithm, a new algorithm of on-line softmax has been proposed in [this paper](https://arxiv.org/pdf/1805.02867). The online softmax algorithm is:


![online_softmax](./assets/online_softmax_algo.png)


The `Algorithm 3` calculates maximum value $m_j$ and denominator $d_j$ in a single for-loop, which can remove redundant memory overhead in `Algorithm 2`.

We will conduct the online calculation like this:

- The block size of col dimension is 256.
- We allocate one program per row of the input tensor, which means the grid size is (nrows,) where nrows equals number of rows of input tensor.
- The program instance (thread block) scans one row of the tensor and iteratively process the data blocks of current row, to calculate the maximum value $m_k$ and the denominator (sum of exponentials) value $d_j$ of current row. This is 1st for-loop.
- The program instance (thread block) scans one row of the tensor and iteratively process the data blocks of current row, to calculate the final softmax value $y_i$ of current row. This is 2nd for-loop.


### 3.2 Kernel Development Hands-on (Online-Softmax)

### Here are 5 things you need to do

Understand how the calculation is conducted, with the help of above `Algorithm 3`.
-  Write your own code in (**"YOUR CODE HERE 1️⃣"** in below code cell) to set up the program ID.
-  Write your own code in 1st for-loop (**"YOUR CODE HERE 2️⃣"** in below code cell) to set up the mask to make program only access elements with $offsets$ less than $n\_cols$. 
-  Write your own code in 1st for-loop (**"YOUR CODE HERE 3️⃣ "** in below code cell) to use tl.load() to load input data from global memory. It's combination of rows offset (pid*row_stride) and cols offset (n_cols).
-  Write your own code in 1st for-loop (**"YOUR CODE HERE 4️⃣ "** in below code cell) to finish the calculation of maximum value $m_k$.
-  Write your own code in 1st for-loop (**"YOUR CODE HERE 5️⃣️ "** in below code cell) to calculate the sum of exponentials as denominator (sum of exponentials) value $d_j$.

Tips: Check triton language APIs in this [link](https://triton-lang.org/main/python-api/triton.language.html).

After the code is finished, run the code cell below to check the result correctness and performance.


In [None]:
import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

# === Teaching mode flag ===
BEGINNER_MODE = False # True

# === Triton kernel ===
@triton.jit
def online_softmax_kernel(in_ptr, output_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    placeholder = tl.zeros([1], dtype=tl.float32) #for quiz only
    
    # === YOUR CODE HERE 1️⃣ ===: set up the program ID
    # [TODO] Replace placeholder[0] with your expression for 'pid'.
    pid = tl.program_id(0)
    
    in_max = -float('inf')
    in_exp_sum = 0.0

    # --- Step 1: Compute global max (Algorithm 3, Line 4) ---
    for col_idx in range(0, n_cols, BLOCK_SIZE):
        offsets = col_idx + tl.arange(0, BLOCK_SIZE)
        
        # === YOUR CODE HERE 2️⃣ ===: set up the mask to make program only access elements with offsets less than n_cols 
        # [TODO] Replace placeholder[0] with your expression for 'mask'.
        mask = offsets < n_cols

        # === YOUR CODE HERE 3️⃣ ===: use tl.load() to load input data from global memory
        # [TODO] Replace placeholder[0] with your expression for 'in_data'. It's combination 
        # of rows offset (pid*row_stride) and cols offset (offsets).
        in_data = tl.load(in_ptr + pid * row_stride + offsets, mask=mask, other=-float('inf'))
        
        # === YOUR CODE HERE 4️⃣ ===:  Calculate the max value (use tl.max() api) of current block data
        # according to Line 4 in Algorithm 3, get the new global max value by comparing with the
        # previous global max value in_max (use tl.maximum() api).
        # [TODO] Replace placeholder[0] with your expression for 'in_max_new'.
        in_max_new = tl.maximum(in_max, tl.max(in_data, axis=0))
        
        # === YOUR CODE HERE 5️⃣️ ===: # Calculate the sum of exponentials (use tl.exp() and tl.sum() apis)
        # of current block data according to Line 5 in Algorithm 3, and update the global sum of
        # exponentials in_exp_sum according to Line 5 in Algorithm 3
        # [TODO] Replace placeholder[0] with your expression for 'in_exp_sum'.
        in_exp_sum = in_exp_sum * tl.exp(in_max - in_max_new) + tl.sum(tl.exp(in_data - in_max_new), axis=0)

        in_max = in_max_new

    # --- Step 2: Normalize and write output ---
    for col_idx in range(0, n_cols, BLOCK_SIZE):
        offsets = col_idx + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        in_data = tl.load(in_ptr + pid * row_stride + offsets, mask=mask)
        in_exp = tl.exp(in_data - in_max)
        tl.store(output_ptr + pid * row_stride + offsets, in_exp / in_exp_sum, mask=mask)


# === Host-side wrapper ===
def run_online_softmax(x, output_triton):
    """
    Wrap the kernel call.
    If BEGINNER_MODE=True, show hint text before running kernel.
    Capture exceptions and provide teaching hints if code is incomplete.
    """

    if BEGINNER_MODE:
        print("\n🧩 [Beginner Mode Enabled]")
        print("Make sure the following lines in kernel are correctly filled:")
        print("──────────────────────────────────────────────")
        print("1️⃣  pid = tl.program_id(0)")
        print("2️⃣  mask = offsets < n_cols ")
        print("3️⃣  in_data = tl.load(in_ptr + pid * row_stride + offsets, mask=mask, other=-float('inf')) ")
        print("4️⃣  in_max_new = tl.maximum(in_max, tl.max(in_data, axis=0))")
        print("5️⃣  in_exp_sum = in_exp_sum * tl.exp(in_max - in_max_new) + tl.sum(tl.exp(in_data - in_max_new), axis=0)")       
        print("──────────────────────────────────────────────\n")

    n_rows, n_cols = x.shape

    try:
        online_softmax_kernel[(n_rows,)](
            x,
            output_triton,
            x.stride(0),
            n_cols=n_cols,
            BLOCK_SIZE=256
        )
    except Exception as e:
        print("\n❌ [Error Detected During Kernel Launch]")
        print("──────────────────────────────────────────────")
        print(f"⚠️ Exception Type: {type(e).__name__}")
        print(f"💬 Message: {e}")
        print("──────────────────────────────────────────────")
        print("💡 Possible causes:")
        print("  - The placeholders[0] have not been replaced with actual expressions.")
        print("  - Syntax errors or missing parameters in Triton kernel.")
        print("──────────────────────────────────────────────\n")
        if not BEGINNER_MODE:
            print("Please wait for the workshop presenter to introduce the assignment!")
            print("Tip: You can also enable BEGINNER_MODE = True at the top of the script for guided hints.")
        # Prevent crash — just return silently
        return


# === Main test ===
torch.manual_seed(0)
x = torch.randn(8192, 8192, device=DEVICE)
output_torch = torch.softmax(x, dim=-1)
output_triton = torch.empty_like(x)

run_online_softmax(x, output_triton)

# === Check correctness ===
print("\nChecking correctness...")
print("\n")
if torch.allclose(output_torch, output_triton, atol=1e-3, rtol=1e-3):
    print("✅ Success! The output of the Triton kernel-based Online Softmax aligns with the PyTorch version!")
else:
    print("⚠️ Accuracy mismatch detected. Triton output does not match PyTorch output.")
    # print(f"Maximum difference: {torch.max(torch.abs(output_torch - output_triton))}")
print("\n")


### Performacne Measurement via 'Latency(ms)' 👇

In [None]:
n_rows, n_cols = x.shape
output_triton = torch.empty_like(x)

# Warm up the no-tune version
online_softmax_kernel[(n_rows,)](
        x,
        output_triton,
        x.stride(0),
        n_cols,
        BLOCK_SIZE=256
)
torch.cuda.empty_cache()

# latency of online softmax kernel 
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
online_softmax_kernel[(n_rows,)](
        x,
        output_triton,
        x.stride(0),
        n_cols,
        BLOCK_SIZE=256
)
end_event.record()

torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
assert torch.allclose(output_torch, output_triton, atol=1e-3, rtol=1e-3), "Accuracy mismatch: The maximum difference between torch and triton is " \
      f'{torch.max(torch.abs(output_torch - output_triton))}'
print("\n")
print(f'Online Softmax Triton Version Elapsed: {elapsed_time_ms:.3f}ms')
print("\n")


### 3.3. Visualize the performance via perf_report feature and Compare with Torch Naive Kernel.


In [None]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import matplotlib.pyplot as plt

DEVICE = triton.runtime.driver.active.get_active_torch_device()


# --- PyTorch Online Softmax ---
def softmax_torch(x: torch.Tensor, dim=-1):
    """
    Compute softmax using PyTorch built-in function.
    Output is the same shape as input.
    """
    output = F.softmax(x, dim=dim)
    return output
    

# --- Helper to run Triton Autotune ---
def softmax_triton(x: torch.Tensor):
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    online_softmax_kernel[(n_rows,)](
        x,
        output,
        x.stride(0),
        n_cols,
        BLOCK_SIZE=256
    )
    # print("triton output:", output)
    return output

# --- Triton Benchmark ---
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(60, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg`
        line_names=["Triton Online-Softmax", "Torch"],  # label name for the lines
        styles=[('blue', 'solid'), ('orange', 'dashdot')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="Online-Softmax Performance Benchamrk",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    
    quantiles = [0.5, 0.2, 0.8]
    
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.softmax(x, dim=-1), rep=10, quantiles=quantiles
        )
    elif provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_triton(x), rep=10, quantiles=quantiles
        )
    
    # Calculate bandwidth: 2 * (read + write) * size / time
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)

# --- Run benchmark ---
benchmark.run(show_plots=True, print_data=True)


### 3.4 Skim through the advanced Fused-Softmax 
This example demonstrates how to implement a fused softmax kernel using Triton, with architectural awareness for AMD ROCm/CDNA backends.

OpenAI Triton provided a reference softmax sample codes with the name of "fused-softmax". Based on online softmax, it continued to simplify the maxmumim data calculation, which can remove 1 for-loop. it also ask the compiler to use more threads per row by increasing the number of warps, which is often tuned for better performance. and finally it improved the kernel lauching scheme by the GPU hardware properties, which can have the higher GPU kernel occupancy and better performance.  


The calculation is conducted like this:

![fused_softmax](./assets/softmax_fused_50p.png)

- It processes an entire row as a single data block, meaning a larger BLOCK_SIZE, and loads the whole row into SRAM once instead of repeatedly accessing global memory. 
- The maximum data calculation is a parallel reduction across the data block, which is efficient in Triton.
- It also ask the compiler to use more threads per row by increasing the number of warps, which is often tuned for better performance.
- As the program uses more resources, the number of programs is less than n_rows. So each program is assigned one or more rows of the input matrix.
- The register usage is evaluated by warmup run, and the calculated occupancy helps to choose the proper number of programs (thread block). 

  
#### Here are 2 things you need to do
1. Understand the methodology of architecture-awared algorithm optimization.
2. Check the performance improvement and re-run the benchmark and visualization with Torch kernel.


In [None]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()

def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908', 'gfx950')

@triton.jit
def fused_softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)


### To tune the kernel, we first get some resource properties of our GPU by:

In [None]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
print(f"NUM_SM: {NUM_SM}, NUM_REGS: {NUM_REGS}, SIZE_SMEM: {SIZE_SMEM}, WARP_SIZE: {WARP_SIZE}, target: {target}")

### Then we setup the kernel launch configuration

In [None]:
torch.manual_seed(0)
x = torch.randn(8192, 8192, device=DEVICE)
output_torch = torch.softmax(x, dim=-1)
n_rows, n_cols = x.shape
# Allocate output
y = torch.empty_like(x)

# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols*2)

# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
num_warps = 8

# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2

print(f"BLOCK_SIZE: {BLOCK_SIZE}, num_warps: {num_warps}, num_stages: {num_stages}")

### The occupancy of the kernel is limited by register usage. To maximize the occupancy, let's warmup the kernel to get register usage, and calculate the proper programs number.

In [None]:
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = fused_softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared

if is_hip():
    # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
    # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
    # ISA SECTION (3.6.4 for CDNA3)
    # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
    # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
    # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
    # not required to be equal numbers of both types.
    if is_cdna():
        NUM_GPRS = NUM_REGS * 2

    # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
    # When we divide this number with WARP_SIZE we get maximum number of waves that can
    # execute on a CU (multi-processor)  in parallel.
    MAX_NUM_THREADS = properties["max_threads_per_sm"]
    max_num_waves = MAX_NUM_THREADS // WARP_SIZE
    occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
    occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy

num_programs = min(num_programs, n_rows)

print(f"n_regs: {n_regs}, size_smem: {size_smem}, occupancy: {occupancy}, num_programs: {num_programs}")

### Now everything is ready, we can verify the kernel's correctness and benchmark it, as we did in previous kernel versions:

In [None]:
# Create a number of persistent programs.
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
end_event.record()

torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
assert torch.allclose(output_torch, y, atol=1e-2, rtol=1e-2), "Accuracy mismatch:The maximum difference between torch and triton is " \
      f'{torch.max(torch.abs(output_torch - y))}'
print("✅The output of the Triton kernel-based Fused Softmax aligns with the PyTorch version!")
print(f'Fused Softmax Triton Version Elapsed: {elapsed_time_ms:.3f}ms')


# Let's re-run the benchmark and visualization !!!


In [None]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import matplotlib.pyplot as plt

DEVICE = triton.runtime.driver.active.get_active_torch_device()


# --- PyTorch Online Softmax ---
def softmax_torch(x: torch.Tensor, dim=-1):
    """
    Compute softmax using PyTorch built-in function.
    Output is the same shape as input.
    """
    output = F.softmax(x, dim=dim)
    return output
    

# --- Helper to run Triton Autotune ---
def softmax_triton(x: torch.Tensor):
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    kernel[(num_programs, 1, 1)](
        y, 
        x, 
        x.stride(0), 
        y.stride(0), 
        n_rows, 
        n_cols, 
        BLOCK_SIZE, 
        num_stages)
    return output

# --- Triton Benchmark ---
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(60, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg`
        line_names=["Triton Fused-Softmax", "Torch Softmax"],  # label name for the lines
        styles=[('blue', 'solid'), ('orange', 'dashdot')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="Fused-Softmax Performance Benchamrk",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    
    quantiles = [0.5, 0.2, 0.8]
    
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.softmax(x, dim=-1), rep=10, quantiles=quantiles
        )
    elif provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_triton(x), rep=10, quantiles=quantiles
        )
    
    # Calculate bandwidth: 2 * (read + write) * size / time
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)

# --- Run benchmark ---
benchmark.run(show_plots=True, print_data=True)


## Summary 
Triton simplifies high-performance GPU kernel development. Through this workshop, developer has already know how to develop and optimize Triton kernel on AMD GPUs. 

If developer would like to study more about OpenAI Triton itself, OpenAI Triton [official document](https://triton-lang.org/main/index.html) can be very useful. You can find more information about AMD Triton from [this document](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/optimizing-triton-kernel.html) and [this blog](https://rocm.blogs.amd.com/software-tools-optimization/kernel-development-optimizations-with-triton-on-/README.html).

We hope that this workshop will encourage you to tune, test, and contribute to Triton on AMD GPUs, and help us shape the future of AI acceleration.   