# Fused Kernels - What started as exploring DLRM

## Abstract

Abstract: With focus on performance to get the most out of hardware, fusing of kernels has been a popular technique. At times, researchers/practitioners will re-write their code in native cuda or cpu kernels to get optimal performance, but projects such as torch.compile aim to make this simpler. Talk will focus on generating fused kernels and how to leverage torch.compile to be able to do that. We will shift a bit from all LLM talk and look into recommendation algorithms as big deep learning/AI systems. In the process, we will work on creating fused kernels (triton and cuda) with the help of torch.compile. 

# Setup

## Code and other artifacts

- Lecture code: https://github.com/kapilsh/cuda-mode-lecture
- How to open chrome trace: chrome://tracing
- DLRM Blog Post: https://ai.meta.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/
- DLRM Paper: https://arxiv.org/pdf/1906.00091
- DLRM github repo: https://github.com/facebookresearch/dlrm
- Criteo Dataset: https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/


## DLRM (Deep Learning Recommendation Model)

### MODEL ARCHITECTURE 

![DLRM Model](./data/dlrm_model.png)

### System Constrants

![System Constraints](./data/66324023_2056206621174067_2937830378620059648_n.gif)

### Criteo Dataset

- Training dataset with 24 days of ad display and click data (positive: clicked and negatives: non-clicked)
- 13 features taking integer values (mostly count features)
- 26 anonymized categorical features
- Corresponding Kaggle competition: https://www.kaggle.com/c/criteo-display-ad-challenge

# Exploring DLRM

In [27]:
import json
import time
from dataclasses import dataclass
from typing import Mapping, List, Dict, Union

import click
import torch
import torch._dynamo
from loguru import logger
from torch import nn, Tensor
from torch.utils.data import DataLoader

from criteo_dataset import CriteoParquetDataset
from model import DenseArch, read_metadata, SparseArch, DenseSparseInteractionLayer, PredictionLayer, Parameters, DLRM

In [1]:
file_path = "./data/sample_criteo_data.parquet"
metadata_path = "./data/sample_criteo_metadata.json"

In [5]:
logger.info("Reading the parquet file {}...".format(file_path))
logger.info("Reading the metadata file {}...".format(metadata_path))

dataset = CriteoParquetDataset(file_path)
data_loader = DataLoader(dataset, batch_size=2, shuffle=False)
labels, dense, sparse = next(iter(data_loader))
logger.info("Labels size: {}".format(labels.size()))
logger.info("Dense size: {}".format(dense.size()))
logger.info("Sparse size: {}".format(sparse.size()))

2024-05-04 14:13:38.111 | INFO     | __main__:<module>:1 - Reading the parquet file ./data/sample_criteo_data.parquet...
2024-05-04 14:13:38.113 | INFO     | __main__:<module>:2 - Reading the metadata file ./data/sample_criteo_metadata.json...
2024-05-04 14:13:40.288 | INFO     | __main__:<module>:7 - Labels size: torch.Size([2])
2024-05-04 14:13:40.288 | INFO     | __main__:<module>:8 - Dense size: torch.Size([2, 13])
2024-05-04 14:13:40.289 | INFO     | __main__:<module>:9 - Sparse size: torch.Size([2, 26])


In [6]:
dense

tensor([[5.0000e+00, 1.1000e+02, 0.0000e+00, 1.6000e+01, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 1.4000e+01, 7.0000e+00, 1.0000e+00, 0.0000e+00, 3.0600e+02,
         0.0000e+00],
        [3.2000e+01, 3.0000e+00, 5.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         0.0000e+00, 6.1000e+01, 5.0000e+00, 0.0000e+00, 1.0000e+00, 3.1570e+03,
         5.0000e+00]])

In [7]:
sparse

tensor([[1651969401, 3793706328, 2951365679, 2489089999,  951068488, 1875733963,
          897624609,  679512323, 1189011366,  771915201,  209470001, 2509774111,
           12976055, 3192841527, 2316006604, 1289502458, 3523761834, 3088518074,
         2501034507, 3280875304,  351689309,  632402057, 3619814411, 2091868316,
          809724924, 3977271069],
        [3857972621, 2695561126, 1873417685, 3666490401, 1020698403, 1875733963,
         2870406529, 1128426537,  502653268, 2112471209, 1716706404, 2582335015,
           12976055, 3192841527, 4089183897, 1289502458, 3523761834, 2716538129,
         2501034507, 4273985635, 2737978529, 3370249814,  391309800, 1966410890,
         2568167914, 3075991895]])

