# Fused Kernels - What started as exploring DLRM

## Abstract

Abstract: With focus on performance to get the most out of hardware, fusing of kernels has been a popular technique. At times, researchers/practitioners will re-write their code in native cuda or cpu kernels to get optimal performance, but projects such as torch.compile aim to make this simpler. Talk will focus on generating fused kernels and how to leverage torch.compile to be able to do that. We will shift a bit from all LLM talk and look into recommendation algorithms as big deep learning/AI systems. In the process, we will work on creating fused kernels (triton and cuda) with the help of torch.compile. 

## Code and other artifacts

- Lecture code: https://github.com/kapilsh/cuda-mode-lecture
- How to open chrome trace: chrome://tracing
- DLRM Blog Post: https://ai.meta.com/blog/dlrm-an-advanced-open-source-deep-learning-recommendation-model/
- DLRM Paper: https://arxiv.org/pdf/1906.00091
- DLRM github repo: https://github.com/facebookresearch/dlrm
- Criteo Dataset: https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/


## DLRM (Deep Learning Recommendation Model)

### MODEL ARCHITECTURE 

![DLRM Model](./data/dlrm_model.png)

### System Constrants

![System Constraints](./data/66324023_2056206621174067_2937830378620059648_n.gif)

### Criteo Dataset

- Training dataset with 24 days of ad display and click data (positive: clicked and negatives: non-clicked)
- 13 features taking integer values (mostly count features)
- 26 anonymized categorical features
- Corresponding Kaggle competition: https://www.kaggle.com/c/criteo-display-ad-challenge

In [27]:
import json
import time
from dataclasses import dataclass
from typing import Mapping, List, Dict, Union

import click
import torch
import torch._dynamo
from loguru import logger
from torch import nn, Tensor
from torch.utils.data import DataLoader

from criteo_dataset import CriteoParquetDataset
from model import DenseArch, read_metadata, SparseArch, DenseSparseInteractionLayer, PredictionLayer, Parameters, DLRM

# Exploring DLRM

In [1]:
file_path = "./data/sample_criteo_data.parquet"
metadata_path = "./data/sample_criteo_metadata.json"

In [5]:
logger.info("Reading the parquet file {}...".format(file_path))
logger.info("Reading the metadata file {}...".format(metadata_path))

dataset = CriteoParquetDataset(file_path)
data_loader = DataLoader(dataset, batch_size=2, shuffle=False)
labels, dense, sparse = next(iter(data_loader))
logger.info("Labels size: {}".format(labels.size()))
logger.info("Dense size: {}".format(dense.size()))
logger.info("Sparse size: {}".format(sparse.size()))

2024-05-04 14:13:38.111 | INFO     | __main__:<module>:1 - Reading the parquet file ./data/sample_criteo_data.parquet...
2024-05-04 14:13:38.113 | INFO     | __main__:<module>:2 - Reading the metadata file ./data/sample_criteo_metadata.json...
2024-05-04 14:13:40.288 | INFO     | __main__:<module>:7 - Labels size: torch.Size([2])
2024-05-04 14:13:40.288 | INFO     | __main__:<module>:8 - Dense size: torch.Size([2, 13])
2024-05-04 14:13:40.289 | INFO     | __main__:<module>:9 - Sparse size: torch.Size([2, 26])


In [6]:
dense

tensor([[5.0000e+00, 1.1000e+02, 0.0000e+00, 1.6000e+01, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 1.4000e+01, 7.0000e+00, 1.0000e+00, 0.0000e+00, 3.0600e+02,
         0.0000e+00],
        [3.2000e+01, 3.0000e+00, 5.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         0.0000e+00, 6.1000e+01, 5.0000e+00, 0.0000e+00, 1.0000e+00, 3.1570e+03,
         5.0000e+00]])

In [7]:
sparse

tensor([[1651969401, 3793706328, 2951365679, 2489089999,  951068488, 1875733963,
          897624609,  679512323, 1189011366,  771915201,  209470001, 2509774111,
           12976055, 3192841527, 2316006604, 1289502458, 3523761834, 3088518074,
         2501034507, 3280875304,  351689309,  632402057, 3619814411, 2091868316,
          809724924, 3977271069],
        [3857972621, 2695561126, 1873417685, 3666490401, 1020698403, 1875733963,
         2870406529, 1128426537,  502653268, 2112471209, 1716706404, 2582335015,
           12976055, 3192841527, 4089183897, 1289502458, 3523761834, 2716538129,
         2501034507, 4273985635, 2737978529, 3370249814,  391309800, 1966410890,
         2568167914, 3075991895]])

In [10]:

dense_mlp_out_size = 16
num_dense_features = dense.size()[1]
dense_arch = DenseArch(dense_feature_count=num_dense_features,
                       dense_hidden_layers_sizes=[32],
                       output_size=dense_mlp_out_size)
dense_out = dense_arch(dense)
logger.info("Dense out size: {}".format(dense_out.size()))
dense_out

2024-05-04 14:15:15.944 | INFO     | __main__:<module>:8 - Dense out size: torch.Size([2, 16])


tensor([[ -15.9158,   -3.1166,  -14.7238,  -15.7895,    9.8099,   -7.4153,
            6.1413,  -17.2388,  -17.7460,   36.3442,    9.2208,   17.1685,
          -26.7153,    4.1549,  -27.6369,   13.1371],
        [ -67.6642,  -24.3709, -192.2249, -123.3462,   86.3253,  -45.0116,
          -26.8654, -194.2189, -240.2030,  305.7565,  136.9120,   46.5375,
         -266.4658,   81.3118, -262.6347,  169.4223]],
       grad_fn=<AddmmBackward0>)

