# Fused Kernels - What started as exploring DLRM

## Abstract

Abstract: With focus on performance to get the most out of hardware, fusing of kernels has been a popular technique. At times, researchers/practitioners will re-write their code in native cuda or cpu kernels to get optimal performance, but projects such as torch.compile aim to make this simpler. Talk will focus on generating fused kernels and how to leverage torch.compile to be able to do that. We will shift a bit from all LLM talk and look into recommendation algorithms as big deep learning/AI systems. In the process, we will work on creating fused kernels (triton and cuda) with the help of torch.compile. 

# Setup

## Code and other artifacts

- Lecture code: https://github.com/kapilsh/cuda-mode-lecture
- How to open chrome trace: chrome://tracing
- DLRM Blog Post: https://ai.meta.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/
- DLRM Paper: https://arxiv.org/pdf/1906.00091
- DLRM github repo: https://github.com/facebookresearch/dlrm
- Criteo Dataset: https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/


## DLRM (Deep Learning Recommendation Model)

### MODEL ARCHITECTURE 

![DLRM Model](./data/dlrm_model.png)

### System Constrants

![System Constraints](./data/66324023_2056206621174067_2937830378620059648_n.gif)

### Criteo Dataset

- Training dataset with 24 days of ad display and click data (positive: clicked and negatives: non-clicked)
- 13 features taking integer values (mostly count features)
- 26 anonymized categorical features
- Corresponding Kaggle competition: https://www.kaggle.com/c/criteo-display-ad-challenge

# Exploring DLRM

In [1]:
import json
import time
from dataclasses import dataclass
from typing import Mapping, List, Dict, Union

import click
import torch
import torch._dynamo
from loguru import logger
from torch import nn, Tensor
from torch.utils.data import DataLoader

from criteo_dataset import CriteoParquetDataset
from model import DenseArch, read_metadata, SparseArch, DenseSparseInteractionLayer, PredictionLayer, Parameters, DLRM

In [2]:
file_path = "./data/sample_criteo_data.parquet"
metadata_path = "./data/sample_criteo_metadata.json"

In [3]:
logger.info("Reading the parquet file {}...".format(file_path))
logger.info("Reading the metadata file {}...".format(metadata_path))

dataset = CriteoParquetDataset(file_path)
data_loader = DataLoader(dataset, batch_size=2, shuffle=False)
labels, dense, sparse = next(iter(data_loader))
logger.info("Labels size: {}".format(labels.size()))
logger.info("Dense size: {}".format(dense.size()))
logger.info("Sparse size: {}".format(sparse.size()))

2024-05-05 21:10:14.301 | INFO     | __main__:<module>:1 - Reading the parquet file ./data/sample_criteo_data.parquet...
2024-05-05 21:10:14.302 | INFO     | __main__:<module>:2 - Reading the metadata file ./data/sample_criteo_metadata.json...
2024-05-05 21:10:15.447 | INFO     | __main__:<module>:7 - Labels size: torch.Size([2])
2024-05-05 21:10:15.448 | INFO     | __main__:<module>:8 - Dense size: torch.Size([2, 13])
2024-05-05 21:10:15.448 | INFO     | __main__:<module>:9 - Sparse size: torch.Size([2, 26])


In [4]:
dense

tensor([[5.0000e+00, 1.1000e+02, 0.0000e+00, 1.6000e+01, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 1.4000e+01, 7.0000e+00, 1.0000e+00, 0.0000e+00, 3.0600e+02,
         0.0000e+00],
        [3.2000e+01, 3.0000e+00, 5.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         0.0000e+00, 6.1000e+01, 5.0000e+00, 0.0000e+00, 1.0000e+00, 3.1570e+03,
         5.0000e+00]])

In [5]:
sparse