In [10]:

dense_mlp_out_size = 16
num_dense_features = dense.size()[1]
dense_arch = DenseArch(dense_feature_count=num_dense_features,
                       dense_hidden_layers_sizes=[32],
                       output_size=dense_mlp_out_size)
dense_out = dense_arch(dense)
logger.info("Dense out size: {}".format(dense_out.size()))
dense_out

2024-05-04 14:15:15.944 | INFO     | __main__:<module>:8 - Dense out size: torch.Size([2, 16])


tensor([[ -15.9158,   -3.1166,  -14.7238,  -15.7895,    9.8099,   -7.4153,
            6.1413,  -17.2388,  -17.7460,   36.3442,    9.2208,   17.1685,
          -26.7153,    4.1549,  -27.6369,   13.1371],
        [ -67.6642,  -24.3709, -192.2249, -123.3462,   86.3253,  -45.0116,
          -26.8654, -194.2189, -240.2030,  305.7565,  136.9120,   46.5375,
         -266.4658,   81.3118, -262.6347,  169.4223]],
       grad_fn=<AddmmBackward0>)

In [37]:
metadata = read_metadata(metadata_path)
embedding_size = 16
embedding_sizes = {fn: embedding_size for fn in metadata.keys()}
sparse_mlp_out_size = 16
sparse_arch = SparseArch(metadata=metadata,
                         embedding_sizes=embedding_sizes)
# compiled model hangs on running with inputs
# sparse_arch_optim = torch.compile(sparse_arch)
sparse_out = sparse_arch(sparse)
for v in sparse_out:
    logger.info("Sparse out size: {}".format(v.size()))
sparse_out[0]

  tokenizers = torch.tensor(tokenizer_values).reshape(1, -1)
2024-05-05 07:12:05.227 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.228 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.228 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.228 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.229 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.229 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.229 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.230 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.230 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 07:12:05.231 | INFO     | __main__:<module>:11 - Sparse

tensor([[ 3.4630e-01,  2.1154e+00, -1.1004e+00, -4.8268e-01, -1.1138e+00,
         -6.5604e-01, -3.5921e-01,  8.3207e-02,  2.1305e-01,  1.5671e-01,
         -2.0035e-01,  1.5515e+00, -2.6357e-01, -7.4086e-01, -1.8739e-03,
          8.1901e-01],
        [-1.2675e-01, -1.2281e+00,  7.4344e-01,  1.9631e+00, -1.3492e-01,
         -1.6420e+00, -5.3306e-01, -1.9332e-01, -3.4592e-01,  7.2285e-01,
         -1.7827e-01,  1.0495e+00, -1.0653e+00, -3.2139e-01, -2.1849e+00,
         -8.9220e-01]], grad_fn=<EmbeddingBackward0>)

In [20]:
dense_sparse_interaction_layer = DenseSparseInteractionLayer()
ds_out = dense_sparse_interaction_layer(dense_out, sparse_out)
logger.info("Dense sparse interaction out size: {}".format(ds_out.size()))
ds_out

2024-05-04 14:18:34.244 | INFO     | __main__:<module>:3 - Dense sparse interaction out size: torch.Size([2, 186624])


tensor([[ 2.5331e+02,  4.9603e+01,  2.3434e+02,  ..., -2.1245e+00,
          2.2212e-01,  9.8313e-01],
        [ 4.5784e+03,  1.6490e+03,  1.3007e+04,  ...,  3.6353e-03,
          7.7478e-03,  2.5760e-04]], grad_fn=<ViewBackward0>)

In [29]:
prediction_layer = PredictionLayer(dense_out_size=dense_mlp_out_size,
                                   sparse_out_sizes=[sparse_mlp_out_size] * len(metadata),
                                   hidden_sizes=[16])
pred_out = prediction_layer(ds_out)
logger.info("Prediction out size: {}".format(pred_out.size()))
logger.info("Prediction out value: {}".format(pred_out))