In [16]:


metadata = read_metadata(metadata_path)
embedding_size = 16
embedding_sizes = {fn: embedding_size for fn in metadata.keys()}
sparse_mlp_out_size = 16
sparse_arch = SparseArch(metadata=metadata,
                         embedding_sizes=embedding_sizes)
# compiled model hangs on running with inputs
# sparse_arch_optim = torch.compile(sparse_arch)
sparse_out = sparse_arch(sparse)
for v in sparse_out:
    logger.info("Sparse out size: {}".format(v.size()))
sparse_out[0]

2024-05-04 14:16:50.027 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.027 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.028 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.028 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.028 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.029 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.029 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.029 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.030 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.030 | INFO     | __main__:<module>:13 - Sparse out size: torch.Size([2, 16])
2024-05-04 14:16:50.031 | INFO

tensor([[ 1.0933, -0.0496,  0.7733, -1.7307, -1.5215, -1.7066, -1.0758, -1.0418,
          0.0032,  0.9964,  0.9568, -3.0041,  1.6534, -1.3191,  1.5605, -0.5941],
        [ 1.0211,  0.7705, -1.2030,  1.6011,  0.1957,  0.4547, -0.6117, -1.0417,
         -0.7514, -0.2119,  0.4045, -0.3170,  0.4403, -1.0879,  0.8618, -0.3548]],
       grad_fn=<EmbeddingBackward0>)

In [20]:
dense_sparse_interaction_layer = DenseSparseInteractionLayer()
ds_out = dense_sparse_interaction_layer(dense_out, sparse_out)
logger.info("Dense sparse interaction out size: {}".format(ds_out.size()))
ds_out

2024-05-04 14:18:34.244 | INFO     | __main__:<module>:3 - Dense sparse interaction out size: torch.Size([2, 186624])


tensor([[ 2.5331e+02,  4.9603e+01,  2.3434e+02,  ..., -2.1245e+00,
          2.2212e-01,  9.8313e-01],
        [ 4.5784e+03,  1.6490e+03,  1.3007e+04,  ...,  3.6353e-03,
          7.7478e-03,  2.5760e-04]], grad_fn=<ViewBackward0>)

In [29]:
prediction_layer = PredictionLayer(dense_out_size=dense_mlp_out_size,
                                   sparse_out_sizes=[sparse_mlp_out_size] * len(metadata),
                                   hidden_sizes=[16])
pred_out = prediction_layer(ds_out)
logger.info("Prediction out size: {}".format(pred_out.size()))
logger.info("Prediction out value: {}".format(pred_out))

2024-05-04 14:33:00.571 | INFO     | __main__:<module>:5 - Prediction out size: torch.Size([2, 1])
2024-05-04 14:33:00.572 | INFO     | __main__:<module>:6 - Prediction out value: tensor([[0.9976],
        [1.0000]], grad_fn=<SigmoidBackward0>)


## Model Graph

![Model Graph](./data/model_graph.png)

## Profiling

### Initial Setup: Simple 2 layered MLP used for each triangle

### Baseline

> python model_train.py

### Initial Distribution

![Initial Model Parameters Pie](./perf_screenshots/model_parameters_initial_pie.png)

### What's going on - what is local

In [33]:
sparse_arch.index_hash??

[0;31mSignature:[0m [0msparse_arch[0m[0;34m.[0m[0mindex_hash[0m[0;34m([0m[0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mtokenizer_values[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mindex_hash[0m[0;34m([0m[0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mtokenizer_values[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mtensor[0m [0;34m=[0m [0mtensor[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mtokenizers[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mtokenizer_values[0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0

In [34]:
sparse_arch.modulus_hash??

[0;31mSignature:[0m [0msparse_arch[0m[0;34m.[0m[0mmodulus_hash[0m[0;34m([0m[0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mcardinality[0m[0;34m:[0m [0mint[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mmodulus_hash[0m[0;34m([0m[0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mcardinality[0m[0;34m:[0m [0mint[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0;34m([0m[0mtensor[0m [0;34m+[0m [0;36m1[0m[0;34m)[0m [0;34m%[0m [0mcardinality[0m[0;34m[0m[0;34m[0m[0m
[0;31mFile:[0m      ~/dev/git/cuda-mode-lecture/model.py
[0;31mType:[0m      function

In [35]:
sparse_arch.modulus_hash_opt??

[0;31mSignature:[0m
[0msparse_arch[0m[0;34m.[0m[0mmodulus_hash_opt[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcardinality[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mmodulus_hash_opt[0m[0;34m([0m[0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mcardinality[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0;34m([0m[0mtensor[0m [0;34m+[0m [0;36m1[0m[0;34m)[0m [0;34m%[0m [0mcardinality[0m[0;34m[0m[0;34m[0m[0m
[0;31mFile:[0m      ~/dev/git/cuda-mode-lecture/model.py
[0;31mType:[0m      function

##  torch.compile DLRM

> TORCH_COMPILE_DEBUG_DIR=/home/ksharma/logs TORCH_LOGS=recompiles,+dynamo,inductor,guards,graph_breaks python model.py