tensor([[1651969401, 3793706328, 2951365679, 2489089999,  951068488, 1875733963,
          897624609,  679512323, 1189011366,  771915201,  209470001, 2509774111,
           12976055, 3192841527, 2316006604, 1289502458, 3523761834, 3088518074,
         2501034507, 3280875304,  351689309,  632402057, 3619814411, 2091868316,
          809724924, 3977271069],
        [3857972621, 2695561126, 1873417685, 3666490401, 1020698403, 1875733963,
         2870406529, 1128426537,  502653268, 2112471209, 1716706404, 2582335015,
           12976055, 3192841527, 4089183897, 1289502458, 3523761834, 2716538129,
         2501034507, 4273985635, 2737978529, 3370249814,  391309800, 1966410890,
         2568167914, 3075991895]])

In [6]:

dense_mlp_out_size = 16
num_dense_features = dense.size()[1]
dense_arch = DenseArch(dense_feature_count=num_dense_features,
                       dense_hidden_layers_sizes=[32],
                       output_size=dense_mlp_out_size)
dense_out = dense_arch(dense)
logger.info("Dense out size: {}".format(dense_out.size()))
dense_out

2024-05-05 21:10:16.415 | INFO     | __main__:<module>:7 - Dense out size: torch.Size([2, 16])


tensor([[  -8.9006,   40.5694,   29.5478,   -8.9894,  -15.8872,   40.5282,
          -45.3048,  -10.0197,  -34.3252,   34.9440,  -20.3214,  -23.4073,
          -27.0077,   14.4200,   16.2891,   23.7067],
        [ -20.1532,  249.8419,  418.2879, -123.5538, -127.3028,  470.2328,
         -476.5699,    3.4968, -396.1133,  421.3827, -245.7733, -277.4859,
         -231.9042,  209.6555,  157.4375,  275.0896]],
       grad_fn=<AddmmBackward0>)

In [7]:
metadata = read_metadata(metadata_path)
embedding_size = 16
embedding_sizes = {fn: embedding_size for fn in metadata.keys()}
sparse_mlp_out_size = 16
sparse_arch = SparseArch(metadata=metadata,
                         embedding_sizes=embedding_sizes)
# compiled model hangs on running with inputs
# sparse_arch_optim = torch.compile(sparse_arch)
sparse_out = sparse_arch(sparse)
for v in sparse_out:
    logger.info("Sparse out size: {}".format(v.size()))
sparse_out[0]

2024-05-05 21:10:18.757 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.758 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.758 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.758 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.759 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.759 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.760 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.760 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.761 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.761 | INFO     | __main__:<module>:11 - Sparse out size: torch.Size([2, 16])
2024-05-05 21:10:18.761 | INFO

tensor([[-1.3973, -0.4610,  2.0486, -0.2744, -0.3556, -0.8468, -0.5945, -1.5288,
          0.1601, -0.1903,  0.1085, -0.4725,  1.2473, -0.8733, -1.9742,  1.7321],
        [-0.0477, -1.7535,  0.2312,  0.4713,  0.9088,  1.1122,  1.4918, -1.7666,
          0.1965, -0.4317,  1.0522, -3.0231, -1.1296,  0.2273,  0.0119, -0.1556]],
       grad_fn=<EmbeddingBackward0>)

In [8]:
dense_sparse_interaction_layer = DenseSparseInteractionLayer()
ds_out = dense_sparse_interaction_layer(dense_out, sparse_out)
logger.info("Dense sparse interaction out size: {}".format(ds_out.size()))
ds_out

2024-05-05 21:10:18.983 | INFO     | __main__:<module>:3 - Dense sparse interaction out size: torch.Size([2, 186624])


tensor([[ 7.9221e+01, -3.6109e+02, -2.6299e+02,  ...,  1.2273e-01,
         -2.9036e-02,  5.1416e-03],
        [ 4.0615e+02, -5.0351e+03, -8.4298e+03,  ...,  1.0812e+00,
         -7.4899e-01,  1.1686e+00]], grad_fn=<ViewBackward0>)

In [9]:
prediction_layer = PredictionLayer(dense_out_size=dense_mlp_out_size,
                                   sparse_out_sizes=[sparse_mlp_out_size] * len(metadata),
                                   hidden_sizes=[16])