2024-05-04 14:33:00.571 | INFO     | __main__:<module>:5 - Prediction out size: torch.Size([2, 1])
2024-05-04 14:33:00.572 | INFO     | __main__:<module>:6 - Prediction out value: tensor([[0.9976],
        [1.0000]], grad_fn=<SigmoidBackward0>)


# ONNX Model Graph

## Model Graph

![Model Graph](./data/model_graph.png)

# Profiling

### Initial Setup: Simple 2 layered MLP used for each triangle

### Baseline

> python model_train.py

### Initial Distribution - Naive Implementation of index_hash

![Initial index](./perf_screenshots/pytorch_profile_initial_index_hash.png)
*Pytorch Profiler trace (initial)*

---

> tensorboard --logdir tb_logs --bind_all

![Summary Initial](./perf_screenshots/summary_initial_index_hash.png)
*Initial distribution of ops - summary from tensorboard*

---

### Tensor.item() takes a lot of running time

- What's going on - what is _local_scalar_dense and why is item() taking so long?
    - https://discuss.pytorch.org/t/tensor-item-takes-a-lot-of-running-time/16683
    - https://discuss.pytorch.org/t/calling-loss-item-is-very-slow/99774 

> CUDA_LAUNCH_BLOCKING=1 python model_train.py

---

### After passing `CUDA_LAUNCH_BLOCKING=1`

![Summary Initial](./perf_screenshots/summary_initial_cuda_launch_blocking.png)
*New distribution of ops after `CUDA_LAUNCH_BLOCKING=1` - summary from tensorboard*

---

![Initial Index Hash Profile](./perf_screenshots/index_hash_profile_1.png)
*Profile initial index hash implementation*

![Index Hash After Improvement](./perf_screenshots/improve_index_hash.png)
*Profile after improvements*


