# Understanding fusion: What and Why

One reason to write code in Triton is to "fuse" multiple operations together.
In this section, we'll illustrate what fusion is, why it improves speed/efficiency,
and show some empirical results.

## What does fusion look like?

Suppose we would like to compute $y = x*sigmoid(x)$, where the size of $x$ is 512 MB.
This fits on device memory but far exceeds local memory. Two ways we can compute it are
shown below: "operation at a time" and "fused: block at a time." 

<img src="img/x_at_a_time.png" width=1024 alt="A comparison of two execution methods: operation at a time and fused block at a time. Operation at a time computes this with sigmoid over whole tensor, in part of the input at a time and saving the result back. Then it does a multiply over the whole tensor, bringing in part of the input/intermediate result at a time and saving the result back. For fused block at a time it brings in a block of input and completes sigmoid and multiply before saving the result back. Then it repeats for each block of the input">

## Why use fusion?

To calculate the efficiency of these two approaches, we can count the number of bytes loaded and stored.

* Operation at a time
  * sigmoid
    * load 512 MB for $x$
    * store 512 MB for intermediate result
  * multiply
    * load 512 MB for $x$, load 512 MB for $sigmoid(x)$
    * store 512 MB for final result
  * total: **1.5 GB load, 1 GB store**
* Fused: block at a time
  * $x*sigmoid(x)$
    * load 512 MB for $x$
    * store 512 MB for $x*sigmoid(x)$
  * total: **512 MB load, 512 MB store**

By keeping data in local memory until it is no longer needed, the fused version is more efficient. 

## What do frameworks do?

Pytorch eager mode does the "operation at a time", while compilers (Pytorch graph mode, torch.script.jit, others)
should be able to perform certain improvements like using "fused: block at a time". In simple cases, this fusion
works automatically. But for more complex cases, it may not, which is when Triton becomes useful.

## Experiments

To see how this analysis applies to practice, we looked at
three qualitatively different workloads:
- add: a single elementwise operation
- swish: multiple elementwise operations
- softmax: elementwise + aggregation operations

And we looked at four implementations:
- Triton
- Torch (manual) - handwritten using Pytorch
- Torch (manual) JIT - torch.script.jit optimizes the above
- Torch (builtin) - call into Pytorch's library

We used a V100 GPU, ran each configuration 100 times with a warmup, and only timed
the GPU execution portion. We used large enough inputs to keep the device busy.
For Triton, we found a block size that maximizes performance for each workload.

### Results

The following chart shows normalized throughput (higher is better) for each of the workloads and implementations. 

<img src="img/benchmarks results.png" width=512 alt="bar chart showing normalized throughput (higher is better) on 3 benchmarks. Triton/Torch (manual), Torch (manual) JIT, Torch (builtin) is the order of the following numbers. Add: 1, 0.99; swish: 0.995, 0.401, 0.997, 1. Softmax: 1, .244, .390, .817"/>

Here are the interesting observations:
- When there is a single operation (add), there is nothing to fuse so handwritten Pytorch reaches peak speed
- When there are multiple elementwise operations (swish), fusion is helpful. Handwritten Pytorch does operator-at-a-time so speed suffers. And, in this simple situation, the JIT is able to fuse to reach peak speed. The builtin version (`torch.nn.SiLU`, presumably backed by a CUDA kernel) also reaches peak speed.
- When the mixture of operations is more complex (softmax), the JIT is less successful. It appears to be able to improve the handwritten Pytorch but doesn't reach Triton's speed. The builtin (`torch.softmax`) doesn't reach Triton's speed either, possibly due to its generality.

## Understanding block size

Let's look at how block size affects speed, using the swish example. We fixed a single input size, and we compiled/ran the Triton program with a variety of block sizes.

<img src="img/block size swish.png" width=512 alt="line chart showing effect of block size on speed in Triton. For example app vector add. Y-axis normalized throughput, x-axis is block size numbered powers of two from 32 to 32K. For sizes 32 to 128 the throughput is 0.2, 0.4, and about 0.88 respectively. From size 256 to 16384, peak or near peak throughput. At 32768 throughput drops off to about 0.3"/>

What we see is that once we get to a certain block size, throughput peaks. Then at large enough block size, the throughput drops off.
- In the climbing regime, two possible causes of low throughput:
  - overhead of launching many instances dominates running time
  - block size small relative to cacheline so memory transfers are inefficient
- In the peak regime, there is plenty of work for each instance to do, and the memory bandwidth of the device is saturated with useful work.
- At the dropping regime, two possible causes of low throughput:
  - (more likely) the data for calculating a block exceeds local memory capacity, in effect reverting execution back to operator-at-a-time 
  - (less likely) too few instances, so there is not enough parallel work to keep all processing units of the GPU busy. Less likely because total input was large enough for V100's 80 SMs to be busy: $80 * 2^{15} < 2^{26}$