pred_out = prediction_layer(ds_out)
logger.info("Prediction out size: {}".format(pred_out.size()))
logger.info("Prediction out value: {}".format(pred_out))

2024-05-05 21:10:21.721 | INFO     | __main__:<module>:5 - Prediction out size: torch.Size([2, 1])
2024-05-05 21:10:21.721 | INFO     | __main__:<module>:6 - Prediction out value: tensor([[0.0213],
        [0.0000]], grad_fn=<SigmoidBackward0>)


# ONNX Model Graph

## Model Graph

![Model Graph](./data/model_graph.png)

# Profiling

### Initial Setup: Simple 2 layered MLP used for each triangle

### Baseline

> python model_train.py

### Initial Distribution - Naive Implementation of index_hash

![Initial index](./perf_screenshots/pytorch_profile_initial_index_hash.png)
*Pytorch Profiler trace (initial)*

---

> tensorboard --logdir tb_logs --bind_all

![Summary Initial](./perf_screenshots/summary_initial_index_hash.png)
*Initial distribution of ops - summary from tensorboard*

---

### Tensor.item() takes a lot of running time

- What's going on - what is _local_scalar_dense and why is item() taking so long?
    - https://discuss.pytorch.org/t/tensor-item-takes-a-lot-of-running-time/16683
    - https://discuss.pytorch.org/t/calling-loss-item-is-very-slow/99774 

> CUDA_LAUNCH_BLOCKING=1 python model_train.py

---

### After passing `CUDA_LAUNCH_BLOCKING=1`

![Summary Initial](./perf_screenshots/summary_initial_cuda_launch_blocking.png)
*New distribution of ops after `CUDA_LAUNCH_BLOCKING=1` - summary from tensorboard*

---

![Initial Index Hash Profile](./perf_screenshots/index_hash_profile_1.png)
*Profile initial index hash implementation*

![Index Hash After Improvement](./perf_screenshots/improve_index_hash.png)
*Profile after improvements*


