diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md
similarity index 100%
rename from examples/inference/README.md
rename to examples/inference_legacy/README.md
diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py
similarity index 100%
rename from examples/inference/dlrm_client.py
rename to examples/inference_legacy/dlrm_client.py
diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py
similarity index 100%
rename from examples/inference/dlrm_packager.py
rename to examples/inference_legacy/dlrm_packager.py
diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py
similarity index 100%
rename from examples/inference/dlrm_predict.py
rename to examples/inference_legacy/dlrm_predict.py
diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py
similarity index 100%
rename from examples/inference/dlrm_predict_single_gpu.py
rename to examples/inference_legacy/dlrm_predict_single_gpu.py
diff --git a/torchrec/inference/CMakeLists.txt b/torchrec/inference/CMakeLists.txt
index 794bb1490..3b3ac2a8d 100644
--- a/torchrec/inference/CMakeLists.txt
+++ b/torchrec/inference/CMakeLists.txt
@@ -4,110 +4,60 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
-project(inference)
+cmake_minimum_required(VERSION 3.8)
-# This step is crucial to ensure that the
-# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set.
-# e.g. ~/gprc/examples/cpp/cmake/common.cmake
-include(${GRPC_COMMON_CMAKE_PATH}/common.cmake)
+project(inference C CXX)
+include(/home/paulzhan/grpc/examples/cpp/cmake/common.cmake)
-# abi and other flags
-if(DEFINED GLIBCXX_USE_CXX11_ABI)
- if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
- set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1")
- endif()
-endif()
+# Proto file
+get_filename_component(hw_proto "/home/paulzhan/torchrec/torchrec/inference/protos/predictor.proto" ABSOLUTE)
+get_filename_component(hw_proto_path "${hw_proto}" PATH)
-# keep it static for now since folly-shared version is broken
-# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
-# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
+# Generated sources
+set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.cc")
+set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.h")
+set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.cc")
+set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.h")
+add_custom_command(
+ OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
+ COMMAND ${_PROTOBUF_PROTOC}
+ ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
+ --cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
+ -I "${hw_proto_path}"
+ --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
+ "${hw_proto}"
+ DEPENDS "${hw_proto}")
-# dependencies
-find_package(Boost REQUIRED)
-find_package(Torch REQUIRED)
-find_package(folly REQUIRED)
-find_package(gflags REQUIRED)
-
-include_directories(${Torch_INCLUDE_DIRS})
-include_directories(${folly_INCLUDE_DIRS})
-include_directories(${PYTORCH_FMT_INCLUDE_PATH})
+# Include generated *.pb.h files
+include_directories("${CMAKE_CURRENT_BINARY_DIR}")
set(CMAKE_CXX_STANDARD 17)
-# torch deploy library
-add_library(torch_deploy_internal STATIC
- ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o
- ${DEPLOY_SRC_PATH}/deploy.cpp
- ${DEPLOY_SRC_PATH}/loader.cpp
- ${DEPLOY_SRC_PATH}/path_environment.cpp
- ${DEPLOY_SRC_PATH}/elf_file.cpp)
-
-# For python builtins. caffe2_interface_library properly
-# makes use of the --whole-archive option.
-target_link_libraries(torch_deploy_internal PRIVATE
- crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw
-)
-target_link_libraries(torch_deploy_internal
- PUBLIC shm torch ${PYTORCH_LIB_FMT}
-)
-caffe2_interface_library(torch_deploy_internal torch_deploy)
-
-# inference library
-
-# for our own header files
-include_directories(include/)
-include_directories(gen/)
-
-# define our library target
-add_library(inference STATIC
- src/Batching.cpp
- src/BatchingQueue.cpp
- src/GPUExecutor.cpp
- src/ResultSplit.cpp
- src/Exception.cpp
- src/ResourceManager.cpp
-)
-
-# -rdynamic is needed to link against the static library
-target_link_libraries(inference "-Wl,--no-as-needed -rdynamic"
- dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES}
-)
-
-# for generated protobuf
-
-# grpc headers. e.g. ~/.local/include
-include_directories(${GRPC_HEADER_INCLUDE_PATH})
-
-set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc")
-set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h")
-set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc")
-set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h")
+# Torch + FBGEMM
+find_package(Torch REQUIRED)
+add_library( fbgemm SHARED IMPORTED GLOBAL )
+set_target_properties(fbgemm PROPERTIES IMPORTED_LOCATION ${FBGEMM_LIB})
-add_library(pred_grpc_proto STATIC
- ${pred_grpc_srcs}
- ${pred_grpc_hdrs}
- ${pred_proto_srcs}
- ${pred_proto_hdrs})
-target_link_libraries(pred_grpc_proto
+add_library(hw_grpc_proto STATIC
+ ${hw_grpc_srcs}
+ ${hw_grpc_hdrs}
+ ${hw_proto_srcs}
+ ${hw_proto_hdrs})
+target_link_libraries(hw_grpc_proto
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
-# server
+# Targets greeter_[async_](client|server)
add_executable(server server.cpp)
target_link_libraries(server
- inference
- torch_deploy
- pred_grpc_proto
"${TORCH_LIBRARIES}"
- ${FOLLY_LIBRARIES}
- ${PYTORCH_LIB_FMT}
- ${FBGEMM_LIB}
+ fbgemm
+ hw_grpc_proto
${_REFLECTION}
${_GRPC_GRPCPP}
- ${_PROTOBUF_LIBPROTOBUF})
+ ${_PROTOBUF_LIBPROTOBUF}
+)
diff --git a/torchrec/inference/README.md b/torchrec/inference/README.md
index 8958fc72d..0a8e1cc17 100644
--- a/torchrec/inference/README.md
+++ b/torchrec/inference/README.md
@@ -1,116 +1,65 @@
# TorchRec Inference Library (**Experimental** Release)
## Overview
-TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL.
+TorchRec Inference is a C++ library that supports **gpu inference**. Previously, the TorchRec inference library was authored with torch.package and torch.deploy, which are old and deprecated. All the previous files live under the directory inference_legacy for reference.
-Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client.
+TorchRec inference was reauthored with simplicity in mind, while also reflecting the current production environment for RecSys models, namely torch.fx for graph capturing/tracing and TorchScript for model inference in a C++ environment. The inference solution here is meant to serve as a simple reference and example, not a fully scaled out solution for production use cases. The current solution demonstrates converting the DLRM model in Python to TorchScript, running a C++ inference server with the model on a GPU, and sending requests to said server via a python client.
-## Example
+## Requirements
-C++ 17 is a requirement.
+C++ 17 is a requirement. GCC version has to be >= 9, with initial testing done on GCC 9.
### **1. Install Dependencies**
-
-Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy
-is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via:
-
+1. [GRPC for C++][https://grpc.io/docs/languages/cpp/quickstart/] needs to be installed, with the resulting installation directory being `$HOME/.local`
+2. Ensure that **the protobuf compiler (protoc) binary being used is from the GRPC installation above**. The protoc binary will live in `$HOME/.local/bin`, which may not match with the system protoc binary, can check with `which protoc`.
+3. Install PyTorch, FBGEMM, and TorchRec (ideally in a virtual environment):
```
-sudo nvidia-docker build -t torchrec .
-sudo nvidia-docker run -it torchrec:latest
+pip install torch --index-url https://download.pytorch.org/whl/cu121
+pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121
+pip install torchmetrics==1.0.3
+pip install torchrec --index-url https://download.pytorch.org/whl/cu121
```
+
### **2. Set variables**
Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime.
```
-# provide the cmake prefix path of pytorch, folly, and fmt.
-# fmt and boost are pulled from folly's installation in this example.
-export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly"
-export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt"
-export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0"
-
-# provide fmt from pytorch for torch deploy
-export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/"
-export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a"
-
-# provide necessary info to link to torch deploy
-export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy"
-export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy"
-
-# provide common.cmake from grpc/examples, makes linking to grpc easier
-export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake"
-export GRPC_HEADER_INCLUDE_PATH="~/.local/include/"
-
-# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators
-export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so"
-
-# provide path to python packages for torch deploy runtime
-export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/"
-```
-
-Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries.
+# provide fbgemm_gpu_py.so to enable fbgemm_gpu c++ operators
+find $HOME -name fbgemm_gpu_py.so
+# Use path from correct virtual environment above and set environment variable $FBGEMM_LIB to it
+export FBGEMM_LIB=""
```
-# double-conversion, fmt and gflags are pulled from folly's installation in this example
-export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib"
-export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/"
-export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib"
-export PYTORCH_LIB_PATH="~/pytorch/build/lib/"
-export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH"
-export LIBRARY_PATH="$PYTORCH_LIB_PATH"
-```
+### **3. Generate TorchScripted DLRM model**
-### **3. Package DLRM model**
-
-The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement
-`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and
-implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read
-more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html.
-
-`/torchrec/examples/inference/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference/dlrm_predict.py`).
-`DLRMPredictModule` is packaged for inference in the following example.
+Here, we generate the DLRM model in Torchscript and save it for model loading later on.
```
git clone https://github.com/pytorch/torchrec.git
-cd ~/torchrec/examples/inference/
-python dlrm_packager.py --output_path /tmp/model_package.zip
+cd ~/torchrec/torchrec/inference/
+python3 dlrm_packager.py --output_path /tmp/model.pt
```
-
### **4. Build inference library and example server**
-Generate protobuf C++ and Python code from protobuf
+Generate Python code from protobuf for client and build the server.
```
-cd ~/torchrec/inference/
-mkdir -p gen/torchrec/inference
-
-# C++ (server)
-protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto
-
-protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto
-
-
# Python (client)
-python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto
+python -m grpc_tools.protoc -I protos --python_out=. --grpc_python_out=. protos/predictor.proto
```
-Build inference library and example server
+Build server and C++ protobufs
```
-cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;"
--DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \
--DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \
--DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \
--DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \
--DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \
--DFBGEMM_LIB="$FBGEMM_LIB"
+cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');" -DFBGEMM_LIB="$FBGEMM_LIB"
cd build
make -j
@@ -119,28 +68,20 @@ make -j
### **5. Run server and client**
-Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size.
+Start the server, loading in the model saved previously
```
-CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH
+./server /tmp/model.pt
```
**output**
-In the logs, a plan should be outputted by the Torchrec planner:
-
-```
-INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- #
-INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- #
-INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- #
-INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards #
-INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- #
-INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 #
-INFO:.torchrec.distributed.planner.stats:# #
-INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables #
-INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients #
-INFO:.torchrec.distributed.planner.stats:# #
-INFO:.torchrec.distributed.planner.stats:# Compute Kernels: #
-INFO:.torchrec.distributed.planner.stats:# quant: 26 #
+In the logs, you should see:
+
+```
+Loading model...
+Sanity Check with dummy inputs
+ Model Forward Completed, Output: 0.489247
+Server listening on 0.0.0.0:50051
````
`nvidia-smi` output should also show allocation of the model onto the gpu:
@@ -155,7 +96,7 @@ INFO:.torchrec.distributed.planner.stats:# quant: 26
+-----------------------------------------------------------------------------+
```
-Make a request to the server via the client:
+In another terminal instance, make a request to the server via the client:
```
python client.py
@@ -166,74 +107,3 @@ python client.py
```
Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...]
```
-
-
-
-## Planned work
-
-- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference
-- In-code documentation
-- Simplify installation process
-
-
-
-## Potential issues and solutions
-
-Skip this section if you had no issues with installation or running the example.
-
-**Missing header files during pytorch installation**
-
-If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below:
-
-```
-~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory
- #include
- ^~~~~~~~
-compilation terminated.
-[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o
-ninja: build stopped: subcommand failed.
-```
-
-To get these header files, install `cudatoolkit-dev`:
-```
-conda install -c conda-forge cudatoolkit-dev
-```
-
-Re-run the installation after this.
-
-**libdouble-conversion missing**
-```
-~/torchrec/torchrec/inference/build$ ./example
-./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory
-```
-
-If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion:
-
-```
-sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \
-libdouble-conversion.so.3
-```
-
-**Two installations of glog**
-```
-~/torchrec/torchrec/inference/build$ ./example
-ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and
-'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc').
-```
-The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was
-previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly.
-
-To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2.
-
-**Undefined symbols with std::string or cxx11**
-
-If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely
-that your dependencies were compiled with different ABI values. Re-compile your dependencies
-and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build.
-
-The ABI value of pytorch can be checked via:
-
-```
-import torch
-torch._C._GLIBCXX_USE_CXX11_ABI
-```
diff --git a/torchrec/inference/__init__.py b/torchrec/inference/__init__.py
index 6c2050114..a0ce8680a 100644
--- a/torchrec/inference/__init__.py
+++ b/torchrec/inference/__init__.py
@@ -7,21 +7,4 @@
# pyre-strict
-"""Torchrec Inference
-
-Torchrec inference provides a Torch.Deploy based library for GPU inference.
-
-These includes:
- - Model packaging in Python
- - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving.
- - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package.
- - Model serving in C++
- - `BatchingQueue` is a generalized config-based request tensor batching implementation.
- - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy.
-
-We implemented an example of how to use this library with the TorchRec DLRM model.
- - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package.
- - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
-"""
-
from . import model_packager, modules # noqa # noqa
diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py
index c185f3909..725338f46 100644
--- a/torchrec/inference/client.py
+++ b/torchrec/inference/client.py
@@ -9,8 +9,8 @@
import logging
import grpc
+import predictor_pb2, predictor_pb2_grpc
import torch
-from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc
from torch.utils.data import DataLoader
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
@@ -20,18 +20,13 @@
def create_training_batch(args: argparse.Namespace) -> Batch:
return next(
iter(
- DataLoader(
- RandomRecDataset(
- keys=DEFAULT_CAT_NAMES,
- batch_size=args.batch_size,
- hash_size=args.num_embedding_features,
- ids_per_feature=1,
- num_dense=len(DEFAULT_INT_NAMES),
- ),
- batch_sampler=None,
- pin_memory=False,
- num_workers=0,
- )
+ RandomRecDataset(
+ keys=DEFAULT_CAT_NAMES,
+ batch_size=args.batch_size,
+ hash_size=args.num_embedding_features,
+ ids_per_feature=1,
+ num_dense=len(DEFAULT_INT_NAMES),
+ ),
)
)
diff --git a/torchrec/inference/dlrm_packager.py b/torchrec/inference/dlrm_packager.py
new file mode 100644
index 000000000..560e31c16
--- /dev/null
+++ b/torchrec/inference/dlrm_packager.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# @nolint
+# pyre-ignore-all-errors
+
+
+import argparse
+import sys
+from typing import List
+
+from dlrm_predict import create_training_batch, DLRMModelConfig, DLRMPredictFactory
+from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
+
+
+def parse_args(argv: List[str]) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="torchrec dlrm model packager")
+ parser.add_argument(
+ "--num_embeddings",
+ type=int,
+ default=100_000,
+ help="max_ind_size. The number of embeddings in each embedding table. Defaults"
+ " to 100_000 if num_embeddings_per_feature is not supplied.",
+ )
+ parser.add_argument(
+ "--num_embeddings_per_feature",
+ type=str,
+ default="45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138,"
+ "10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35",
+ help="Comma separated max_ind_size per sparse feature. The number of embeddings"
+ " in each embedding table. 26 values are expected for the Criteo dataset.",
+ )
+ parser.add_argument(
+ "--sparse_feature_names",
+ type=str,
+ default=",".join(DEFAULT_CAT_NAMES),
+ help="Comma separated names of the sparse features.",
+ )
+ parser.add_argument(
+ "--dense_arch_layer_sizes",
+ type=str,
+ default="512,256,64",
+ help="Comma separated layer sizes for dense arch.",
+ )
+ parser.add_argument(
+ "--over_arch_layer_sizes",
+ type=str,
+ default="512,512,256,1",
+ help="Comma separated layer sizes for over arch.",
+ )
+ parser.add_argument(
+ "--embedding_dim",
+ type=int,
+ default=64,
+ help="Size of each embedding.",
+ )
+ parser.add_argument(
+ "--num_dense_features",
+ type=int,
+ default=len(DEFAULT_INT_NAMES),
+ help="Number of dense features.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ help="Output path of model package.",
+ )
+ return parser.parse_args(argv)
+
+
+def main(argv: List[str]) -> None:
+ """
+ Use torch.package to package the torchrec DLRM Model.
+
+ Args:
+ argv (List[str]): command line args.
+
+ Returns:
+ None.
+ """
+
+ args = parse_args(argv)
+
+ args.batch_size = 10
+ args.num_embedding_features = 26
+ batch = create_training_batch(args)
+
+ model_config = DLRMModelConfig(
+ dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))),
+ dense_in_features=args.num_dense_features,
+ embedding_dim=args.embedding_dim,
+ id_list_features_keys=args.sparse_feature_names.split(","),
+ num_embeddings_per_feature=list(
+ map(int, args.num_embeddings_per_feature.split(","))
+ ),
+ num_embeddings=args.num_embeddings,
+ over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))),
+ sample_input=batch,
+ )
+
+ script_module = DLRMPredictFactory(model_config).create_predict_module(world_size=1)
+
+ script_module.save(args.output_path)
+ print(f"Package is saved to {args.output_path}")
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/torchrec/inference/dlrm_predict.py b/torchrec/inference/dlrm_predict.py
new file mode 100644
index 000000000..3ac2966ad
--- /dev/null
+++ b/torchrec/inference/dlrm_predict.py
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# @nolint
+# pyre-ignore-all-errors
+
+
+import logging
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import torch
+from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
+from torchrec.datasets.random import RandomRecDataset
+from torchrec.datasets.utils import Batch
+from torchrec.fx.tracer import Tracer
+from torchrec.inference.modules import (
+ PredictFactory,
+ PredictModule,
+ quantize_inference_model,
+ shard_quant_model,
+)
+from torchrec.models.dlrm import DLRM
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
+
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+def create_training_batch(args) -> Batch:
+ return RandomRecDataset(
+ keys=DEFAULT_CAT_NAMES,
+ batch_size=args.batch_size,
+ hash_size=args.num_embedding_features,
+ ids_per_feature=1,
+ num_dense=len(DEFAULT_INT_NAMES),
+ ).batch_generator._generate_batch()
+
+
+# OSS Only
+
+
+@dataclass
+class DLRMModelConfig:
+ dense_arch_layer_sizes: List[int]
+ dense_in_features: int
+ embedding_dim: int
+ id_list_features_keys: List[str]
+ num_embeddings_per_feature: List[int]
+ num_embeddings: int
+ over_arch_layer_sizes: List[int]
+ sample_input: Batch
+
+
+class DLRMPredictModule(PredictModule):
+ """
+ nn.Module to wrap DLRM model to use for inference.
+
+ Args:
+ embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
+ used to define SparseArch.
+ dense_in_features (int): the dimensionality of the dense input features.
+ dense_arch_layer_sizes (List[int]): the layer sizes for the DenseArch.
+ over_arch_layer_sizes (List[int]): the layer sizes for the OverArch. NOTE: The
+ output dimension of the InteractionArch should not be manually specified
+ here.
+ id_list_features_keys (List[str]): the names of the sparse features. Used to
+ construct a batch for inference.
+ dense_device: (Optional[torch.device]).
+ """
+
+ def __init__(
+ self,
+ embedding_bag_collection: EmbeddingBagCollection,
+ dense_in_features: int,
+ dense_arch_layer_sizes: List[int],
+ over_arch_layer_sizes: List[int],
+ id_list_features_keys: List[str],
+ dense_device: Optional[torch.device] = None,
+ ) -> None:
+ module = DLRM(
+ embedding_bag_collection=embedding_bag_collection,
+ dense_in_features=dense_in_features,
+ dense_arch_layer_sizes=dense_arch_layer_sizes,
+ over_arch_layer_sizes=over_arch_layer_sizes,
+ dense_device=dense_device,
+ )
+ super().__init__(module)
+
+ self.id_list_features_keys: List[str] = id_list_features_keys
+
+ def predict_forward(
+ self, batch: Dict[str, torch.Tensor]
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Args:
+ batch (Dict[str, torch.Tensor]): currently expects input dense features
+ to be mapped to the key "float_features" and input sparse features
+ to be mapped to the key "id_list_features".
+
+ Returns:
+ Dict[str, torch.Tensor]: output of inference.
+ """
+
+ try:
+ logits = self.predict_module(
+ batch["float_features"],
+ KeyedJaggedTensor(
+ keys=self.id_list_features_keys,
+ lengths=batch["id_list_features.lengths"],
+ values=batch["id_list_features.values"],
+ ),
+ )
+ predictions = logits.sigmoid()
+ except Exception as e:
+ logger.info(e)
+ raise e
+
+ # Flip predictions tensor to be 1D. TODO: Determine why prediction shape
+ # can be 2D at times (likely due to input format?)
+ predictions = predictions.reshape(
+ [
+ predictions.size()[0],
+ ]
+ )
+
+ return {
+ "default": predictions.to(torch.device("cpu"), non_blocking=True).float()
+ }
+
+
+class DLRMPredictFactory(PredictFactory):
+ def __init__(self, model_config: DLRMModelConfig) -> None:
+ self.model_config = model_config
+
+ def create_predict_module(self, world_size: int) -> torch.nn.Module:
+ logging.basicConfig(level=logging.INFO)
+ device = torch.device("cuda:0")
+
+ eb_configs = [
+ EmbeddingBagConfig(
+ name=f"t_{feature_name}",
+ embedding_dim=self.model_config.embedding_dim,
+ num_embeddings=(
+ self.model_config.num_embeddings_per_feature[feature_idx]
+ if self.model_config.num_embeddings is None
+ else self.model_config.num_embeddings
+ ),
+ feature_names=[feature_name],
+ )
+ for feature_idx, feature_name in enumerate(
+ self.model_config.id_list_features_keys
+ )
+ ]
+ ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
+
+ module = DLRMPredictModule(
+ embedding_bag_collection=ebc,
+ dense_in_features=self.model_config.dense_in_features,
+ dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes,
+ over_arch_layer_sizes=self.model_config.over_arch_layer_sizes,
+ id_list_features_keys=self.model_config.id_list_features_keys,
+ dense_device=device,
+ )
+
+ table_fqns = []
+ for name, _ in module.named_modules():
+ if "t_" in name:
+ table_fqns.append(name.split(".")[-1])
+
+ quant_model = quantize_inference_model(module)
+ sharded_model, _ = shard_quant_model(quant_model, table_fqns)
+
+ batch = {}
+ batch["float_features"] = self.model_config.sample_input.dense_features.cuda()
+ batch["id_list_features.lengths"] = (
+ self.model_config.sample_input.sparse_features.lengths().cuda()
+ )
+ batch["id_list_features.values"] = (
+ self.model_config.sample_input.sparse_features.values().cuda()
+ )
+
+ sharded_model(batch)
+
+ tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
+
+ graph = tracer.trace(sharded_model)
+ gm = torch.fx.GraphModule(sharded_model, graph)
+
+ gm(batch)
+ scripted_gm = torch.jit.script(gm)
+ scripted_gm(batch)
+ return scripted_gm
+
+ def batching_metadata(self) -> Dict[str, str]:
+ return {
+ "float_features": "dense",
+ "id_list_features": "sparse",
+ }
+
+ def result_metadata(self) -> str:
+ return "dict_of_tensor"
+
+ def run_weights_independent_tranformations(
+ self, predict_module: torch.nn.Module
+ ) -> torch.nn.Module:
+ return predict_module
+
+ def run_weights_dependent_transformations(
+ self, predict_module: torch.nn.Module
+ ) -> torch.nn.Module:
+ """
+ Run transformations that depends on weights of the predict module. e.g. lowering to a backend.
+ """
+ return predict_module
diff --git a/torchrec/inference/inference_legacy/CMakeLists.txt b/torchrec/inference/inference_legacy/CMakeLists.txt
new file mode 100644
index 000000000..794bb1490
--- /dev/null
+++ b/torchrec/inference/inference_legacy/CMakeLists.txt
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
+project(inference)
+
+# This step is crucial to ensure that the
+# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set.
+# e.g. ~/gprc/examples/cpp/cmake/common.cmake
+include(${GRPC_COMMON_CMAKE_PATH}/common.cmake)
+
+
+# abi and other flags
+
+if(DEFINED GLIBCXX_USE_CXX11_ABI)
+ if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
+ set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1")
+ endif()
+endif()
+
+# keep it static for now since folly-shared version is broken
+# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
+# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
+
+# dependencies
+find_package(Boost REQUIRED)
+find_package(Torch REQUIRED)
+find_package(folly REQUIRED)
+find_package(gflags REQUIRED)
+
+include_directories(${Torch_INCLUDE_DIRS})
+include_directories(${folly_INCLUDE_DIRS})
+include_directories(${PYTORCH_FMT_INCLUDE_PATH})
+
+set(CMAKE_CXX_STANDARD 17)
+
+# torch deploy library
+add_library(torch_deploy_internal STATIC
+ ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o
+ ${DEPLOY_SRC_PATH}/deploy.cpp
+ ${DEPLOY_SRC_PATH}/loader.cpp
+ ${DEPLOY_SRC_PATH}/path_environment.cpp
+ ${DEPLOY_SRC_PATH}/elf_file.cpp)
+
+# For python builtins. caffe2_interface_library properly
+# makes use of the --whole-archive option.
+target_link_libraries(torch_deploy_internal PRIVATE
+ crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw
+)
+target_link_libraries(torch_deploy_internal
+ PUBLIC shm torch ${PYTORCH_LIB_FMT}
+)
+caffe2_interface_library(torch_deploy_internal torch_deploy)
+
+# inference library
+
+# for our own header files
+include_directories(include/)
+include_directories(gen/)
+
+# define our library target
+add_library(inference STATIC
+ src/Batching.cpp
+ src/BatchingQueue.cpp
+ src/GPUExecutor.cpp
+ src/ResultSplit.cpp
+ src/Exception.cpp
+ src/ResourceManager.cpp
+)
+
+# -rdynamic is needed to link against the static library
+target_link_libraries(inference "-Wl,--no-as-needed -rdynamic"
+ dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES}
+)
+
+# for generated protobuf
+
+# grpc headers. e.g. ~/.local/include
+include_directories(${GRPC_HEADER_INCLUDE_PATH})
+
+set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc")
+set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h")
+set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc")
+set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h")
+
+add_library(pred_grpc_proto STATIC
+ ${pred_grpc_srcs}
+ ${pred_grpc_hdrs}
+ ${pred_proto_srcs}
+ ${pred_proto_hdrs})
+
+target_link_libraries(pred_grpc_proto
+ ${_REFLECTION}
+ ${_GRPC_GRPCPP}
+ ${_PROTOBUF_LIBPROTOBUF})
+
+# server
+add_executable(server server.cpp)
+target_link_libraries(server
+ inference
+ torch_deploy
+ pred_grpc_proto
+ "${TORCH_LIBRARIES}"
+ ${FOLLY_LIBRARIES}
+ ${PYTORCH_LIB_FMT}
+ ${FBGEMM_LIB}
+ ${_REFLECTION}
+ ${_GRPC_GRPCPP}
+ ${_PROTOBUF_LIBPROTOBUF})
diff --git a/torchrec/inference/inference_legacy/README.md b/torchrec/inference/inference_legacy/README.md
new file mode 100644
index 000000000..fc3d8afcb
--- /dev/null
+++ b/torchrec/inference/inference_legacy/README.md
@@ -0,0 +1,239 @@
+# TorchRec Inference Library (**Experimental** Release)
+
+## Overview
+TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL.
+
+Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client.
+
+## Example
+
+C++ 17 is a requirement.
+
+
+
+### **1. Install Dependencies**
+
+Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy
+is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via:
+
+```
+sudo nvidia-docker build -t torchrec .
+sudo nvidia-docker run -it torchrec:latest
+```
+
+### **2. Set variables**
+
+Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime.
+
+```
+# provide the cmake prefix path of pytorch, folly, and fmt.
+# fmt and boost are pulled from folly's installation in this example.
+export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly"
+export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt"
+export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0"
+
+# provide fmt from pytorch for torch deploy
+export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/"
+export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a"
+
+# provide necessary info to link to torch deploy
+export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy"
+export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy"
+
+# provide common.cmake from grpc/examples, makes linking to grpc easier
+export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake"
+export GRPC_HEADER_INCLUDE_PATH="~/.local/include/"
+
+# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators
+export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so"
+
+# provide path to python packages for torch deploy runtime
+export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/"
+```
+
+Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries.
+
+```
+# double-conversion, fmt and gflags are pulled from folly's installation in this example
+export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib"
+export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/"
+export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib"
+export PYTORCH_LIB_PATH="~/pytorch/build/lib/"
+
+export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH"
+export LIBRARY_PATH="$PYTORCH_LIB_PATH"
+```
+
+### **3. Package DLRM model**
+
+The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement
+`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and
+implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read
+more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html.
+
+`/torchrec/examples/inference_legacy/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference_legacy/dlrm_predict.py`).
+`DLRMPredictModule` is packaged for inference in the following example.
+
+```
+git clone https://github.com/pytorch/torchrec.git
+
+cd ~/torchrec/examples/inference_legacy/
+python dlrm_packager.py --output_path /tmp/model_package.zip
+```
+
+
+
+### **4. Build inference library and example server**
+
+Generate protobuf C++ and Python code from protobuf
+
+```
+cd ~/torchrec/inference/
+mkdir -p gen/torchrec/inference
+
+# C++ (server)
+protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto
+
+protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto
+
+
+# Python (client)
+python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto
+```
+
+
+Build inference library and example server
+```
+cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;"
+-DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \
+-DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \
+-DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \
+-DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \
+-DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \
+-DFBGEMM_LIB="$FBGEMM_LIB"
+
+cd build
+make -j
+```
+
+
+### **5. Run server and client**
+
+Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size.
+```
+CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH
+```
+
+**output**
+
+In the logs, a plan should be outputted by the Torchrec planner:
+
+```
+INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- #
+INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- #
+INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- #
+INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards #
+INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- #
+INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 #
+INFO:.torchrec.distributed.planner.stats:# #
+INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables #
+INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients #
+INFO:.torchrec.distributed.planner.stats:# #
+INFO:.torchrec.distributed.planner.stats:# Compute Kernels: #
+INFO:.torchrec.distributed.planner.stats:# quant: 26 #
+````
+
+`nvidia-smi` output should also show allocation of the model onto the gpu:
+
+```
++-----------------------------------------------------------------------------+
+| Processes: |
+| GPU GI CI PID Type Process name GPU Memory |
+| ID ID Usage |
+|=============================================================================|
+| 0 N/A N/A 86668 C ./example 1357MiB |
++-----------------------------------------------------------------------------+
+```
+
+Make a request to the server via the client:
+
+```
+python client.py
+```
+
+**output**
+
+```
+Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...]
+```
+
+
+
+## Planned work
+
+- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference
+- In-code documentation
+- Simplify installation process
+
+
+
+## Potential issues and solutions
+
+Skip this section if you had no issues with installation or running the example.
+
+**Missing header files during pytorch installation**
+
+If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below:
+
+```
+~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory
+ #include
+ ^~~~~~~~
+compilation terminated.
+[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o
+ninja: build stopped: subcommand failed.
+```
+
+To get these header files, install `cudatoolkit-dev`:
+```
+conda install -c conda-forge cudatoolkit-dev
+```
+
+Re-run the installation after this.
+
+**libdouble-conversion missing**
+```
+~/torchrec/torchrec/inference/build$ ./example
+./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory
+```
+
+If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion:
+
+```
+sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \
+libdouble-conversion.so.3
+```
+
+**Two installations of glog**
+```
+~/torchrec/torchrec/inference/build$ ./example
+ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and
+'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc').
+```
+The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was
+previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly.
+
+To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2.
+
+**Undefined symbols with std::string or cxx11**
+
+If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely
+that your dependencies were compiled with different ABI values. Re-compile your dependencies
+and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build.
+
+The ABI value of pytorch can be checked via:
+
+```
+import torch
+torch._C._GLIBCXX_USE_CXX11_ABI
+```
diff --git a/torchrec/inference/inference_legacy/__init__.py b/torchrec/inference/inference_legacy/__init__.py
new file mode 100644
index 000000000..670f2af78
--- /dev/null
+++ b/torchrec/inference/inference_legacy/__init__.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torchrec Inference
+
+Torchrec inference provides a Torch.Deploy based library for GPU inference.
+
+These includes:
+ - Model packaging in Python
+ - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving.
+ - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package.
+ - Model serving in C++
+ - `BatchingQueue` is a generalized config-based request tensor batching implementation.
+ - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy.
+
+We implemented an example of how to use this library with the TorchRec DLRM model.
+ - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package.
+ - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
+"""
+
+from . import model_packager, modules # noqa # noqa
diff --git a/torchrec/inference/inference_legacy/client.py b/torchrec/inference/inference_legacy/client.py
new file mode 100644
index 000000000..a3a9f2a83
--- /dev/null
+++ b/torchrec/inference/inference_legacy/client.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# @nolint
+# pyre-ignore-all-errors
+
+
+import argparse
+import logging
+
+import grpc
+import torch
+from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc
+from torch.utils.data import DataLoader
+from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
+from torchrec.datasets.random import RandomRecDataset
+from torchrec.datasets.utils import Batch
+
+
+def create_training_batch(args: argparse.Namespace) -> Batch:
+ return next(
+ iter(
+ DataLoader(
+ RandomRecDataset(
+ keys=DEFAULT_CAT_NAMES,
+ batch_size=args.batch_size,
+ hash_size=args.num_embedding_features,
+ ids_per_feature=1,
+ num_dense=len(DEFAULT_INT_NAMES),
+ ),
+ batch_sampler=None,
+ pin_memory=False,
+ num_workers=0,
+ )
+ )
+ )
+
+
+def create_request(
+ batch: Batch, args: argparse.Namespace
+) -> predictor_pb2.PredictionRequest:
+ def to_bytes(tensor: torch.Tensor) -> bytes:
+ return tensor.cpu().numpy().tobytes()
+
+ float_features = predictor_pb2.FloatFeatures(
+ num_features=args.num_float_features,
+ values=to_bytes(batch.dense_features),
+ )
+
+ id_list_features = predictor_pb2.SparseFeatures(
+ num_features=args.num_id_list_features,
+ values=to_bytes(batch.sparse_features.values()),
+ lengths=to_bytes(batch.sparse_features.lengths()),
+ )
+
+ id_score_list_features = predictor_pb2.SparseFeatures(num_features=0)
+ embedding_features = predictor_pb2.FloatFeatures(num_features=0)
+ unary_features = predictor_pb2.SparseFeatures(num_features=0)
+
+ return predictor_pb2.PredictionRequest(
+ batch_size=args.batch_size,
+ float_features=float_features,
+ id_list_features=id_list_features,
+ id_score_list_features=id_score_list_features,
+ embedding_features=embedding_features,
+ unary_features=unary_features,
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ip",
+ type=str,
+ default="0.0.0.0",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=50051,
+ )
+ parser.add_argument(
+ "--num_float_features",
+ type=int,
+ default=13,
+ )
+ parser.add_argument(
+ "--num_id_list_features",
+ type=int,
+ default=26,
+ )
+ parser.add_argument(
+ "--num_id_score_list_features",
+ type=int,
+ default=0,
+ )
+ parser.add_argument(
+ "--num_embedding_features",
+ type=int,
+ default=100000,
+ )
+ parser.add_argument(
+ "--embedding_feature_dim",
+ type=int,
+ default=100,
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=100,
+ )
+
+ args: argparse.Namespace = parser.parse_args()
+
+ training_batch: Batch = create_training_batch(args)
+ request: predictor_pb2.PredictionRequest = create_request(training_batch, args)
+
+ with grpc.insecure_channel(f"{args.ip}:{args.port}") as channel:
+ stub = predictor_pb2_grpc.PredictorStub(channel)
+ response = stub.Predict(request)
+ print("Response: ", response.predictions["default"].data)
+
+if __name__ == "__main__":
+ logging.basicConfig()
diff --git a/torchrec/inference/docker/Dockerfile b/torchrec/inference/inference_legacy/docker/Dockerfile
similarity index 100%
rename from torchrec/inference/docker/Dockerfile
rename to torchrec/inference/inference_legacy/docker/Dockerfile
diff --git a/torchrec/inference/docs/inference.rst b/torchrec/inference/inference_legacy/docs/inference.rst
similarity index 100%
rename from torchrec/inference/docs/inference.rst
rename to torchrec/inference/inference_legacy/docs/inference.rst
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h
new file mode 100644
index 000000000..26e7987f9
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+
+#define TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(condition, message) \
+ if (!(condition)) { \
+ throw std::runtime_error( \
+ "Internal Assertion failed: (" + std::string(#condition) + "), " + \
+ "function " + __FUNCTION__ + ", file " + __FILE__ + ", line " + \
+ std::to_string(__LINE__) + ".\n" + \
+ "Please report bug to TorchRec.\n" + message + "\n"); \
+ }
+
+#define TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(condition) \
+ TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(#condition, "")
+
+#define TORCHREC_INTERNAL_ASSERT_(x, condition, message, FUNC, ...) FUNC
+
+#define TORCHREC_INTERNAL_ASSERT(...) \
+ TORCHREC_INTERNAL_ASSERT_( \
+ , \
+ ##__VA_ARGS__, \
+ TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(__VA_ARGS__), \
+ TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(__VA_ARGS__));
+
+#define TORCHREC_CHECK_WITH_MESSAGE(condition, message) \
+ if (!(condition)) { \
+ throw std::runtime_error( \
+ "Check failed: (" + std::string(#condition) + "), " + "function " + \
+ __FUNCTION__ + ", file " + __FILE__ + ", line " + \
+ std::to_string(__LINE__) + ".\n" + message + "\n"); \
+ }
+
+#define TORCHREC_CHECK_NO_MESSAGE(condition) \
+ TORCHREC_CHECK_WITH_MESSAGE(#condition, "")
+
+#define TORCHREC_CHECK_(x, condition, message, FUNC, ...) FUNC
+
+#define TORCHREC_CHECK(...) \
+ TORCHREC_CHECK_( \
+ , \
+ ##__VA_ARGS__, \
+ TORCHREC_CHECK_WITH_MESSAGE(__VA_ARGS__), \
+ TORCHREC_CHECK_NO_MESSAGE(__VA_ARGS__));
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h
new file mode 100644
index 000000000..2c5e8837c
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#include "torchrec/inference/JaggedTensor.h"
+#include "torchrec/inference/Types.h"
+
+namespace torchrec {
+
+using LazyTensorRef = folly::detail::Lazy>&;
+
+// BatchingFunc should be responsible to move the output tensor to desired
+// location using the device input.
+class BatchingFunc {
+ public:
+ virtual ~BatchingFunc() = default;
+
+ virtual std::unordered_map batch(
+ const std::string& /* featureName */,
+ const std::vector>& /* requests */,
+ const int64_t& /* totalNumBatch */,
+ LazyTensorRef /* batchOffsets */,
+ const c10::Device& /* device */,
+ LazyTensorRef /* batchItems */) = 0;
+};
+
+/**
+ * TorchRecBatchingFuncRegistry is used to register custom batching functions.
+ */
+C10_DECLARE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc);
+
+#define REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY(name, priority, ...) \
+ C10_REGISTER_CLASS_WITH_PRIORITY( \
+ TorchRecBatchingFuncRegistry, name, priority, __VA_ARGS__);
+
+#define REGISTER_TORCHREC_BATCHING_FUNC(name, ...) \
+ REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY( \
+ name, c10::REGISTRY_DEFAULT, __VA_ARGS__);
+
+std::unordered_map combineFloat(
+ const std::string& featureName,
+ const std::vector>& requests);
+
+std::unordered_map combineSparse(
+ const std::string& featureName,
+ const std::vector>& requests,
+ bool isWeighted);
+
+std::unordered_map combineEmbedding(
+ const std::string& featureName,
+ const std::vector>& requests);
+
+void moveIValueToDevice(c10::IValue& val, const c10::Device& device);
+
+std::unordered_map moveToDevice(
+ std::unordered_map combined,
+ const c10::Device& device);
+
+} // namespace torchrec
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h
new file mode 100644
index 000000000..f012572b2
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include // @manual
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "torchrec/inference/Batching.h"
+#include "torchrec/inference/Observer.h"
+#include "torchrec/inference/ResourceManager.h"
+#include "torchrec/inference/Types.h"
+
+namespace torchrec {
+
+using BatchQueueCb = std::function)>;
+
+class BatchingQueue {
+ public:
+ struct Config {
+ std::chrono::milliseconds batchingInterval = std::chrono::milliseconds(10);
+ std::chrono::milliseconds queueTimeout = std::chrono::milliseconds(500);
+ int numExceptionThreads = 4;
+ int numMemPinnerThreads = 4;
+ int maxBatchSize = 2000;
+ // For feature name to BatchingFunc name.
+ const std::unordered_map batchingMetadata;
+ std::function eventCreationFn;
+ std::function warmupFn;
+ };
+
+ BatchingQueue(const BatchingQueue&) = delete;
+ BatchingQueue& operator=(const BatchingQueue&) = delete;
+
+ BatchingQueue(
+ std::vector cbs,
+ const Config& config,
+ int worldSize,
+ std::unique_ptr observer,
+ std::shared_ptr resourceManager = nullptr);
+ ~BatchingQueue();
+
+ void add(
+ std::shared_ptr request,
+ folly::Promise> promise);
+
+ void stop();
+
+ private:
+ struct QueryQueueEntry {
+ std::shared_ptr request;
+ RequestContext context;
+ std::chrono::time_point addedTime;
+ };
+
+ struct BatchingQueueEntry {
+ std::vector> requests;
+ std::vector contexts;
+ std::chrono::time_point addedTime;
+ };
+
+ void createBatch();
+
+ void pinMemory(int gpuIdx);
+
+ void observeBatchCompletion(size_t batchSizeBytes, size_t numRequests);
+
+ const Config config_;
+
+ // Batching func name to batching func instance.
+ std::unordered_map> batchingFuncs_;
+ std::vector cbs_;
+ std::thread batchingThread_;
+ std::vector memPinnerThreads_;
+ std::unique_ptr rejectionExecutor_;
+ folly::Synchronized> requestQueue_;
+ std::vector>>
+ batchingQueues_;
+ std::atomic stopping_;
+ int worldSize_;
+ std::unique_ptr observer_;
+ std::shared_ptr resourceManager_;
+};
+
+} // namespace torchrec
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h
new file mode 100644
index 000000000..5667d1ad2
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+
+namespace torchrec {
+
+// We have different error code defined for different kinds of exceptions in
+// fblearner/sigrid predictor. (Code pointer:
+// fblearner/predictor/if/prediction_service.thrift.) We define different
+// exception type here so that in fblearner/sigrid predictor we can detect the
+// exception type and return the corresponding error code to reflect the right
+// info.
+class TorchrecException : public std::runtime_error {
+ public:
+ explicit TorchrecException(const std::string& error)
+ : std::runtime_error(error) {}
+};
+
+// GPUOverloadException maps to
+// PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT
+class GPUOverloadException : public TorchrecException {
+ public:
+ explicit GPUOverloadException(const std::string& error)
+ : TorchrecException(error) {}
+};
+
+// GPUExecutorOverloadException maps to
+// PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT
+class GPUExecutorOverloadException : public TorchrecException {
+ public:
+ explicit GPUExecutorOverloadException(const std::string& error)
+ : TorchrecException(error) {}
+};
+
+// TorchDeployException maps to
+// PredictorUserErrorCode::TORCH_DEPLOY_ERROR
+class TorchDeployException : public TorchrecException {
+ public:
+ explicit TorchDeployException(const std::string& error)
+ : TorchrecException(error) {}
+};
+} // namespace torchrec
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h
new file mode 100644
index 000000000..491acf48f
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+#include
+
+#include
+
+#include "torchrec/inference/Exception.h"
+#include "torchrec/inference/Types.h"
+
+namespace torchrec {
+template
+void handleRequestException(
+ folly::Promise>& promise,
+ const std::string& msg) {
+ auto ex = folly::make_exception_wrapper(msg);
+ auto response = std::make_unique();
+ response->exception = std::move(ex);
+ promise.setValue(std::move(response));
+}
+
+template
+void handleBatchException(
+ std::vector& contexts,
+ const std::string& msg) {
+ for (auto& context : contexts) {
+ handleRequestException(context.promise, msg);
+ }
+}
+
+} // namespace torchrec
diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h
new file mode 100644
index 000000000..b55dacc5a
--- /dev/null
+++ b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+#include