From eec5df3d8d305beea227e9f0bd8416247346f7dd Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 14 Jun 2024 13:00:08 -0700 Subject: [PATCH] Fix TorchRec inference solution with updated torch.fx + TorchScript Solution (#2101) Summary: Previously, the TorchRec inference solution used torch.deploy, which isn't supported, and the solution itself was outdated and broken. This PR revamps the inference example with Torch.FX and TorchScript, which more accurately represents what is used in production for RecSys currently. Furthermore, the example was constructed with simplicity as the top priority, with fewer package requirements and steps to run. This is the first iteration of the inference example, with more to come! Pull Request resolved: https://github.com/pytorch/torchrec/pull/2101 Differential Revision: D58478461 Pulled By: PaulZhang12 --- .../{inference => inference_legacy}/README.md | 0 .../dlrm_client.py | 0 .../dlrm_packager.py | 0 .../dlrm_predict.py | 0 .../dlrm_predict_single_gpu.py | 0 torchrec/inference/CMakeLists.txt | 124 ++---- torchrec/inference/README.md | 198 ++------- torchrec/inference/__init__.py | 17 - torchrec/inference/client.py | 21 +- torchrec/inference/dlrm_packager.py | 112 +++++ torchrec/inference/dlrm_predict.py | 221 ++++++++++ .../inference/inference_legacy/CMakeLists.txt | 113 +++++ torchrec/inference/inference_legacy/README.md | 239 +++++++++++ .../inference/inference_legacy/__init__.py | 25 ++ torchrec/inference/inference_legacy/client.py | 130 ++++++ .../{ => inference_legacy}/docker/Dockerfile | 0 .../{ => inference_legacy}/docs/inference.rst | 0 .../include/torchrec/inference/Assert.h | 52 +++ .../include/torchrec/inference/Batching.h | 74 ++++ .../torchrec/inference/BatchingQueue.h | 107 +++++ .../include/torchrec/inference/Exception.h | 49 +++ .../torchrec/inference/ExceptionHandler.h | 39 ++ .../include/torchrec/inference/GPUExecutor.h | 99 +++++ .../include/torchrec/inference/JaggedTensor.h | 31 ++ .../include/torchrec/inference/Observer.h | 280 ++++++++++++ .../torchrec/inference/ResourceManager.h | 74 ++++ .../include/torchrec/inference/ResultSplit.h | 68 +++ .../torchrec/inference/ShardedTensor.h | 48 +++ .../torchrec/inference/SingleGPUExecutor.h | 63 +++ .../include/torchrec/inference/TestUtils.h | 37 ++ .../include/torchrec/inference/Types.h | 145 +++++++ .../include/torchrec/inference/Validation.h | 30 ++ .../inference_legacy/model_packager.py | 119 +++++ .../inference_legacy/protos/predictor.proto | 50 +++ .../inference/inference_legacy/server.cpp | 336 +++++++++++++++ .../{ => inference_legacy}/src/Batching.cpp | 0 .../src/BatchingQueue.cpp | 0 .../inference_legacy/src/Executer2.cpp | 0 .../src/GPUExecutor.cpp | 0 .../src/ResourceManager.cpp | 0 .../src/ResultSplit.cpp | 0 .../src/SingleGPUExecutor.cpp | 0 .../{ => inference_legacy}/src/TestUtils.cpp | 0 .../{ => inference_legacy}/src/Validation.cpp | 0 .../inference_legacy/state_dict_transform.py | 103 +++++ .../tests/BatchingQueueTest.cpp | 0 .../tests/BatchingTest.cpp | 0 .../tests/ResultSplitTest.cpp | 0 .../tests/SingleGPUExecutorMultiGPUTest.cpp | 0 .../tests/SingleGPUExecutorTest.cpp | 0 .../tests/ValidationTest.cpp | 0 .../tests/generate_test_packages.py | 0 .../tests/model_packager_tests.py | 1 - .../tests/predict_module_tests.py | 2 - .../tests/test_modules.py | 0 torchrec/inference/modules.py | 4 +- torchrec/inference/server.cpp | 405 ++++++------------ 57 files changed, 2849 insertions(+), 567 deletions(-) rename examples/{inference => inference_legacy}/README.md (100%) rename examples/{inference => inference_legacy}/dlrm_client.py (100%) rename examples/{inference => inference_legacy}/dlrm_packager.py (100%) rename examples/{inference => inference_legacy}/dlrm_predict.py (100%) rename examples/{inference => inference_legacy}/dlrm_predict_single_gpu.py (100%) create mode 100644 torchrec/inference/dlrm_packager.py create mode 100644 torchrec/inference/dlrm_predict.py create mode 100644 torchrec/inference/inference_legacy/CMakeLists.txt create mode 100644 torchrec/inference/inference_legacy/README.md create mode 100644 torchrec/inference/inference_legacy/__init__.py create mode 100644 torchrec/inference/inference_legacy/client.py rename torchrec/inference/{ => inference_legacy}/docker/Dockerfile (100%) rename torchrec/inference/{ => inference_legacy}/docs/inference.rst (100%) create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Types.h create mode 100644 torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h create mode 100644 torchrec/inference/inference_legacy/model_packager.py create mode 100644 torchrec/inference/inference_legacy/protos/predictor.proto create mode 100644 torchrec/inference/inference_legacy/server.cpp rename torchrec/inference/{ => inference_legacy}/src/Batching.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/BatchingQueue.cpp (100%) create mode 100644 torchrec/inference/inference_legacy/src/Executer2.cpp rename torchrec/inference/{ => inference_legacy}/src/GPUExecutor.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/ResourceManager.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/ResultSplit.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/SingleGPUExecutor.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/TestUtils.cpp (100%) rename torchrec/inference/{ => inference_legacy}/src/Validation.cpp (100%) create mode 100644 torchrec/inference/inference_legacy/state_dict_transform.py rename torchrec/inference/{ => inference_legacy}/tests/BatchingQueueTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/BatchingTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/ResultSplitTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/SingleGPUExecutorMultiGPUTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/SingleGPUExecutorTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/ValidationTest.cpp (100%) rename torchrec/inference/{ => inference_legacy}/tests/generate_test_packages.py (100%) rename torchrec/inference/{ => inference_legacy}/tests/model_packager_tests.py (99%) rename torchrec/inference/{ => inference_legacy}/tests/predict_module_tests.py (99%) rename torchrec/inference/{ => inference_legacy}/tests/test_modules.py (100%) diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md similarity index 100% rename from examples/inference/README.md rename to examples/inference_legacy/README.md diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py similarity index 100% rename from examples/inference/dlrm_client.py rename to examples/inference_legacy/dlrm_client.py diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py similarity index 100% rename from examples/inference/dlrm_packager.py rename to examples/inference_legacy/dlrm_packager.py diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py similarity index 100% rename from examples/inference/dlrm_predict.py rename to examples/inference_legacy/dlrm_predict.py diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py similarity index 100% rename from examples/inference/dlrm_predict_single_gpu.py rename to examples/inference_legacy/dlrm_predict_single_gpu.py diff --git a/torchrec/inference/CMakeLists.txt b/torchrec/inference/CMakeLists.txt index 794bb1490..3b3ac2a8d 100644 --- a/torchrec/inference/CMakeLists.txt +++ b/torchrec/inference/CMakeLists.txt @@ -4,110 +4,60 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.13 FATAL_ERROR) -project(inference) +cmake_minimum_required(VERSION 3.8) -# This step is crucial to ensure that the -# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. -# e.g. ~/gprc/examples/cpp/cmake/common.cmake -include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) +project(inference C CXX) +include(/home/paulzhan/grpc/examples/cpp/cmake/common.cmake) -# abi and other flags -if(DEFINED GLIBCXX_USE_CXX11_ABI) - if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") - set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") - endif() -endif() +# Proto file +get_filename_component(hw_proto "/home/paulzhan/torchrec/torchrec/inference/protos/predictor.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) -# keep it static for now since folly-shared version is broken -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") -# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.h") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") -# dependencies -find_package(Boost REQUIRED) -find_package(Torch REQUIRED) -find_package(folly REQUIRED) -find_package(gflags REQUIRED) - -include_directories(${Torch_INCLUDE_DIRS}) -include_directories(${folly_INCLUDE_DIRS}) -include_directories(${PYTORCH_FMT_INCLUDE_PATH}) +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") set(CMAKE_CXX_STANDARD 17) -# torch deploy library -add_library(torch_deploy_internal STATIC - ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o - ${DEPLOY_SRC_PATH}/deploy.cpp - ${DEPLOY_SRC_PATH}/loader.cpp - ${DEPLOY_SRC_PATH}/path_environment.cpp - ${DEPLOY_SRC_PATH}/elf_file.cpp) - -# For python builtins. caffe2_interface_library properly -# makes use of the --whole-archive option. -target_link_libraries(torch_deploy_internal PRIVATE - crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw -) -target_link_libraries(torch_deploy_internal - PUBLIC shm torch ${PYTORCH_LIB_FMT} -) -caffe2_interface_library(torch_deploy_internal torch_deploy) - -# inference library - -# for our own header files -include_directories(include/) -include_directories(gen/) - -# define our library target -add_library(inference STATIC - src/Batching.cpp - src/BatchingQueue.cpp - src/GPUExecutor.cpp - src/ResultSplit.cpp - src/Exception.cpp - src/ResourceManager.cpp -) - -# -rdynamic is needed to link against the static library -target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" - dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} -) - -# for generated protobuf - -# grpc headers. e.g. ~/.local/include -include_directories(${GRPC_HEADER_INCLUDE_PATH}) - -set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") -set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") -set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") -set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") +# Torch + FBGEMM +find_package(Torch REQUIRED) +add_library( fbgemm SHARED IMPORTED GLOBAL ) +set_target_properties(fbgemm PROPERTIES IMPORTED_LOCATION ${FBGEMM_LIB}) -add_library(pred_grpc_proto STATIC - ${pred_grpc_srcs} - ${pred_grpc_hdrs} - ${pred_proto_srcs} - ${pred_proto_hdrs}) -target_link_libraries(pred_grpc_proto +add_library(hw_grpc_proto STATIC + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) +target_link_libraries(hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF}) -# server +# Targets greeter_[async_](client|server) add_executable(server server.cpp) target_link_libraries(server - inference - torch_deploy - pred_grpc_proto "${TORCH_LIBRARIES}" - ${FOLLY_LIBRARIES} - ${PYTORCH_LIB_FMT} - ${FBGEMM_LIB} + fbgemm + hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} - ${_PROTOBUF_LIBPROTOBUF}) + ${_PROTOBUF_LIBPROTOBUF} +) diff --git a/torchrec/inference/README.md b/torchrec/inference/README.md index 8958fc72d..0a8e1cc17 100644 --- a/torchrec/inference/README.md +++ b/torchrec/inference/README.md @@ -1,116 +1,65 @@ # TorchRec Inference Library (**Experimental** Release) ## Overview -TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. +TorchRec Inference is a C++ library that supports **gpu inference**. Previously, the TorchRec inference library was authored with torch.package and torch.deploy, which are old and deprecated. All the previous files live under the directory inference_legacy for reference. -Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. +TorchRec inference was reauthored with simplicity in mind, while also reflecting the current production environment for RecSys models, namely torch.fx for graph capturing/tracing and TorchScript for model inference in a C++ environment. The inference solution here is meant to serve as a simple reference and example, not a fully scaled out solution for production use cases. The current solution demonstrates converting the DLRM model in Python to TorchScript, running a C++ inference server with the model on a GPU, and sending requests to said server via a python client. -## Example +## Requirements -C++ 17 is a requirement. +C++ 17 is a requirement. GCC version has to be >= 9, with initial testing done on GCC 9.
### **1. Install Dependencies** - -Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy -is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: - +1. [GRPC for C++][https://grpc.io/docs/languages/cpp/quickstart/] needs to be installed, with the resulting installation directory being `$HOME/.local` +2. Ensure that **the protobuf compiler (protoc) binary being used is from the GRPC installation above**. The protoc binary will live in `$HOME/.local/bin`, which may not match with the system protoc binary, can check with `which protoc`. +3. Install PyTorch, FBGEMM, and TorchRec (ideally in a virtual environment): ``` -sudo nvidia-docker build -t torchrec . -sudo nvidia-docker run -it torchrec:latest +pip install torch --index-url https://download.pytorch.org/whl/cu121 +pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121 +pip install torchmetrics==1.0.3 +pip install torchrec --index-url https://download.pytorch.org/whl/cu121 ``` + ### **2. Set variables** Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. ``` -# provide the cmake prefix path of pytorch, folly, and fmt. -# fmt and boost are pulled from folly's installation in this example. -export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" -export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" -export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" - -# provide fmt from pytorch for torch deploy -export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" -export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" - -# provide necessary info to link to torch deploy -export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" -export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" - -# provide common.cmake from grpc/examples, makes linking to grpc easier -export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" -export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" - -# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators -export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" - -# provide path to python packages for torch deploy runtime -export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" -``` - -Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. +# provide fbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +find $HOME -name fbgemm_gpu_py.so +# Use path from correct virtual environment above and set environment variable $FBGEMM_LIB to it +export FBGEMM_LIB="" ``` -# double-conversion, fmt and gflags are pulled from folly's installation in this example -export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" -export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" -export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" -export PYTORCH_LIB_PATH="~/pytorch/build/lib/" -export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" -export LIBRARY_PATH="$PYTORCH_LIB_PATH" -``` +### **3. Generate TorchScripted DLRM model** -### **3. Package DLRM model** - -The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement -`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and -implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read -more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. - -`/torchrec/examples/inference/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference/dlrm_predict.py`). -`DLRMPredictModule` is packaged for inference in the following example. +Here, we generate the DLRM model in Torchscript and save it for model loading later on. ``` git clone https://github.com/pytorch/torchrec.git -cd ~/torchrec/examples/inference/ -python dlrm_packager.py --output_path /tmp/model_package.zip +cd ~/torchrec/torchrec/inference/ +python3 dlrm_packager.py --output_path /tmp/model.pt ``` - ### **4. Build inference library and example server** -Generate protobuf C++ and Python code from protobuf +Generate Python code from protobuf for client and build the server. ``` -cd ~/torchrec/inference/ -mkdir -p gen/torchrec/inference - -# C++ (server) -protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto - -protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto - - # Python (client) -python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +python -m grpc_tools.protoc -I protos --python_out=. --grpc_python_out=. protos/predictor.proto ``` -Build inference library and example server +Build server and C++ protobufs ``` -cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" --DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ --DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ --DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ --DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ --DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ --DFBGEMM_LIB="$FBGEMM_LIB" +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');" -DFBGEMM_LIB="$FBGEMM_LIB" cd build make -j @@ -119,28 +68,20 @@ make -j ### **5. Run server and client** -Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +Start the server, loading in the model saved previously ``` -CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +./server /tmp/model.pt ``` **output** -In the logs, a plan should be outputted by the Torchrec planner: - -``` -INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # -INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # -INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # -INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # -INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # -INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # -INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # -INFO:.torchrec.distributed.planner.stats:# quant: 26 # +In the logs, you should see: + +``` +Loading model... +Sanity Check with dummy inputs + Model Forward Completed, Output: 0.489247 +Server listening on 0.0.0.0:50051 ```` `nvidia-smi` output should also show allocation of the model onto the gpu: @@ -155,7 +96,7 @@ INFO:.torchrec.distributed.planner.stats:# quant: 26 +-----------------------------------------------------------------------------+ ``` -Make a request to the server via the client: +In another terminal instance, make a request to the server via the client: ``` python client.py @@ -166,74 +107,3 @@ python client.py ``` Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] ``` - -
- -## Planned work - -- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference -- In-code documentation -- Simplify installation process - -
- -## Potential issues and solutions - -Skip this section if you had no issues with installation or running the example. - -**Missing header files during pytorch installation** - -If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: - -``` -~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory - #include - ^~~~~~~~ -compilation terminated. -[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o -ninja: build stopped: subcommand failed. -``` - -To get these header files, install `cudatoolkit-dev`: -``` -conda install -c conda-forge cudatoolkit-dev -``` - -Re-run the installation after this. - -**libdouble-conversion missing** -``` -~/torchrec/torchrec/inference/build$ ./example -./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory -``` - -If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: - -``` -sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ -libdouble-conversion.so.3 -``` - -**Two installations of glog** -``` -~/torchrec/torchrec/inference/build$ ./example -ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and -'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). -``` -The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was -previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. - -To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. - -**Undefined symbols with std::string or cxx11** - -If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely -that your dependencies were compiled with different ABI values. Re-compile your dependencies -and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. - -The ABI value of pytorch can be checked via: - -``` -import torch -torch._C._GLIBCXX_USE_CXX11_ABI -``` diff --git a/torchrec/inference/__init__.py b/torchrec/inference/__init__.py index 6c2050114..a0ce8680a 100644 --- a/torchrec/inference/__init__.py +++ b/torchrec/inference/__init__.py @@ -7,21 +7,4 @@ # pyre-strict -"""Torchrec Inference - -Torchrec inference provides a Torch.Deploy based library for GPU inference. - -These includes: - - Model packaging in Python - - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. - - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. - - Model serving in C++ - - `BatchingQueue` is a generalized config-based request tensor batching implementation. - - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. - -We implemented an example of how to use this library with the TorchRec DLRM model. - - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. - - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. -""" - from . import model_packager, modules # noqa # noqa diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py index c185f3909..725338f46 100644 --- a/torchrec/inference/client.py +++ b/torchrec/inference/client.py @@ -9,8 +9,8 @@ import logging import grpc +import predictor_pb2, predictor_pb2_grpc import torch -from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc from torch.utils.data import DataLoader from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset @@ -20,18 +20,13 @@ def create_training_batch(args: argparse.Namespace) -> Batch: return next( iter( - DataLoader( - RandomRecDataset( - keys=DEFAULT_CAT_NAMES, - batch_size=args.batch_size, - hash_size=args.num_embedding_features, - ids_per_feature=1, - num_dense=len(DEFAULT_INT_NAMES), - ), - batch_sampler=None, - pin_memory=False, - num_workers=0, - ) + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), ) ) diff --git a/torchrec/inference/dlrm_packager.py b/torchrec/inference/dlrm_packager.py new file mode 100644 index 000000000..560e31c16 --- /dev/null +++ b/torchrec/inference/dlrm_packager.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import sys +from typing import List + +from dlrm_predict import create_training_batch, DLRMModelConfig, DLRMPredictFactory +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec dlrm model packager") + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default="45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138," + "10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35", + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--sparse_feature_names", + type=str, + default=",".join(DEFAULT_CAT_NAMES), + help="Comma separated names of the sparse features.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--num_dense_features", + type=int, + default=len(DEFAULT_INT_NAMES), + help="Number of dense features.", + ) + parser.add_argument( + "--output_path", + type=str, + help="Output path of model package.", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + """ + Use torch.package to package the torchrec DLRM Model. + + Args: + argv (List[str]): command line args. + + Returns: + None. + """ + + args = parse_args(argv) + + args.batch_size = 10 + args.num_embedding_features = 26 + batch = create_training_batch(args) + + model_config = DLRMModelConfig( + dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))), + dense_in_features=args.num_dense_features, + embedding_dim=args.embedding_dim, + id_list_features_keys=args.sparse_feature_names.split(","), + num_embeddings_per_feature=list( + map(int, args.num_embeddings_per_feature.split(",")) + ), + num_embeddings=args.num_embeddings, + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + sample_input=batch, + ) + + script_module = DLRMPredictFactory(model_config).create_predict_module(world_size=1) + + script_module.save(args.output_path) + print(f"Package is saved to {args.output_path}") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/inference/dlrm_predict.py b/torchrec/inference/dlrm_predict.py new file mode 100644 index 000000000..3ac2966ad --- /dev/null +++ b/torchrec/inference/dlrm_predict.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch +from torchrec.fx.tracer import Tracer +from torchrec.inference.modules import ( + PredictFactory, + PredictModule, + quantize_inference_model, + shard_quant_model, +) +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_training_batch(args) -> Batch: + return RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ).batch_generator._generate_batch() + + +# OSS Only + + +@dataclass +class DLRMModelConfig: + dense_arch_layer_sizes: List[int] + dense_in_features: int + embedding_dim: int + id_list_features_keys: List[str] + num_embeddings_per_feature: List[int] + num_embeddings: int + over_arch_layer_sizes: List[int] + sample_input: Batch + + +class DLRMPredictModule(PredictModule): + """ + nn.Module to wrap DLRM model to use for inference. + + Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (List[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (List[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + id_list_features_keys (List[str]): the names of the sparse features. Used to + construct a batch for inference. + dense_device: (Optional[torch.device]). + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + id_list_features_keys: List[str], + dense_device: Optional[torch.device] = None, + ) -> None: + module = DLRM( + embedding_bag_collection=embedding_bag_collection, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=dense_device, + ) + super().__init__(module) + + self.id_list_features_keys: List[str] = id_list_features_keys + + def predict_forward( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Args: + batch (Dict[str, torch.Tensor]): currently expects input dense features + to be mapped to the key "float_features" and input sparse features + to be mapped to the key "id_list_features". + + Returns: + Dict[str, torch.Tensor]: output of inference. + """ + + try: + logits = self.predict_module( + batch["float_features"], + KeyedJaggedTensor( + keys=self.id_list_features_keys, + lengths=batch["id_list_features.lengths"], + values=batch["id_list_features.values"], + ), + ) + predictions = logits.sigmoid() + except Exception as e: + logger.info(e) + raise e + + # Flip predictions tensor to be 1D. TODO: Determine why prediction shape + # can be 2D at times (likely due to input format?) + predictions = predictions.reshape( + [ + predictions.size()[0], + ] + ) + + return { + "default": predictions.to(torch.device("cpu"), non_blocking=True).float() + } + + +class DLRMPredictFactory(PredictFactory): + def __init__(self, model_config: DLRMModelConfig) -> None: + self.model_config = model_config + + def create_predict_module(self, world_size: int) -> torch.nn.Module: + logging.basicConfig(level=logging.INFO) + device = torch.device("cuda:0") + + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=self.model_config.embedding_dim, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate( + self.model_config.id_list_features_keys + ) + ] + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + + module = DLRMPredictModule( + embedding_bag_collection=ebc, + dense_in_features=self.model_config.dense_in_features, + dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, + over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, + id_list_features_keys=self.model_config.id_list_features_keys, + dense_device=device, + ) + + table_fqns = [] + for name, _ in module.named_modules(): + if "t_" in name: + table_fqns.append(name.split(".")[-1]) + + quant_model = quantize_inference_model(module) + sharded_model, _ = shard_quant_model(quant_model, table_fqns) + + batch = {} + batch["float_features"] = self.model_config.sample_input.dense_features.cuda() + batch["id_list_features.lengths"] = ( + self.model_config.sample_input.sparse_features.lengths().cuda() + ) + batch["id_list_features.values"] = ( + self.model_config.sample_input.sparse_features.values().cuda() + ) + + sharded_model(batch) + + tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"]) + + graph = tracer.trace(sharded_model) + gm = torch.fx.GraphModule(sharded_model, graph) + + gm(batch) + scripted_gm = torch.jit.script(gm) + scripted_gm(batch) + return scripted_gm + + def batching_metadata(self) -> Dict[str, str]: + return { + "float_features": "dense", + "id_list_features": "sparse", + } + + def result_metadata(self) -> str: + return "dict_of_tensor" + + def run_weights_independent_tranformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + return predict_module + + def run_weights_dependent_transformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + """ + Run transformations that depends on weights of the predict module. e.g. lowering to a backend. + """ + return predict_module diff --git a/torchrec/inference/inference_legacy/CMakeLists.txt b/torchrec/inference/inference_legacy/CMakeLists.txt new file mode 100644 index 000000000..794bb1490 --- /dev/null +++ b/torchrec/inference/inference_legacy/CMakeLists.txt @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) +project(inference) + +# This step is crucial to ensure that the +# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. +# e.g. ~/gprc/examples/cpp/cmake/common.cmake +include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) + + +# abi and other flags + +if(DEFINED GLIBCXX_USE_CXX11_ABI) + if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") + endif() +endif() + +# keep it static for now since folly-shared version is broken +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") + +# dependencies +find_package(Boost REQUIRED) +find_package(Torch REQUIRED) +find_package(folly REQUIRED) +find_package(gflags REQUIRED) + +include_directories(${Torch_INCLUDE_DIRS}) +include_directories(${folly_INCLUDE_DIRS}) +include_directories(${PYTORCH_FMT_INCLUDE_PATH}) + +set(CMAKE_CXX_STANDARD 17) + +# torch deploy library +add_library(torch_deploy_internal STATIC + ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o + ${DEPLOY_SRC_PATH}/deploy.cpp + ${DEPLOY_SRC_PATH}/loader.cpp + ${DEPLOY_SRC_PATH}/path_environment.cpp + ${DEPLOY_SRC_PATH}/elf_file.cpp) + +# For python builtins. caffe2_interface_library properly +# makes use of the --whole-archive option. +target_link_libraries(torch_deploy_internal PRIVATE + crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw +) +target_link_libraries(torch_deploy_internal + PUBLIC shm torch ${PYTORCH_LIB_FMT} +) +caffe2_interface_library(torch_deploy_internal torch_deploy) + +# inference library + +# for our own header files +include_directories(include/) +include_directories(gen/) + +# define our library target +add_library(inference STATIC + src/Batching.cpp + src/BatchingQueue.cpp + src/GPUExecutor.cpp + src/ResultSplit.cpp + src/Exception.cpp + src/ResourceManager.cpp +) + +# -rdynamic is needed to link against the static library +target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" + dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} +) + +# for generated protobuf + +# grpc headers. e.g. ~/.local/include +include_directories(${GRPC_HEADER_INCLUDE_PATH}) + +set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") +set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") +set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") +set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") + +add_library(pred_grpc_proto STATIC + ${pred_grpc_srcs} + ${pred_grpc_hdrs} + ${pred_proto_srcs} + ${pred_proto_hdrs}) + +target_link_libraries(pred_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) + +# server +add_executable(server server.cpp) +target_link_libraries(server + inference + torch_deploy + pred_grpc_proto + "${TORCH_LIBRARIES}" + ${FOLLY_LIBRARIES} + ${PYTORCH_LIB_FMT} + ${FBGEMM_LIB} + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) diff --git a/torchrec/inference/inference_legacy/README.md b/torchrec/inference/inference_legacy/README.md new file mode 100644 index 000000000..fc3d8afcb --- /dev/null +++ b/torchrec/inference/inference_legacy/README.md @@ -0,0 +1,239 @@ +# TorchRec Inference Library (**Experimental** Release) + +## Overview +TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. + +Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. + +## Example + +C++ 17 is a requirement. + +
+ +### **1. Install Dependencies** + +Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy +is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: + +``` +sudo nvidia-docker build -t torchrec . +sudo nvidia-docker run -it torchrec:latest +``` + +### **2. Set variables** + +Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. + +``` +# provide the cmake prefix path of pytorch, folly, and fmt. +# fmt and boost are pulled from folly's installation in this example. +export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" +export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" +export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" + +# provide fmt from pytorch for torch deploy +export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" +export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" + +# provide necessary info to link to torch deploy +export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" +export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" + +# provide common.cmake from grpc/examples, makes linking to grpc easier +export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" +export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" + +# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" + +# provide path to python packages for torch deploy runtime +export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" +``` + +Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. + +``` +# double-conversion, fmt and gflags are pulled from folly's installation in this example +export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" +export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" +export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" +export PYTORCH_LIB_PATH="~/pytorch/build/lib/" + +export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" +export LIBRARY_PATH="$PYTORCH_LIB_PATH" +``` + +### **3. Package DLRM model** + +The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement +`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and +implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read +more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. + +`/torchrec/examples/inference_legacy/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference_legacy/dlrm_predict.py`). +`DLRMPredictModule` is packaged for inference in the following example. + +``` +git clone https://github.com/pytorch/torchrec.git + +cd ~/torchrec/examples/inference_legacy/ +python dlrm_packager.py --output_path /tmp/model_package.zip +``` + + + +### **4. Build inference library and example server** + +Generate protobuf C++ and Python code from protobuf + +``` +cd ~/torchrec/inference/ +mkdir -p gen/torchrec/inference + +# C++ (server) +protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto + +protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto + + +# Python (client) +python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +``` + + +Build inference library and example server +``` +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" +-DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ +-DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ +-DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ +-DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ +-DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ +-DFBGEMM_LIB="$FBGEMM_LIB" + +cd build +make -j +``` + + +### **5. Run server and client** + +Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +``` +CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +``` + +**output** + +In the logs, a plan should be outputted by the Torchrec planner: + +``` +INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # +INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # +INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # +INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # +INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # +INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # +INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # +INFO:.torchrec.distributed.planner.stats:# quant: 26 # +```` + +`nvidia-smi` output should also show allocation of the model onto the gpu: + +``` ++-----------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=============================================================================| +| 0 N/A N/A 86668 C ./example 1357MiB | ++-----------------------------------------------------------------------------+ +``` + +Make a request to the server via the client: + +``` +python client.py +``` + +**output** + +``` +Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] +``` + +
+ +## Planned work + +- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference +- In-code documentation +- Simplify installation process + +
+ +## Potential issues and solutions + +Skip this section if you had no issues with installation or running the example. + +**Missing header files during pytorch installation** + +If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: + +``` +~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory + #include + ^~~~~~~~ +compilation terminated. +[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o +ninja: build stopped: subcommand failed. +``` + +To get these header files, install `cudatoolkit-dev`: +``` +conda install -c conda-forge cudatoolkit-dev +``` + +Re-run the installation after this. + +**libdouble-conversion missing** +``` +~/torchrec/torchrec/inference/build$ ./example +./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory +``` + +If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: + +``` +sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ +libdouble-conversion.so.3 +``` + +**Two installations of glog** +``` +~/torchrec/torchrec/inference/build$ ./example +ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and +'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). +``` +The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was +previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. + +To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. + +**Undefined symbols with std::string or cxx11** + +If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely +that your dependencies were compiled with different ABI values. Re-compile your dependencies +and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. + +The ABI value of pytorch can be checked via: + +``` +import torch +torch._C._GLIBCXX_USE_CXX11_ABI +``` diff --git a/torchrec/inference/inference_legacy/__init__.py b/torchrec/inference/inference_legacy/__init__.py new file mode 100644 index 000000000..670f2af78 --- /dev/null +++ b/torchrec/inference/inference_legacy/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Torchrec Inference + +Torchrec inference provides a Torch.Deploy based library for GPU inference. + +These includes: + - Model packaging in Python + - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. + - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. + - Model serving in C++ + - `BatchingQueue` is a generalized config-based request tensor batching implementation. + - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. + +We implemented an example of how to use this library with the TorchRec DLRM model. + - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. + - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. +""" + +from . import model_packager, modules # noqa # noqa diff --git a/torchrec/inference/inference_legacy/client.py b/torchrec/inference/inference_legacy/client.py new file mode 100644 index 000000000..a3a9f2a83 --- /dev/null +++ b/torchrec/inference/inference_legacy/client.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import logging + +import grpc +import torch +from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch + + +def create_training_batch(args: argparse.Namespace) -> Batch: + return next( + iter( + DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), + batch_sampler=None, + pin_memory=False, + num_workers=0, + ) + ) + ) + + +def create_request( + batch: Batch, args: argparse.Namespace +) -> predictor_pb2.PredictionRequest: + def to_bytes(tensor: torch.Tensor) -> bytes: + return tensor.cpu().numpy().tobytes() + + float_features = predictor_pb2.FloatFeatures( + num_features=args.num_float_features, + values=to_bytes(batch.dense_features), + ) + + id_list_features = predictor_pb2.SparseFeatures( + num_features=args.num_id_list_features, + values=to_bytes(batch.sparse_features.values()), + lengths=to_bytes(batch.sparse_features.lengths()), + ) + + id_score_list_features = predictor_pb2.SparseFeatures(num_features=0) + embedding_features = predictor_pb2.FloatFeatures(num_features=0) + unary_features = predictor_pb2.SparseFeatures(num_features=0) + + return predictor_pb2.PredictionRequest( + batch_size=args.batch_size, + float_features=float_features, + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + embedding_features=embedding_features, + unary_features=unary_features, + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--ip", + type=str, + default="0.0.0.0", + ) + parser.add_argument( + "--port", + type=int, + default=50051, + ) + parser.add_argument( + "--num_float_features", + type=int, + default=13, + ) + parser.add_argument( + "--num_id_list_features", + type=int, + default=26, + ) + parser.add_argument( + "--num_id_score_list_features", + type=int, + default=0, + ) + parser.add_argument( + "--num_embedding_features", + type=int, + default=100000, + ) + parser.add_argument( + "--embedding_feature_dim", + type=int, + default=100, + ) + parser.add_argument( + "--batch_size", + type=int, + default=100, + ) + + args: argparse.Namespace = parser.parse_args() + + training_batch: Batch = create_training_batch(args) + request: predictor_pb2.PredictionRequest = create_request(training_batch, args) + + with grpc.insecure_channel(f"{args.ip}:{args.port}") as channel: + stub = predictor_pb2_grpc.PredictorStub(channel) + response = stub.Predict(request) + print("Response: ", response.predictions["default"].data) + +if __name__ == "__main__": + logging.basicConfig() diff --git a/torchrec/inference/docker/Dockerfile b/torchrec/inference/inference_legacy/docker/Dockerfile similarity index 100% rename from torchrec/inference/docker/Dockerfile rename to torchrec/inference/inference_legacy/docker/Dockerfile diff --git a/torchrec/inference/docs/inference.rst b/torchrec/inference/inference_legacy/docs/inference.rst similarity index 100% rename from torchrec/inference/docs/inference.rst rename to torchrec/inference/inference_legacy/docs/inference.rst diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h new file mode 100644 index 000000000..26e7987f9 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Internal Assertion failed: (" + std::string(#condition) + "), " + \ + "function " + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + \ + "Please report bug to TorchRec.\n" + message + "\n"); \ + } + +#define TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(condition) \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(#condition, "") + +#define TORCHREC_INTERNAL_ASSERT_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_INTERNAL_ASSERT(...) \ + TORCHREC_INTERNAL_ASSERT_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(__VA_ARGS__)); + +#define TORCHREC_CHECK_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Check failed: (" + std::string(#condition) + "), " + "function " + \ + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + message + "\n"); \ + } + +#define TORCHREC_CHECK_NO_MESSAGE(condition) \ + TORCHREC_CHECK_WITH_MESSAGE(#condition, "") + +#define TORCHREC_CHECK_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_CHECK(...) \ + TORCHREC_CHECK_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_CHECK_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_CHECK_NO_MESSAGE(__VA_ARGS__)); diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h new file mode 100644 index 000000000..2c5e8837c --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using LazyTensorRef = folly::detail::Lazy>&; + +// BatchingFunc should be responsible to move the output tensor to desired +// location using the device input. +class BatchingFunc { + public: + virtual ~BatchingFunc() = default; + + virtual std::unordered_map batch( + const std::string& /* featureName */, + const std::vector>& /* requests */, + const int64_t& /* totalNumBatch */, + LazyTensorRef /* batchOffsets */, + const c10::Device& /* device */, + LazyTensorRef /* batchItems */) = 0; +}; + +/** + * TorchRecBatchingFuncRegistry is used to register custom batching functions. + */ +C10_DECLARE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc); + +#define REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY(name, priority, ...) \ + C10_REGISTER_CLASS_WITH_PRIORITY( \ + TorchRecBatchingFuncRegistry, name, priority, __VA_ARGS__); + +#define REGISTER_TORCHREC_BATCHING_FUNC(name, ...) \ + REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY( \ + name, c10::REGISTRY_DEFAULT, __VA_ARGS__); + +std::unordered_map combineFloat( + const std::string& featureName, + const std::vector>& requests); + +std::unordered_map combineSparse( + const std::string& featureName, + const std::vector>& requests, + bool isWeighted); + +std::unordered_map combineEmbedding( + const std::string& featureName, + const std::vector>& requests); + +void moveIValueToDevice(c10::IValue& val, const c10::Device& device); + +std::unordered_map moveToDevice( + std::unordered_map combined, + const c10::Device& device); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h new file mode 100644 index 000000000..f012572b2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torchrec/inference/Batching.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResourceManager.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using BatchQueueCb = std::function)>; + +class BatchingQueue { + public: + struct Config { + std::chrono::milliseconds batchingInterval = std::chrono::milliseconds(10); + std::chrono::milliseconds queueTimeout = std::chrono::milliseconds(500); + int numExceptionThreads = 4; + int numMemPinnerThreads = 4; + int maxBatchSize = 2000; + // For feature name to BatchingFunc name. + const std::unordered_map batchingMetadata; + std::function eventCreationFn; + std::function warmupFn; + }; + + BatchingQueue(const BatchingQueue&) = delete; + BatchingQueue& operator=(const BatchingQueue&) = delete; + + BatchingQueue( + std::vector cbs, + const Config& config, + int worldSize, + std::unique_ptr observer, + std::shared_ptr resourceManager = nullptr); + ~BatchingQueue(); + + void add( + std::shared_ptr request, + folly::Promise> promise); + + void stop(); + + private: + struct QueryQueueEntry { + std::shared_ptr request; + RequestContext context; + std::chrono::time_point addedTime; + }; + + struct BatchingQueueEntry { + std::vector> requests; + std::vector contexts; + std::chrono::time_point addedTime; + }; + + void createBatch(); + + void pinMemory(int gpuIdx); + + void observeBatchCompletion(size_t batchSizeBytes, size_t numRequests); + + const Config config_; + + // Batching func name to batching func instance. + std::unordered_map> batchingFuncs_; + std::vector cbs_; + std::thread batchingThread_; + std::vector memPinnerThreads_; + std::unique_ptr rejectionExecutor_; + folly::Synchronized> requestQueue_; + std::vector>> + batchingQueues_; + std::atomic stopping_; + int worldSize_; + std::unique_ptr observer_; + std::shared_ptr resourceManager_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h new file mode 100644 index 000000000..5667d1ad2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace torchrec { + +// We have different error code defined for different kinds of exceptions in +// fblearner/sigrid predictor. (Code pointer: +// fblearner/predictor/if/prediction_service.thrift.) We define different +// exception type here so that in fblearner/sigrid predictor we can detect the +// exception type and return the corresponding error code to reflect the right +// info. +class TorchrecException : public std::runtime_error { + public: + explicit TorchrecException(const std::string& error) + : std::runtime_error(error) {} +}; + +// GPUOverloadException maps to +// PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT +class GPUOverloadException : public TorchrecException { + public: + explicit GPUOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// GPUExecutorOverloadException maps to +// PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT +class GPUExecutorOverloadException : public TorchrecException { + public: + explicit GPUExecutorOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// TorchDeployException maps to +// PredictorUserErrorCode::TORCH_DEPLOY_ERROR +class TorchDeployException : public TorchrecException { + public: + explicit TorchDeployException(const std::string& error) + : TorchrecException(error) {} +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h new file mode 100644 index 000000000..491acf48f --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/Exception.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { +template +void handleRequestException( + folly::Promise>& promise, + const std::string& msg) { + auto ex = folly::make_exception_wrapper(msg); + auto response = std::make_unique(); + response->exception = std::move(ex); + promise.setValue(std::move(response)); +} + +template +void handleBatchException( + std::vector& contexts, + const std::string& msg) { + for (auto& context : contexts) { + handleRequestException(context.promise, msg); + } +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h new file mode 100644 index 000000000..b55dacc5a --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#else +#include // @manual +#endif + +#include "torchrec/inference/BatchingQueue.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResultSplit.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" + +namespace torchrec { + +class GPUExecutor { + public: + // Used to interface with python's garbage collector + struct GCConfig { + bool optimizationEnabled = false; + size_t collectionFreq = 1000; + size_t statReportingFreq = 10000; + std::unique_ptr observer = + std::make_unique(); + std::map threadIdToNumForwards = std::map(); + }; + + GPUExecutor( + std::shared_ptr manager, + torch::deploy::ReplicatedObj model, + size_t rank, + size_t worldSize, + std::shared_ptr func, + std::chrono::milliseconds queueTimeout, + std::shared_ptr + observer, // shared_ptr because used in completion executor callback + std::function warmupFn = {}, + std::optional numThreadsPerGPU = c10::nullopt, + std::unique_ptr gcConfig = std::make_unique()); + GPUExecutor(GPUExecutor&& executor) noexcept = default; + GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default; + ~GPUExecutor(); + + void callback(std::shared_ptr batch); + + void process(int idx); + + private: + // torch deploy + std::shared_ptr manager_; + torch::deploy::ReplicatedObj model_; + const size_t rank_; + const size_t worldSize_; + + folly::MPMCQueue> batches_; + std::vector processThreads_; + std::unique_ptr rejectionExecutor_; + std::unique_ptr completionExecutor_; + std::shared_ptr resultSplitFunc_; + const std::chrono::milliseconds queueTimeout_; + std::shared_ptr observer_; + std::function warmupFn_; + + std::mutex warmUpMutex_; + std::mutex warmUpAcquireSessionMutex_; + std::condition_variable warmUpCV_; + int warmUpCounter_{0}; + + size_t numThreadsPerGPU_; + + std::unique_ptr gcConfig_; + + void reportGCStats(c10::IValue stats); +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h new file mode 100644 index 000000000..de00aebac --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace torchrec { + +struct JaggedTensor { + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +struct KeyedJaggedTensor { + std::vector keys; + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h new file mode 100644 index 000000000..14ac3bceb --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +// Record generic timeseries stat with a key +class IDynamicTimeseriesObserver { + public: + virtual void addCount(uint32_t value, std::string key) = 0; + + virtual ~IDynamicTimeseriesObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyDynamicTimeseriesObserver : public IDynamicTimeseriesObserver { + public: + void addCount(uint32_t /* value */, std::string /* key */) override {} +}; + +class IBatchingQueueObserver { + public: + // Record the amount of time an entry of PredictionRequests + // in the batching queue waits before they are read and allocated + // onto a GPU device. + virtual void recordBatchingQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes for a batching function + // to execute. + virtual void recordBatchingFuncLatency( + uint32_t value, + std::string batchingFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes to create a batch of + // requests. + virtual void recordBatchCreationLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of batching queue timeouts experienced. + virtual void addBatchingQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of times a GPU could not be chosen + // for allocation. + virtual void addGPUBusyCount(uint32_t value) = 0; + + // Increment the number of requests entering the batching queue. + virtual void addRequestsCount(uint32_t value) = 0; + + // Increment the number of bytes of tensors moved to cuda. + virtual void addBytesMovedToGPUCount(uint32_t value) = 0; + + // Increment the number of batches processed by the batching + // queue (moved onto the GPU executor). + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + // Increment the number of requests processed by the batching + // queue (moved onto the GPU executor). + virtual void addRequestsProcessedCount(uint32_t value) = 0; + + // The obervations that should be made when a batch is completed. + virtual void observeBatchCompletion( + size_t batchSizeBytes, + size_t numRequests) { + addBytesMovedToGPUCount(batchSizeBytes); + addBatchesProcessedCount(1); + addRequestsProcessedCount(numRequests); + } + + virtual ~IBatchingQueueObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyBatchingQueueObserver : public IBatchingQueueObserver { + public: + void recordBatchingQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchingFuncLatency( + uint32_t /* value */, + std::string /* batchingFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchCreationLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addBatchingQueueTimeoutCount(uint32_t /* value */) override {} + + void addGPUBusyCount(uint32_t /* value */) override {} + + void addRequestsCount(uint32_t /* value */) override {} + + void addBytesMovedToGPUCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} + + void addRequestsProcessedCount(uint32_t /* value */) override {} +}; + +class IGPUExecutorObserver { + public: + // Record the amount of time a batch spends in the GPU Executor + // queue. + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of prediction (forward call, H2D). + virtual void recordPredictionLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of device to host transfer facilitated + // by result split function. + virtual void recordDeviceToHostLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of splitting the result. + virtual void recordResultSplitLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency from enqueue to completion. + virtual void recordTotalLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of GPUExecutor queue timeouts. + virtual void addQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of predict exceptions. + virtual void addPredictionExceptionCount(uint32_t value) = 0; + + // Increment the number of batches successfully processed. + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + virtual ~IGPUExecutorObserver() {} +}; + +class ISingleGPUExecutorObserver { + public: + virtual void addRequestsCount(uint32_t value) = 0; + virtual void addRequestProcessingExceptionCount(uint32_t value) = 0; + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) = 0; + + virtual void recordRequestProcessingLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + virtual ~ISingleGPUExecutorObserver() = default; +}; + +class EmptySingleGPUExecutorObserver : public ISingleGPUExecutorObserver { + void addRequestsCount(uint32_t) override {} + void addRequestProcessingExceptionCount(uint32_t) override {} + void recordQueueLatency( + uint32_t, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) override {} + + void recordRequestProcessingLatency( + uint32_t, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) override {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyGPUExecutorObserver : public IGPUExecutorObserver { + public: + void recordQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordPredictionLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordDeviceToHostLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordResultSplitLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordTotalLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addQueueTimeoutCount(uint32_t /* value */) override {} + + void addPredictionExceptionCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} +}; + +class IResourceManagerObserver { + public: + // Add the number of requests in flight for a gpu + virtual void addOutstandingRequestsCount(uint32_t value, int gpuIdx) = 0; + + // Add the most in flight requests on a gpu ever + virtual void addAllTimeHighOutstandingCount(uint32_t value, int gpuIdx) = 0; + + // Record the latency for finding a device + virtual void addWaitingForDeviceLatency( + uint32_t value, + int gpuIdx, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Recording all stats related to resource manager at once. + virtual void recordAllStats( + uint32_t outstandingRequests, + uint32_t allTimeHighOutstanding, + uint32_t waitedForMs, + int gpuIdx) { + addOutstandingRequestsCount(outstandingRequests, gpuIdx); + addAllTimeHighOutstandingCount(allTimeHighOutstanding, gpuIdx); + addWaitingForDeviceLatency(waitedForMs, gpuIdx); + } + + virtual ~IResourceManagerObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyResourceManagerObserver : public IResourceManagerObserver { + public: + void addOutstandingRequestsCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addAllTimeHighOutstandingCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addWaitingForDeviceLatency( + uint32_t /* value */, + int /* gpuIdx */, + std::chrono::steady_clock::time_point /* now */) override {} +}; + +// Helper for determining how much time has elapsed in milliseconds since a +// given time point. +inline std::chrono::milliseconds getTimeElapsedMS( + std::chrono::steady_clock::time_point startTime) { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime); +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h new file mode 100644 index 000000000..d3dd1ea18 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "torchrec/inference/Observer.h" + +namespace torchrec { + +/** + * ResourceManager can be used to limit in-flight batches + * allocated onto GPUs to prevent OOMing. + */ +class ResourceManager { + public: + ResourceManager( + int worldSize, + size_t maxOutstandingBatches, + int logFrequency = 100, + std::unique_ptr observer = + std::make_unique()); + + // Returns whether batches can be allocated onto a device based on + // slack provided (ms) and maxOutstandingBatches_). + bool occupyDevice(int gpuIdx, std::chrono::milliseconds slack); + + void release(int gpuIdx); + + private: + folly::small_vector gpuToOutstandingBatches_; + // Helpful for tuning + folly::small_vector allTimeHigh_; + const size_t maxOutstandingBatches_; + const int logFrequency_; + // Align as 64B to avoid false sharing + alignas(64) std::mutex mu_; + std::unique_ptr observer_; +}; + +class ResourceManagerGuard { + public: + ResourceManagerGuard( + std::weak_ptr resourceManager, + int gpuIdx) + : resourceManager_(std::move(resourceManager)), gpuIdx_(gpuIdx) {} + + ~ResourceManagerGuard() { + std::shared_ptr rm = resourceManager_.lock(); + if (rm != nullptr) { + rm->release(gpuIdx_); + } + } + + private: + std::weak_ptr resourceManager_; + const int gpuIdx_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h new file mode 100644 index 000000000..2c3ef2463 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +class ResultSplitFunc { + public: + virtual ~ResultSplitFunc() = default; + + virtual std::string name() = 0; + + virtual c10::IValue splitResult( + c10::IValue /* result */, + size_t /* nOffset */, + size_t /* nLength */, + size_t /* nTotalLength */) = 0; + + virtual c10::IValue moveToHost(c10::IValue /* result */) = 0; +}; + +/** + * TorchRecResultSplitFuncRegistry is used to register custom result split + * functions. + */ +C10_DECLARE_REGISTRY(TorchRecResultSplitFuncRegistry, ResultSplitFunc); + +#define REGISTER_TORCHREC_RESULTSPLIT_FUNC(name, ...) \ + C10_REGISTER_CLASS(TorchRecResultSplitFuncRegistry, name, __VA_ARGS__); + +c10::IValue splitDictOfTensor( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue splitDictOfTensors( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue +splitDictWithMaskTensor(c10::IValue result, size_t nOffset, size_t nLength); + +class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { + public: + virtual std::string name() override; + + virtual c10::IValue splitResult( + c10::IValue result, + size_t offset, + size_t length, + size_t /* nTotalLength */) override; + + c10::IValue moveToHost(c10::IValue result) override; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h new file mode 100644 index 000000000..478773768 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace torchrec { + +struct ShardMetadata { + std::vector shard_offsets; + std::vector shard_lengths; + + bool operator==(const ShardMetadata& other) const { + return shard_offsets == other.shard_offsets && + shard_lengths == other.shard_lengths; + } +}; + +struct Shard { + ShardMetadata metadata; + at::Tensor tensor; +}; + +struct ShardedTensorMetadata { + std::vector shards_metadata; +}; + +struct ShardedTensor { + std::vector sizes; + std::vector local_shards; + ShardedTensorMetadata metadata; +}; + +struct ReplicatedTensor { + ShardedTensor local_replica; + int64_t local_replica_id; + int64_t replica_count; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h new file mode 100644 index 000000000..9da63d7c2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +class SingleGPUExecutor { + constexpr static const size_t kQUEUE_CAPACITY = 10000; + + public: + struct ExecInfo { + size_t gpuIdx; + size_t interpIdx; + torch::deploy::ReplicatedObj model; + }; + using ExecInfos = std::vector; + + SingleGPUExecutor( + std::shared_ptr manager, + ExecInfos execInfos, + size_t numGpu, + std::shared_ptr observer = + std::make_shared(), + c10::Device resultDevice = c10::kCPU, + size_t numProcessThreads = 1u, + bool useHighPriCudaStream = false); + + // Moveable only + SingleGPUExecutor(SingleGPUExecutor&& executor) noexcept = default; + SingleGPUExecutor& operator=(SingleGPUExecutor&& executor) noexcept = default; + ~SingleGPUExecutor(); + + void schedule(std::shared_ptr request); + + private: + void process(); + + std::shared_ptr manager_; + const ExecInfos execInfos_; + const size_t numGpu_; + const size_t numProcessThreads_; + const bool useHighPriCudaStream_; + const c10::Device resultDevice_; + std::shared_ptr observer_; + folly::MPMCQueue> requests_; + + std::unique_ptr processExecutor_; + std::unique_ptr completionExecutor_; + std::atomic roundRobinExecInfoNextIdx_; +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h new file mode 100644 index 000000000..1150e2c23 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +std::shared_ptr createRequest(at::Tensor denseTensor); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); + +JaggedTensor createJaggedTensor(const std::vector>& input); + +c10::List createIValueList( + const std::vector>& input); + +at::Tensor createEmbeddingTensor( + const std::vector>& input); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h new file mode 100644 index 000000000..0a09d1f35 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "torchrec/inference/ResourceManager.h" + +namespace torchrec { + +struct SparseFeatures { + uint32_t num_features; + // int32: T x B + folly::IOBuf lengths; + // T x B x L (jagged) + folly::IOBuf values; + // float16 + folly::IOBuf weights; +}; + +struct FloatFeatures { + uint32_t num_features; + // shape: {B} + folly::IOBuf values; +}; + +// TODO: Change the input format to torch::IValue. +// Currently only dense batching function support IValue. +using Feature = std::variant; + +struct PredictionRequest { + uint32_t batch_size; + std::unordered_map features; +}; + +struct PredictionResponse { + uint32_t batchSize; + c10::IValue predictions; + // If set, the result is an exception. + std::optional exception; +}; + +struct RequestContext { + uint32_t batchSize; + folly::Promise> promise; + // folly request context for request tracking in crochet + std::shared_ptr follyRequestContext; +}; + +using PredictionException = std::runtime_error; + +using Event = std:: + unique_ptr>; + +struct BatchingMetadata { + std::string type; + std::string device; + folly::F14FastSet pinned; +}; + +// noncopyable because we only want to move PredictionBatch around +// as it holds a reference to ResourceManagerGuard. We wouldn't want +// to inadvertently increase the reference count to ResourceManagerGuard +// with copies of this struct. +struct PredictionBatch : public boost::noncopyable { + std::string methodName; + std::vector args; + + size_t batchSize; + + c10::impl::GenericDict forwardArgs; + + std::vector contexts; + + std::unique_ptr resourceManagerGuard = nullptr; + + std::chrono::time_point enqueueTime = + std::chrono::steady_clock::now(); + + Event event; + + // Need a constructor to use make_shared/unique with + // noncopyable struct and not trigger copy-constructor. + PredictionBatch( + size_t bs, + c10::impl::GenericDict fa, + std::vector ctxs, + std::unique_ptr rmg = nullptr) + : batchSize(bs), + forwardArgs(std::move(fa)), + contexts(std::move(ctxs)), + resourceManagerGuard(std::move(rmg)) {} + + PredictionBatch( + std::string methodNameArg, + std::vector argsArg, + folly::Promise> promise) + : methodName(std::move(methodNameArg)), + args(std::move(argsArg)), + forwardArgs( + c10::impl::GenericDict(at::StringType::get(), at::AnyType::get())) { + contexts.push_back(RequestContext{1u, std::move(promise)}); + } + + size_t sizeOfIValue(const c10::IValue& val) const { + size_t size = 0; + if (val.isTensor()) { + size += val.toTensor().storage().nbytes(); + } else if (val.isList()) { + for (const auto& v : val.toListRef()) { + size += sizeOfIValue(v); + } + } + return size; + } + + inline size_t size() const { + size_t size = 0; + for (const auto& iter : forwardArgs) { + size += sizeOfIValue(iter.value()); + } + return size; + } +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h new file mode 100644 index 000000000..4844c9ca1 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "torchrec/inference/Types.h" + +namespace torchrec { + +// Returns whether sparse features (KeyedJaggedTensor) are valid. +// Currently validates: +// 1. Whether sum(lengths) == size(values) +// 2. Whether there are negative values in lengths +// 3. If weights is present, whether sum(lengths) == size(weights) +bool validateSparseFeatures( + at::Tensor& values, + at::Tensor& lengths, + std::optional maybeWeights = c10::nullopt); + +// Returns whether dense features are valid. +// Currently validates: +// 1. Whether the size of values is divisable by batch size (request level) +bool validateDenseFeatures(at::Tensor& values, size_t batchSize); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/model_packager.py b/torchrec/inference/inference_legacy/model_packager.py new file mode 100644 index 000000000..9957b3373 --- /dev/null +++ b/torchrec/inference/inference_legacy/model_packager.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Type, TypeVar, Union + +import torch +from torch.package import PackageExporter +from torchrec.inference.modules import PredictFactory + +LOADER_MODULE = "__module_loader" +LOADER_FACTORY = "MODULE_FACTORY" +LOADER_CODE = f""" +import %PACKAGE% + +{LOADER_FACTORY}=%PACKAGE%.%CLASS% +""" +CONFIG_MODULE = "__configs" + +T = TypeVar("T") + + +try: + # pyre-fixme[21]: Could not find module `torch_package_importer`. + import torch_package_importer # @manual +except ImportError: + pass + + +def load_config_text(name: str) -> str: + return torch_package_importer.load_text("__configs", name) + + +def load_pickle_config(name: str, clazz: Type[T]) -> T: + loaded_obj = torch_package_importer.load_pickle("__configs", name) + assert isinstance( + loaded_obj, clazz + ), f"The loaded config {type(loaded_obj)} is not of type {clazz}" + return loaded_obj + + +class PredictFactoryPackager: + @classmethod + @abc.abstractclassmethod + def set_extern_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + @abc.abstractclassmethod + def set_mocked_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + def save_predict_factory( + cls, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + output: Union[str, Path, BinaryIO], + extra_files: Dict[str, Union[str, bytes]], + loader_code: str = LOADER_CODE, + package_importer: Union[ + torch.package.Importer, List[torch.package.Importer] + ] = torch.package.sys_importer, + ) -> None: + with PackageExporter(output, importer=package_importer) as pe: + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_extern_modules(pe) + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_mocked_modules(pe) + pe.extern(["sys"]) + pe.intern("**") + for k, v in extra_files.items(): + if isinstance(v, str): + pe.save_text("extra_files", k, v) + elif isinstance(v, bytes): + pe.save_binary("extra_files", k, v) + else: + raise ValueError(f"Unsupported type {type(v)}") + cls._save_predict_factory( + pe, predict_factory, configs, loader_code=loader_code + ) + + @classmethod + def _save_predict_factory( + cls, + pe: PackageExporter, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + loader_code: str = LOADER_CODE, + ) -> None: + # If predict_factory is coming from a torch package, + # __module__ would have prefix. + # To save such predict factory, we need to remove + # the prefix. + package_name = predict_factory.__module__ + if package_name.startswith(" predictions = 1; +} + +// The predictor service definition. Synchronous for now. +service Predictor { + rpc Predict(PredictionRequest) returns (PredictionResponse) {} +} diff --git a/torchrec/inference/inference_legacy/server.cpp b/torchrec/inference/inference_legacy/server.cpp new file mode 100644 index 000000000..f7695cdcf --- /dev/null +++ b/torchrec/inference/inference_legacy/server.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#include +#else +#include +#include +#endif + +#include + +#include "torchrec/inference/GPUExecutor.h" +#include "torchrec/inference/predictor.grpc.pb.h" +#include "torchrec/inference/predictor.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using predictor::FloatVec; +using predictor::PredictionRequest; +using predictor::PredictionResponse; +using predictor::Predictor; + +DEFINE_int32(n_interp_per_gpu, 1, ""); +DEFINE_int32(n_gpu, 1, ""); +DEFINE_string(package_path, "", ""); + +DEFINE_int32(batching_interval, 10, ""); +DEFINE_int32(queue_timeout, 500, ""); + +DEFINE_int32(num_exception_threads, 4, ""); +DEFINE_int32(num_mem_pinner_threads, 4, ""); +DEFINE_int32(max_batch_size, 2048, ""); +DEFINE_int32(gpu_executor_queue_timeout, 50, ""); + +DEFINE_string(server_address, "0.0.0.0", ""); +DEFINE_string(server_port, "50051", ""); + +DEFINE_string( + python_packages_path, + "", + "Used to load the packages that you 'extern' with torch.package"); + +namespace { + +std::unique_ptr toTorchRecRequest( + const PredictionRequest* request) { + auto torchRecRequest = std::make_unique(); + torchRecRequest->batch_size = request->batch_size(); + + // Client sends a request with serialized tensor to bytes. + // Byte string is converted to folly::iobuf for torchrec request. + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->float_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["float_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["id_list_features"] = std::move(sparseFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_score_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + auto encoded_weights = feature.weights(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + sparseFeature.weights = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_weights.data(), + encoded_weights.size()}; + + torchRecRequest->features["id_score_list_features"] = + std::move(sparseFeature); + } + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->embedding_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["embedding_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->unary_features(); + auto encoded_lengths = feature.lengths(); + auto encoded_values = feature.values(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["unary_features"] = std::move(sparseFeature); + } + + return torchRecRequest; +} + +// Logic behind the server's behavior. +class PredictorServiceHandler final : public Predictor::Service { + public: + explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) + : queue_(queue) {} + + Status Predict( + grpc::ServerContext* context, + const PredictionRequest* request, + PredictionResponse* reply) override { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + queue_.add(toTorchRecRequest(request), std::move(promise)); + auto torchRecResponse = + std::move(future).get(); // blocking, TODO: Write async server + auto predictions = reply->mutable_predictions(); + + // Convert ivalue to map, TODO: find out if protobuf + // can support custom types (folly::iobuf), so we can avoid this overhead. + for (const auto& item : torchRecResponse->predictions.toGenericDict()) { + auto tensor = item.value().toTensor(); + FloatVec fv; + fv.mutable_data()->Add( + tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); + (*predictions)[item.key().toStringRef()] = fv; + } + + return Status::OK; + } + + private: + torchrec::BatchingQueue& queue_; +}; + +} // namespace + +int main(int argc, char* argv[]) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + LOG(INFO) << "Creating GPU executors"; + + // store the executors and interpreter managers + std::vector> executors; + std::vector models; + std::vector batchQueueCbs; + std::unordered_map batchingMetadataMap; + + std::shared_ptr env = + std::make_shared( + FLAGS_python_packages_path); + + auto manager = std::make_shared( + FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); + { + torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); + auto I = package.acquireSession(); + auto imported = I.self.attr("import_module")({"__module_loader"}); + auto factoryType = imported.attr("MODULE_FACTORY"); + auto factory = factoryType.attr("__new__")({factoryType}); + factoryType.attr("__init__")({factory}); + + // Process forward metadata. + try { + auto batchingMetadataJsonStr = + factory.attr("batching_metadata_json")(at::ArrayRef()) + .toIValue() + .toString() + ->string(); + auto dynamic = folly::parseJson(batchingMetadataJsonStr); + CHECK(dynamic.isObject()); + for (auto it : dynamic.items()) { + torchrec::BatchingMetadata metadata; + metadata.type = it.second["type"].asString(); + metadata.device = it.second["device"].asString(); + batchingMetadataMap[it.first.asString()] = std::move(metadata); + } + } catch (...) { + auto batchingMetadata = + factory.attr("batching_metadata")(at::ArrayRef()) + .toIValue(); + for (const auto& iter : batchingMetadata.toGenericDict()) { + torchrec::BatchingMetadata metadata; + metadata.type = iter.value().toStringRef(); + metadata.device = "cuda"; + batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); + } + } + + // Process result metadata. + auto resultMetadata = + factory.attr("result_metadata")(at::ArrayRef()) + .toIValue() + .toStringRef(); + std::shared_ptr resultSplitFunc = + torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); + + LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; + auto dmp = factory.attr("create_predict_module") + .callKwargs({{"world_size", FLAGS_n_gpu}}); + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto device = I.self.attr("import_module")({"torch"}).attr("device")( + {"cuda", rank}); + auto m = dmp.attr("copy")({device.toIValue()}); + models.push_back(I.createMovable(m)); + } + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto executor = std::make_unique( + manager, + std::move(models[rank]), + rank, + FLAGS_n_gpu, + resultSplitFunc, + std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); + executors.push_back(std::move(executor)); + batchQueueCbs.push_back( + [&, rank](std::shared_ptr batch) { + executors[rank]->callback(std::move(batch)); + }); + } + } + + torchrec::BatchingQueue queue( + batchQueueCbs, + torchrec::BatchingQueue::Config{ + .batchingInterval = + std::chrono::milliseconds(FLAGS_batching_interval), + .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), + .numExceptionThreads = FLAGS_num_exception_threads, + .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, + .maxBatchSize = FLAGS_max_batch_size, + .batchingMetadata = std::move(batchingMetadataMap), + }, + FLAGS_n_gpu); + + // create the server + std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); + auto service = PredictorServiceHandler(queue); + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + grpc::ServerBuilder builder; + + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + LOG(INFO) << "Server listening on " << server_address; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); + + LOG(INFO) << "Shutting down server"; + return 0; +} diff --git a/torchrec/inference/src/Batching.cpp b/torchrec/inference/inference_legacy/src/Batching.cpp similarity index 100% rename from torchrec/inference/src/Batching.cpp rename to torchrec/inference/inference_legacy/src/Batching.cpp diff --git a/torchrec/inference/src/BatchingQueue.cpp b/torchrec/inference/inference_legacy/src/BatchingQueue.cpp similarity index 100% rename from torchrec/inference/src/BatchingQueue.cpp rename to torchrec/inference/inference_legacy/src/BatchingQueue.cpp diff --git a/torchrec/inference/inference_legacy/src/Executer2.cpp b/torchrec/inference/inference_legacy/src/Executer2.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/inference/src/GPUExecutor.cpp b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp similarity index 100% rename from torchrec/inference/src/GPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/GPUExecutor.cpp diff --git a/torchrec/inference/src/ResourceManager.cpp b/torchrec/inference/inference_legacy/src/ResourceManager.cpp similarity index 100% rename from torchrec/inference/src/ResourceManager.cpp rename to torchrec/inference/inference_legacy/src/ResourceManager.cpp diff --git a/torchrec/inference/src/ResultSplit.cpp b/torchrec/inference/inference_legacy/src/ResultSplit.cpp similarity index 100% rename from torchrec/inference/src/ResultSplit.cpp rename to torchrec/inference/inference_legacy/src/ResultSplit.cpp diff --git a/torchrec/inference/src/SingleGPUExecutor.cpp b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp similarity index 100% rename from torchrec/inference/src/SingleGPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp diff --git a/torchrec/inference/src/TestUtils.cpp b/torchrec/inference/inference_legacy/src/TestUtils.cpp similarity index 100% rename from torchrec/inference/src/TestUtils.cpp rename to torchrec/inference/inference_legacy/src/TestUtils.cpp diff --git a/torchrec/inference/src/Validation.cpp b/torchrec/inference/inference_legacy/src/Validation.cpp similarity index 100% rename from torchrec/inference/src/Validation.cpp rename to torchrec/inference/inference_legacy/src/Validation.cpp diff --git a/torchrec/inference/inference_legacy/state_dict_transform.py b/torchrec/inference/inference_legacy/state_dict_transform.py new file mode 100644 index 000000000..0379b1b80 --- /dev/null +++ b/torchrec/inference/inference_legacy/state_dict_transform.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Union + +import torch +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor + + +def state_dict_gather( + src: Dict[str, Union[torch.Tensor, ShardedTensor]], + dst: Dict[str, torch.Tensor], +) -> None: + """ + Gathers the values of the src state_dict of the keys present in the dst state_dict. Can handle ShardedTensors in the src state_dict. + + Args: + src (Dict[str, Union[torch.Tensor, ShardedTensor]]): source's state_dict for this rank + dst (Dict[str, torch.Tensor]): destination's state_dict + """ + for key, dst_tensor in dst.items(): + src_tensor = src[key] + if isinstance(src_tensor, ShardedTensor): + src_tensor.gather(out=dst_tensor if (dist.get_rank() == 0) else None) + elif isinstance(src_tensor, torch.Tensor): + dst_tensor.copy_(src_tensor) + else: + raise ValueError(f"Unsupported tensor {key} type {type(src_tensor)}") + + +def state_dict_all_gather_keys( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, +) -> List[str]: + """ + Gathers all the keys of the state_dict from all ranks. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): keys of this state_dict will be gathered + pg (ProcessGroup): Process Group used for comms + """ + names = list(state_dict.keys()) + all_names = [None] * dist.get_world_size(pg) + dist.all_gather_object(all_names, names, pg) + deduped_names = set() + for local_names in all_names: + # pyre-ignore[16] + for name in local_names: + deduped_names.add(name) + return sorted(deduped_names) + + +def state_dict_to_device( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, + device: torch.device, +) -> Dict[str, Union[torch.Tensor, ShardedTensor]]: + """ + Moves a state_dict to a device with a process group. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): state_dict to move + pg (ProcessGroup): Process Group used for comms + device (torch.device): device to put state_dict on + """ + ret = {} + all_keys = state_dict_all_gather_keys(state_dict, pg) + for key in all_keys: + if key in state_dict: + tensor = state_dict[key] + if isinstance(tensor, ShardedTensor): + copied_shards = [ + Shard.from_tensor_and_offsets( + tensor=shard.tensor.to(device), + shard_offsets=shard.metadata.shard_offsets, + rank=dist.get_rank(pg), + ) + for shard in tensor.local_shards() + ] + ret[key] = ShardedTensor._init_from_local_shards( + copied_shards, + tensor.metadata().size, + process_group=pg, + ) + elif isinstance(tensor, torch.Tensor): + ret[key] = tensor.to(device) + else: + raise ValueError(f"Unsupported tensor {key} type {type(tensor)}") + else: + # No state_dict entries for table-wise sharding, + # but need to follow full-sync. + ret[key] = ShardedTensor._init_from_local_shards( + [], + [], + process_group=pg, + ) + return ret diff --git a/torchrec/inference/tests/BatchingQueueTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp similarity index 100% rename from torchrec/inference/tests/BatchingQueueTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp diff --git a/torchrec/inference/tests/BatchingTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingTest.cpp similarity index 100% rename from torchrec/inference/tests/BatchingTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingTest.cpp diff --git a/torchrec/inference/tests/ResultSplitTest.cpp b/torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp similarity index 100% rename from torchrec/inference/tests/ResultSplitTest.cpp rename to torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp diff --git a/torchrec/inference/tests/SingleGPUExecutorMultiGPUTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp similarity index 100% rename from torchrec/inference/tests/SingleGPUExecutorMultiGPUTest.cpp rename to torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp diff --git a/torchrec/inference/tests/SingleGPUExecutorTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp similarity index 100% rename from torchrec/inference/tests/SingleGPUExecutorTest.cpp rename to torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp diff --git a/torchrec/inference/tests/ValidationTest.cpp b/torchrec/inference/inference_legacy/tests/ValidationTest.cpp similarity index 100% rename from torchrec/inference/tests/ValidationTest.cpp rename to torchrec/inference/inference_legacy/tests/ValidationTest.cpp diff --git a/torchrec/inference/tests/generate_test_packages.py b/torchrec/inference/inference_legacy/tests/generate_test_packages.py similarity index 100% rename from torchrec/inference/tests/generate_test_packages.py rename to torchrec/inference/inference_legacy/tests/generate_test_packages.py diff --git a/torchrec/inference/tests/model_packager_tests.py b/torchrec/inference/inference_legacy/tests/model_packager_tests.py similarity index 99% rename from torchrec/inference/tests/model_packager_tests.py rename to torchrec/inference/inference_legacy/tests/model_packager_tests.py index b88942e19..41533e34a 100644 --- a/torchrec/inference/tests/model_packager_tests.py +++ b/torchrec/inference/inference_legacy/tests/model_packager_tests.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict import io import tempfile diff --git a/torchrec/inference/tests/predict_module_tests.py b/torchrec/inference/inference_legacy/tests/predict_module_tests.py similarity index 99% rename from torchrec/inference/tests/predict_module_tests.py rename to torchrec/inference/inference_legacy/tests/predict_module_tests.py index 22d8c99f8..5a1ac8a31 100644 --- a/torchrec/inference/tests/predict_module_tests.py +++ b/torchrec/inference/inference_legacy/tests/predict_module_tests.py @@ -5,8 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict - import unittest from typing import Dict diff --git a/torchrec/inference/tests/test_modules.py b/torchrec/inference/inference_legacy/tests/test_modules.py similarity index 100% rename from torchrec/inference/tests/test_modules.py rename to torchrec/inference/inference_legacy/tests/test_modules.py diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 27db09280..4d0e68a14 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.quantization as quant import torchrec as trec +import torchrec.distributed as trec_dist import torchrec.quant as trec_quant from torch.fx.passes.split_utils import getattr_recursive -from torchrec import distributed as trec_dist, inference as trec_infer from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fused_params import ( FUSED_PARAM_BOUNDS_CHECK_MODE, @@ -357,7 +357,7 @@ def _quantize_fp_module( _quantize_fp_module(model, m, n) quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) - trec_infer.modules.quantize_embeddings( + quantize_embeddings( model, dtype=DEFAULT_QUANTIZATION_DTYPE, additional_qconfig_spec_keys=additional_qconfig_spec_keys, diff --git a/torchrec/inference/server.cpp b/torchrec/inference/server.cpp index f7695cdcf..51ad58fad 100644 --- a/torchrec/inference/server.cpp +++ b/torchrec/inference/server.cpp @@ -10,327 +10,170 @@ #include #include -#include -#include -#include -#include -#include -#include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_format.h" + #include #include +#include + +#include +#include -// remove this after we switch over to multipy externally for torchrec -#ifdef FBCODE_CAFFE2 -#include // @manual -#include +#ifdef BAZEL_BUILD +#include "examples/protos/predictor.grpc.pb.h" #else -#include -#include +#include "predictor.grpc.pb.h" #endif -#include +#define NUM_BYTES_FLOAT_FEATURES 4 +#define NUM_BYTES_SPARSE_FEATURES 4 -#include "torchrec/inference/GPUExecutor.h" -#include "torchrec/inference/predictor.grpc.pb.h" -#include "torchrec/inference/predictor.pb.h" - -using grpc::Channel; -using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; using grpc::Status; + +ABSL_FLAG(uint16_t, port, 50051, "Server port for the service"); + using predictor::FloatVec; using predictor::PredictionRequest; using predictor::PredictionResponse; using predictor::Predictor; -DEFINE_int32(n_interp_per_gpu, 1, ""); -DEFINE_int32(n_gpu, 1, ""); -DEFINE_string(package_path, "", ""); - -DEFINE_int32(batching_interval, 10, ""); -DEFINE_int32(queue_timeout, 500, ""); - -DEFINE_int32(num_exception_threads, 4, ""); -DEFINE_int32(num_mem_pinner_threads, 4, ""); -DEFINE_int32(max_batch_size, 2048, ""); -DEFINE_int32(gpu_executor_queue_timeout, 50, ""); - -DEFINE_string(server_address, "0.0.0.0", ""); -DEFINE_string(server_port, "50051", ""); - -DEFINE_string( - python_packages_path, - "", - "Used to load the packages that you 'extern' with torch.package"); - -namespace { - -std::unique_ptr toTorchRecRequest( - const PredictionRequest* request) { - auto torchRecRequest = std::make_unique(); - torchRecRequest->batch_size = request->batch_size(); - - // Client sends a request with serialized tensor to bytes. - // Byte string is converted to folly::iobuf for torchrec request. - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->float_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["float_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["id_list_features"] = std::move(sparseFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_score_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - auto encoded_weights = feature.weights(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - sparseFeature.weights = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_weights.data(), - encoded_weights.size()}; - - torchRecRequest->features["id_score_list_features"] = - std::move(sparseFeature); - } - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->embedding_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["embedding_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->unary_features(); - auto encoded_lengths = feature.lengths(); - auto encoded_values = feature.values(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["unary_features"] = std::move(sparseFeature); - } - - return torchRecRequest; -} - -// Logic behind the server's behavior. class PredictorServiceHandler final : public Predictor::Service { public: - explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) - : queue_(queue) {} + PredictorServiceHandler(torch::jit::script::Module& module) + : module_(module) {} Status Predict( grpc::ServerContext* context, const PredictionRequest* request, PredictionResponse* reply) override { - folly::Promise> promise; - auto future = promise.getSemiFuture(); - queue_.add(toTorchRecRequest(request), std::move(promise)); - auto torchRecResponse = - std::move(future).get(); // blocking, TODO: Write async server - auto predictions = reply->mutable_predictions(); - - // Convert ivalue to map, TODO: find out if protobuf - // can support custom types (folly::iobuf), so we can avoid this overhead. - for (const auto& item : torchRecResponse->predictions.toGenericDict()) { - auto tensor = item.value().toTensor(); - FloatVec fv; - fv.mutable_data()->Add( - tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); - (*predictions)[item.key().toStringRef()] = fv; - } + std::cout << "Predict Called!" << std::endl; + c10::Dict dict; + + auto floatFeature = request->float_features(); + auto floatFeatureBlob = floatFeature.values(); + auto numFloatFeatures = floatFeature.num_features(); + auto batchSize = + floatFeatureBlob.size() / (NUM_BYTES_FLOAT_FEATURES * numFloatFeatures); + + std::cout << "Size: " << floatFeatureBlob.size() + << " Num Features: " << numFloatFeatures << std::endl; + auto floatFeatureTensor = torch::from_blob( + floatFeatureBlob.data(), + {batchSize, numFloatFeatures}, + torch::kFloat32); + + auto idListFeature = request->id_list_features(); + auto numIdListFeatures = idListFeature.num_features(); + auto lengthsBlob = idListFeature.lengths(); + auto valuesBlob = idListFeature.values(); + + std::cout << "Lengths Size: " << lengthsBlob.size() + << " Num Features: " << numIdListFeatures << std::endl; + assert( + batchSize == + (lengthsBlob.size() / (NUM_BYTES_SPARSE_FEATURES * numIdListFeatures))); + + auto lengthsTensor = torch::from_blob( + lengthsBlob.data(), + {lengthsBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + auto valuesTensor = torch::from_blob( + valuesBlob.data(), + {valuesBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + + dict.insert("float_features", floatFeatureTensor.to(torch::kCUDA)); + dict.insert("id_list_features.lengths", lengthsTensor.to(torch::kCUDA)); + dict.insert("id_list_features.values", valuesTensor.to(torch::kCUDA)); + + std::vector input; + input.push_back(c10::IValue(dict)); + + torch::Tensor output = + this->module_.forward(input).toGenericDict().at("default").toTensor(); + auto predictions = reply->mutable_predictions(); + FloatVec fv; + fv.mutable_data()->Add( + output.data_ptr(), output.data_ptr() + output.numel()); + (*predictions)["default"] = fv; return Status::OK; } private: - torchrec::BatchingQueue& queue_; + torch::jit::script::Module& module_; }; -} // namespace - -int main(int argc, char* argv[]) { - google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); - - LOG(INFO) << "Creating GPU executors"; - - // store the executors and interpreter managers - std::vector> executors; - std::vector models; - std::vector batchQueueCbs; - std::unordered_map batchingMetadataMap; - - std::shared_ptr env = - std::make_shared( - FLAGS_python_packages_path); - - auto manager = std::make_shared( - FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); - { - torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); - auto I = package.acquireSession(); - auto imported = I.self.attr("import_module")({"__module_loader"}); - auto factoryType = imported.attr("MODULE_FACTORY"); - auto factory = factoryType.attr("__new__")({factoryType}); - factoryType.attr("__init__")({factory}); - - // Process forward metadata. - try { - auto batchingMetadataJsonStr = - factory.attr("batching_metadata_json")(at::ArrayRef()) - .toIValue() - .toString() - ->string(); - auto dynamic = folly::parseJson(batchingMetadataJsonStr); - CHECK(dynamic.isObject()); - for (auto it : dynamic.items()) { - torchrec::BatchingMetadata metadata; - metadata.type = it.second["type"].asString(); - metadata.device = it.second["device"].asString(); - batchingMetadataMap[it.first.asString()] = std::move(metadata); - } - } catch (...) { - auto batchingMetadata = - factory.attr("batching_metadata")(at::ArrayRef()) - .toIValue(); - for (const auto& iter : batchingMetadata.toGenericDict()) { - torchrec::BatchingMetadata metadata; - metadata.type = iter.value().toStringRef(); - metadata.device = "cuda"; - batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); - } - } - - // Process result metadata. - auto resultMetadata = - factory.attr("result_metadata")(at::ArrayRef()) - .toIValue() - .toStringRef(); - std::shared_ptr resultSplitFunc = - torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); - - LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; - auto dmp = factory.attr("create_predict_module") - .callKwargs({{"world_size", FLAGS_n_gpu}}); - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto device = I.self.attr("import_module")({"torch"}).attr("device")( - {"cuda", rank}); - auto m = dmp.attr("copy")({device.toIValue()}); - models.push_back(I.createMovable(m)); - } - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto executor = std::make_unique( - manager, - std::move(models[rank]), - rank, - FLAGS_n_gpu, - resultSplitFunc, - std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); - executors.push_back(std::move(executor)); - batchQueueCbs.push_back( - [&, rank](std::shared_ptr batch) { - executors[rank]->callback(std::move(batch)); - }); - } - } - - torchrec::BatchingQueue queue( - batchQueueCbs, - torchrec::BatchingQueue::Config{ - .batchingInterval = - std::chrono::milliseconds(FLAGS_batching_interval), - .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), - .numExceptionThreads = FLAGS_num_exception_threads, - .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, - .maxBatchSize = FLAGS_max_batch_size, - .batchingMetadata = std::move(batchingMetadataMap), - }, - FLAGS_n_gpu); - - // create the server - std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); - auto service = PredictorServiceHandler(queue); +void RunServer(uint16_t port, torch::jit::script::Module& module) { + std::string server_address = absl::StrFormat("0.0.0.0:%d", port); + PredictorServiceHandler service(module); grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); - grpc::ServerBuilder builder; - + ServerBuilder builder; // Listen on the given address without any authentication mechanism. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - // Register "service" as the instance through which we'll communicate with // clients. In this case it corresponds to an *synchronous* service. builder.RegisterService(&service); - // Finally assemble the server. - std::unique_ptr server(builder.BuildAndStart()); - LOG(INFO) << "Server listening on " << server_address; + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; // Wait for the server to shutdown. Note that some other thread must be // responsible for shutting down the server for this call to ever return. server->Wait(); +} + +int main(int argc, char** argv) { + // absl::ParseCommandLine(argc, argv); + + if (argc != 2) { + std::cerr << "usage: ts-infer \n"; + return -1; + } + + std::cout << "Loading model...\n"; + + // deserialize ScriptModule + torch::jit::script::Module module; + try { + module = torch::jit::load(argv[1]); + } catch (const c10::Error& e) { + std::cerr << "Error loading model\n"; + return -1; + } + + torch::NoGradGuard no_grad; // ensures that autograd is off + module.eval(); // turn off dropout and other training-time layers/functions + + std::cout << "Sanity Check with dummy inputs" << std::endl; + c10::Dict dict; + dict.insert( + "float_features", + torch::ones( + {1, 13}, torch::dtype(torch::kFloat32).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.lengths", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.values", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + + std::vector input; + input.push_back(c10::IValue(dict)); + + // Execute the model and turn its output into a tensor. + auto output = module.forward(input).toGenericDict().at("default").toTensor(); + std::cout << " Model Forward Completed, Output: " << output.item() + << std::endl; + + RunServer(absl::GetFlag(FLAGS_port), module); - LOG(INFO) << "Shutting down server"; return 0; }