- Index hash seems pretty expensive
- Not adaptive to new sparse ids
- Can we improve/simplify the hash function
- Let's just calculate the modulus hash based on cardinality
    - Maybe not representative of data if distribution is non uniform across categories (but that's fine for now) 

---

### Using Modulus Hash

![Naive Modulus Hash](./perf_screenshots/naive_modulus_hash.png)
*Pytorch Profiler trace for naive modulus hash*

- Let's check the time it took for each of the previous versions
- Wall time down from 48ms -> 5ms

![Optimized Modulus Hash](./perf_screenshots/optimized_modulus_hash.png)
*Pytorch Profiler trace for optimized modulus hash*

- Down to 3.48ms

[ ] TODO: Add summary table of DLRM wall time

---

Based on the profile, what could be the next thing to look at?

- index_select?

# torch.compile

##  torch.compile DLRM

> TORCH_COMPILE_DEBUG_DIR=/home/ksharma/logs TORCH_LOGS=recompiles,+dynamo,inductor,guards,graph_breaks python model.py

> CUDA_LAUNCH_BLOCKING=1 python model_train.py

- GPU utilization goes up
- memory footprint goes down

## Memory Footprint

### Pre `torch.compile`
![Pre torch.compile](./perf_screenshots/pre_torch_compile_initial.png)

### Post `torch.compile`
![Post torch.compile](./perf_screenshots/post_torch_compile_initial.png)

---

### Chrome Trace after `torch.compile`
![Chrome Trace](./perf_screenshots/pytorch_profile_torch_compile.png)
*Pytorch Profile Trace after `torch.compile`

---

### Let's look deeper into what's going on

![torch compile triton kernels](./perf_screenshots/torch_compile_triton_kernels.png)
*Custom triton kernel scheduled on the cuda stream*

### Increase complexity

Source: https://ai.meta.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/

```shell
python dlrm_s_pytorch.py --arch-sparse-feature-size=16 --arch-mlp-bot="13-512-256-64-16" --arch-mlp-top="512-256-1" --data-generation=dataset --data-set=kaggle --processed-data-file=./input/kaggle_processed.npz --loss-function=bce --round-targets=True --learning-rate=0.1 --mini-batch-size=128 --print-freq=1024 --print-time
```

### Let's change the model architecture

- --arch-mlp-bot="13-512-256-64-16"
- --arch-mlp-top="512-256-1"

### Eager view

![Full Model Eager View](./perf_screenshots/full_model_eager_view.png)
*Full Eager Model - Pytorch Profiler trace*

- Sparse Arch is now not the biggest piece of the pie
- PredictionLayer is the highest
    - Top MLP and sigmoid  

### `torch.compile` view

![Full Model Torch Compiled](./perf_screenshots/full_model_torch_compiled.png)
*Full `torch.compile` Model - Pytorch Profiler trace*

# torch.compile -> triton code generation 

## Generate triton code

> TORCH_LOGS=output_code CUDA_LAUNCH_BLOCKING=1 python model_train.py

## Inspect

- Prints generated code for you
- Should see `... torch._inductor.graph.__output_code: [INFO] Output code written to: ...`
- Shows source nodes from where the code was generated
- Fused kernels:
    - fused_relu
    - fused_cat
    - fused_embedding
    - fused_sigmoid_squeeze

## Write our own

### Kernel

```python
@triton.jit
def pointwise_add_relu_fusion_512(in_out_ptr0, in_ptr0, XBLOCK : tl.constexpr):
    # Number of elements in in_out_ptr0 (B X N)
    xnumel = 65536
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a strided tensor of 65536 i.e. 128 X 512 and XBLOCK = 512
    # the programs will each access the elements [0:512, 512:1024, ...].
    # i.e. offsets is a list of pointers:
    # Question: Can you see how torch.compile is allocating blocks here? 
    # below we will call this N = 512
    xoffset = tl.program_id(0) * XBLOCK
    # block threads
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    # masks to guard against overflow
    xmask = xindex < xnumel
    # xindex will have elements from 0:N, N:2N where N = dense @ weights
    x2 = xindex
    # bias i.e. 1D tensor with only N elements
    # mod will give the us the right 
    x0 = xindex % 512
    # load the N elements
    tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
    # load the 1D tensor
    tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    # result = bias + dense @ weights
    tmp2 = tmp0 + tmp1
    # relu: can also use tl.maximum
    tmp3 = triton_helpers.maximum(0, tmp2) 
    # output moved over
    tl.store(in_out_ptr0 + (x2), tmp3, None)
```

### Test

In [10]:
import triton
import torch
import triton.language as tl
from torch._inductor import triton_helpers
from torch._inductor.triton_heuristics import grid

@triton.jit
def pointwise_add_relu_fusion_512(in_out_ptr0, in_ptr0, XBLOCK : tl.constexpr):
    xnumel = 65536
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    # dense @ weights
    x2 = xindex
    # bias
    x0 = xindex % 512
    tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    # bias + dense @ weights
    tmp2 = tmp0 + tmp1
    tmp3 = triton_helpers.maximum(0, tmp2)
    tl.store(in_out_ptr0 + (x2), tmp3, None)


torch.cuda.set_device(0)  # no-op to ensure context
X = torch.ones(size=(128, 512), device='cuda')
print(X[:3, :3])
Y = torch.ones(size=(512,), device='cuda')
print(Y[:3])
eager_result = torch.maximum(X + Y, torch.tensor(0., device='cuda'))
print(eager_result[:3, :3])
pointwise_add_relu_fusion_512[grid(65536)](X, Y, 512)
print(X)
torch.testing.assert_close(X, eager_result, rtol=1e-4, atol=1e-4)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')
tensor([1., 1., 1.], device='cuda:0')
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], device='cuda:0')
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        ...,
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0')


## Cuda Kernel

### Ask ChatGPT to generate the kernel for us

![Chat GPT Input](./data/chatgpt_input.png)

### ChatGPT output (without any changes)