- Index hash seems pretty expensive
- Not adaptive to new sparse ids
- Can we improve/simplify the hash function
- Let's just calculate the modulus hash based on cardinality
    - Maybe not representative of data if distribution is non uniform across categories (but that's fine for now) 

---

### Using Modulus Hash

![Naive Modulus Hash](./perf_screenshots/naive_modulus_hash.png)
*Pytorch Profiler trace for naive modulus hash*

- Let's check the time it took for each of the previous versions
- Wall time down from 48ms -> 5ms

![Optimized Modulus Hash](./perf_screenshots/optimized_modulus_hash.png)
*Pytorch Profiler trace for optimized modulus hash*

- Down to 3.48ms

[ ] TODO: Add summary table of DLRM wall time

---

Based on the profile, what could be the next thing to look at?

- index_select?

# torch.compile

##  torch.compile DLRM

> TORCH_COMPILE_DEBUG_DIR=/home/ksharma/logs TORCH_LOGS=recompiles,+dynamo,inductor,guards,graph_breaks python model.py

> CUDA_LAUNCH_BLOCKING=1 python model_train.py

- GPU utilization goes up
- memory footprint goes down

## Memory Footprint

### Pre `torch.compile`
![Pre torch.compile](./perf_screenshots/pre_torch_compile_initial.png)

### Post `torch.compile`
![Post torch.compile](./perf_screenshots/post_torch_compile_initial.png)

---

### Chrome Trace after `torch.compile`
![Chrome Trace](./perf_screenshots/pytorch_profile_torch_compile.png)
*Pytorch Profile Trace after `torch.compile`

---

### Let's look deeper into what's going on

![torch compile triton kernels](./perf_screenshots/torch_compile_triton_kernels.png)
*Custom triton kernel scheduled on the cuda stream*

### Increase complexity

Source: https://ai.meta.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/

```shell
python dlrm_s_pytorch.py --arch-sparse-feature-size=16 --arch-mlp-bot="13-512-256-64-16" --arch-mlp-top="512-256-1" --data-generation=dataset --data-set=kaggle --processed-data-file=./input/kaggle_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=128 --print-freq=1024 --print-time
```

### Let's change the model architecture

- --arch-mlp-bot="13-512-256-64-16"
- --arch-mlp-top="512-256-1"

### Eager view

![Full Model Eager View](./perf_screenshots/full_model_eager_view.png)
*Full Eager Model - Pytorch Profiler trace*

- Sparse Arch is now not the biggest piece of the pie
- PredictionLayer is the highest
    - Top MLP and sigmoid  

### `torch.compile` view

![Full Model Torch Compiled](./perf_screenshots/full_model_torch_compiled.png)
*Full `torch.compile` Model - Pytorch Profiler trace*

# torch.compile -> triton code generation 

## Generate triton code

> TORCH_LOGS=output_code CUDA_LAUNCH_BLOCKING=1 python model_train.py

## Inspect

- Prints generated code for you
- Should see `... torch._inductor.graph.__output_code: [INFO] Output code written to: ...`
- Shows source nodes from where the code was generated
- Fused kernels:
    - fused_relu
    - fused_cat
    - fused_embedding
    - fused_sigmoid_squeeze

## Write our own

### Kernel

```python
@triton.jit
def pointwise_add_relu_fusion_512(in_out_ptr0, in_ptr0, XBLOCK : tl.constexpr):
    # Number of elements in in_out_ptr0 (B X N)
    xnumel = 65536
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a strided tensor of 65536 i.e. 128 X 512 and XBLOCK = 512
    # the programs will each access the elements [0:512, 512:1024, ...].
    # i.e. offsets is a list of pointers:
    # Question: Can you see how torch.compile is allocating blocks here? 
    # below we will call this N = 512
    xoffset = tl.program_id(0) * XBLOCK
    # block threads
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    # masks to guard against overflow
    xmask = xindex < xnumel
    # xindex will have elements from 0:N, N:2N where N = dense @ weights
    x2 = xindex
    # bias i.e. 1D tensor with only N elements
    # mod will give the us the right 
    x0 = xindex % 512
    # load the N elements
    tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
    # load the 1D tensor
    tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    # result = bias + dense @ weights
    tmp2 = tmp0 + tmp1
    # relu: can also use tl.maximum
    tmp3 = triton_helpers.maximum(0, tmp2) 
    # output moved over
    tl.store(in_out_ptr0 + (x2), tmp3, None)
```

### Test

In [40]:
import triton
import torch
import triton.language as tl
from torch._inductor import triton_helpers
from torch._inductor.triton_heuristics import grid

@triton.jit
def pointwise_add_relu_fusion_512(in_out_ptr0, in_ptr0, XBLOCK : tl.constexpr):
    xnumel = 65536
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    # dense @ weights
    x2 = xindex
    # bias
    x0 = xindex % 512
    tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    # bias + dense @ weights
    tmp2 = tmp0 + tmp1
    tmp3 = triton_helpers.maximum(0, tmp2)
    tl.store(in_out_ptr0 + (x2), tmp3, None)


torch.cuda.set_device(0)  # no-op to ensure context
X = torch.ones(size=(128, 512), device='cuda')
print(X[:3, :3])
Y = torch.ones(size=(512,), device='cuda')
print(Y[:3])
eager_result = torch.maximum(X + Y, torch.tensor(0., device='cuda'))
print(eager_result[:3, :3])
pointwise_add_relu_fusion_512[grid(65536)](X, Y, 512)
print(X)
torch.testing.assert_close(X, eager_result, rtol=1e-4, atol=1e-4)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')
tensor([1., 1., 1.], device='cuda:0')
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], device='cuda:0')
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        ...,
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0')


## Cuda Kernel

### Ask ChatGPT to generate the kernel for us

![Chat GPT Input](./data/chatgpt_input.png)

### ChatGPT output (without any changes)

```cpp
#include <cuda_fp16.h>

__global__ void pointwise_add_relu_fusion_512(float* in_out_ptr0, const float* in_ptr0, const int XBLOCK) {
    const int xnumel = 65536;
    const int N = 512; // Value of N from the Triton kernel
    const int tid = threadIdx.x;
    const int xoffset = blockIdx.x * XBLOCK;
    const int xindex = xoffset + tid;
    const bool xmask = xindex < xnumel;
    
    if (xmask) {
        int x2 = xindex;
        int x0 = xindex % N;
        
        float tmp0 = in_out_ptr0[x2];
        float tmp1 = in_ptr0[x0];
        float tmp2 = tmp0 + tmp1;
        float tmp3 = max(0.0f, tmp2); // ReLU operation
        
        in_out_ptr0[x2] = tmp3;
    }
}
```

### Let's run the generated CUDA kernel

```
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip # Download torch native lib
```