```cpp
#include <cuda_fp16.h>

__global__ void pointwise_add_relu_fusion_512(float* in_out_ptr0, const float* in_ptr0, const int XBLOCK) {
    const int xnumel = 65536;
    const int N = 512; // Value of N from the Triton kernel
    const int tid = threadIdx.x;
    const int xoffset = blockIdx.x * XBLOCK;
    const int xindex = xoffset + tid;
    const bool xmask = xindex < xnumel;
    
    if (xmask) {
        int x2 = xindex;
        int x0 = xindex % N;
        
        float tmp0 = in_out_ptr0[x2];
        float tmp1 = in_ptr0[x0];
        float tmp2 = tmp0 + tmp1;
        float tmp3 = max(0.0f, tmp2); // ReLU operation
        
        in_out_ptr0[x2] = tmp3;
    }
}
```

### Let's run the generated CUDA kernel

> NOTE: To run torch native, you can download it as below or add conda environment to $CMAKE_PREFIX_PATH

```
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip # Download torch native lib
```

### Build the cmake project

In [11]:
! mkdir -p kernels/cmake-build-debug && cd kernels/cmake-build-debug && cmake build .. && make

-- CMake version: 3.22.1
-- Caffe2: CUDA detected: 12.1
-- Caffe2: CUDA nvcc is: /home/ksharma/anaconda3/envs/cuda-learn/bin/nvcc
-- Caffe2: CUDA toolkit directory: /home/ksharma/anaconda3/envs/cuda-learn
-- Caffe2: Header version is: 12.1
-- /home/ksharma/anaconda3/envs/cuda-learn/lib/libnvrtc.so shorthash is c993a6f1
-- USE_CUDNN is set to 0. Compiling without cuDNN support
-- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support
-- Autodetected CUDA architecture(s):  7.5
-- Added CUDA NVCC flags for: -gencode;arch=compute_75,code=sm_75
-- Configuring done
-- Generating done
-- Build files have been written to: /home/ksharma/dev/git/cuda-mode-lecture/kernels/cmake-build-debug
[35m[1mConsolidate compiler generated dependencies of target pointwise_add_relu_fused[0m
[100%] Built target pointwise_add_relu_fused


In [12]:
!./kernels/cmake-build-debug/pointwise_add_relu_fused

Tensor x:
-0.9247 -0.4253 -2.6438  0.1452 -0.1209 -0.5797 -0.6229 -0.3284 -1.0745 -0.3631
-1.6711  2.2655  0.3117 -0.1842  1.2866  1.1820 -0.1271  1.2169  1.4353  1.0605
-0.4941 -1.4244 -0.7244 -1.2973  0.0697 -0.0074  1.8969  0.6878 -0.0779 -0.8373
 1.3506 -0.2879 -0.5965 -0.3283 -0.9086 -0.8059 -0.7407 -0.0504  0.5435  1.5150
 0.0141  0.4532  1.6349  0.7124 -0.1806  1.0252 -1.4622 -0.7554 -0.1836  0.3824
 0.3918 -0.0830  0.8971 -1.1123  0.1116  0.4863 -0.5499 -0.3231 -0.5469  0.9049
 0.2837  0.1210  0.4730 -1.0823 -0.0334 -0.9734  0.9559 -1.1795 -1.0064  0.1160
 0.6852 -0.4124 -0.6738 -0.5404  0.6898 -1.5517  0.3805 -0.0436  0.3597 -0.5043
[ CUDAFloatType{8,10} ]
Tensor y:
 0.1808
-0.5523
 0.9238
-0.7350
 1.3800
 0.8676
 0.1297
-0.9406
 0.8109
 0.8821
[ CUDAFloatType{10} ]
Expected:
 0.0000  0.0000  0.0000  0.0000  1.2591  0.2879  0.0000  0.0000  0.0000  0.5189
 0.0000  1.7132  1.2355  0.0000  2.6666  2.0496  0.0026  0.2763  2.2462  1.9425
 0.0000  0.0000  0.1994  0.0000  1.4497  0.8

### (OR) Run it locally with pytorch utils

In [13]:
cuda_code_file = "./kernels/src/pointwise_add_relu_fused.cu"
header_code_file = "./kernels/src/pointwise_add_relu_fused.cuh"

with open(cuda_code_file) as f:
    cuda_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(cuda_code)

print("----")

with open(header_code_file) as f:
    header_code = "".join([f for f in f.readlines() if not f.startswith("#include")])
    print(header_code)



__global__ void add_relu_fusion_kernel(float* in_out_ptr0, const float* in_ptr0, const int xnumel ,const int XBLOCK) {
    const int tid = threadIdx.x;
    const int xoffset = blockIdx.x * XBLOCK;
    const int xindex = xoffset + tid;
    const bool xmask = xindex < xnumel;

    if (xmask) {
        int x2 = xindex;
        int x0 = xindex % XBLOCK;
        float tmp0 = in_out_ptr0[x2];
        float tmp1 = in_ptr0[x0];
        float tmp2 = tmp0 + tmp1;
        float tmp3 = max(0.0f, tmp2); // ReLU operation

        in_out_ptr0[x2] = tmp3;
    }
}

torch::Tensor add_relu_fusion(torch::Tensor in_out, const torch::Tensor& in) {
    auto sizes = in_out.sizes();
    auto XBLOCK = sizes[1];
    auto numel = in_out.numel();
    dim3 threadsPerBlock(XBLOCK);
    dim3 numBlocks((numel + XBLOCK - 1) / XBLOCK);
    add_relu_fusion_kernel<<<numBlocks, threadsPerBlock>>>(in_out.data_ptr<float>(), in.data_ptr<float>(), numel, XBLOCK);
    cudaDeviceSynchronize();
    return std::move(in_out);
}


In [14]:
!mkdir -p ./build

mkdir: cannot create directory ‘./build’: File exists


In [15]:
import torch
from torch.utils.cpp_extension import load_inline

cuda_extension = load_inline(
    name='kernel_extension',
    cpp_sources=header_code,
    cuda_sources=cuda_code,
    functions=["add_relu_fusion"],
    with_cuda=True,
    verbose=True,
    extra_cuda_cflags=["-O2"],
    build_directory='./build',
)

Detected CUDA files, patching ldflags
Emitting ninja build file ./build/build.ninja...
Building extension module kernel_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module kernel_extension...


In [16]:
dir(cuda_extension)

['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'add_relu_fusion']

In [17]:
torch.cuda.set_device(0)  # no-op to ensure context
X = torch.ones(size=(128, 512), device='cuda')
Y = torch.ones(size=(512,), device='cuda')
cuda_extension.add_relu_fusion(X, Y)
print(X)
torch.testing.assert_close(X, eager_result, rtol=1e-4, atol=1e-4)

tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        ...,
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0')


# LoRA Fused Kernels


## LoRA (LOW-RANK ADAPTATION)

<img src="./data/lora.png" width="400"/>

Source: https://arxiv.org/pdf/2106.09685

## Fused Kernels

In [18]:
from lora_on_simple_mlp import *
from kernels.triton_fused_add_mul_relu import * 

### Fused Mul Add Relu

In [19]:
print(triton.__version__)
in_out_tensor, in_tensor, bias = get_inputs(add_manual_size=True)
expected_output = torch.maximum(in_out_tensor + 0.5 * in_tensor + bias, torch.tensor(0., device='cuda'))
print("Input", in_out_tensor)
print("Expected Output", expected_output)

2.2.0
Input tensor([[-0.3104, -0.0343,  0.1756, -2.2804,  0.5039,  0.5596, -0.0750,  0.9691],
        [-0.2357, -0.4582,  0.5661,  1.2851, -1.8667, -0.0312,  1.2433,  1.3689],
        [-1.0753, -0.0158, -1.4481, -1.3089,  0.6980, -0.3300, -0.7708, -0.4946],
        [ 3.1702,  0.0387, -1.5728,  1.2985, -0.4419, -0.6965, -1.4002, -0.0884],
        [-0.2369, -0.5956, -0.0263,  0.0546, -0.7082,  0.0642,  1.2830, -1.8728],
        [ 0.7562, -0.6345,  1.1176,  1.3382,  0.1994,  0.0671, -0.4159,  0.1468],
        [-0.2672, -0.7882,  0.5857, -1.0649,  1.0950,  0.2490, -0.3271, -0.8691],
        [ 0.5576, -0.1883, -0.5894, -1.0192, -1.4553, -0.8599, -1.3645, -0.0069]],
       device='cuda:0', dtype=torch.float64)
Expected Output tensor([[1.3509, 0.0000, 0.3476, 0.0000, 0.0000, 1.6216, 1.0154, 1.7055],
        [0.7115, 0.0000, 0.4937, 2.3202, 0.0000, 0.0000, 2.9328, 1.9187],
        [0.3889, 0.0000, 0.0000, 0.5725, 0.0000, 0.1259, 0.0000, 0.3423],
        [4.9722, 0.0000, 0.0000, 2.0715, 0.0000,

In [20]:
BLOCK_SIZE = 8
grid = lambda meta: (triton.cdiv(in_out_tensor.numel(), meta['BLOCK_SIZE']),)
fused_add_mul_relu[grid](in_out_tensor, bias, in_tensor, in_out_tensor.numel(), BLOCK_SIZE=BLOCK_SIZE)
print("Output 1", in_out_tensor)
torch.testing.assert_close(in_out_tensor, expected_output, rtol=1e-4, atol=1e-4)

Output 1 tensor([[1.3509, 0.0000, 0.3476, 0.0000, 0.0000, 1.6216, 1.0154, 1.7055],
        [0.7115, 0.0000, 0.4937, 2.3202, 0.0000, 0.0000, 2.9328, 1.9187],
        [0.3889, 0.0000, 0.0000, 0.5725, 0.0000, 0.1259, 0.0000, 0.3423],
        [4.9722, 0.0000, 0.0000, 2.0715, 0.0000, 0.0000, 0.0000, 0.3307],
        [1.3148, 0.0000, 0.4223, 1.3483, 0.0000, 0.0000, 2.0066, 0.0000],
        [1.8233, 0.0000, 1.0833, 2.2275, 0.4882, 0.3238, 0.4521, 0.8122],
        [1.1539, 0.0000, 0.3214, 0.0000, 0.0088, 0.0000, 0.2131, 0.1656],
        [2.3094, 0.0000, 0.0000, 0.2560, 0.0000, 0.0000, 0.0000, 0.8519]],
       device='cuda:0', dtype=torch.float64)


In [21]:
in_out_tensor, in_tensor, bias = get_inputs(add_manual_size=True)
num_weights = bias.numel()
fused_add_mul_relu_cleaner[grid](in_out_tensor, bias, in_tensor, num_weights, in_out_tensor.numel(), multiplier=0.5,
                                 BLOCK_SIZE=BLOCK_SIZE)
print("Output 2", in_out_tensor)
torch.testing.assert_close(in_out_tensor, expected_output, rtol=1e-4, atol=1e-4)

Output 2 tensor([[1.3509, 0.0000, 0.3476, 0.0000, 0.0000, 1.6216, 1.0154, 1.7055],
        [0.7115, 0.0000, 0.4937, 2.3202, 0.0000, 0.0000, 2.9328, 1.9187],
        [0.3889, 0.0000, 0.0000, 0.5725, 0.0000, 0.1259, 0.0000, 0.3423],
        [4.9722, 0.0000, 0.0000, 2.0715, 0.0000, 0.0000, 0.0000, 0.3307],
        [1.3148, 0.0000, 0.4223, 1.3483, 0.0000, 0.0000, 2.0066, 0.0000],
        [1.8233, 0.0000, 1.0833, 2.2275, 0.4882, 0.3238, 0.4521, 0.8122],
        [1.1539, 0.0000, 0.3214, 0.0000, 0.0088, 0.0000, 0.2131, 0.1656],
        [2.3094, 0.0000, 0.0000, 0.2560, 0.0000, 0.0000, 0.0000, 0.8519]],
       device='cuda:0', dtype=torch.float64)
