diff --git a/examples/inference/README.md b/examples/inference_legacy/README.md similarity index 100% rename from examples/inference/README.md rename to examples/inference_legacy/README.md diff --git a/examples/inference/dlrm_client.py b/examples/inference_legacy/dlrm_client.py similarity index 100% rename from examples/inference/dlrm_client.py rename to examples/inference_legacy/dlrm_client.py diff --git a/examples/inference/dlrm_packager.py b/examples/inference_legacy/dlrm_packager.py similarity index 100% rename from examples/inference/dlrm_packager.py rename to examples/inference_legacy/dlrm_packager.py diff --git a/examples/inference/dlrm_predict.py b/examples/inference_legacy/dlrm_predict.py similarity index 100% rename from examples/inference/dlrm_predict.py rename to examples/inference_legacy/dlrm_predict.py diff --git a/examples/inference/dlrm_predict_single_gpu.py b/examples/inference_legacy/dlrm_predict_single_gpu.py similarity index 100% rename from examples/inference/dlrm_predict_single_gpu.py rename to examples/inference_legacy/dlrm_predict_single_gpu.py diff --git a/torchrec/inference/CMakeLists.txt b/torchrec/inference/CMakeLists.txt index 794bb1490..3b3ac2a8d 100644 --- a/torchrec/inference/CMakeLists.txt +++ b/torchrec/inference/CMakeLists.txt @@ -4,110 +4,60 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.13 FATAL_ERROR) -project(inference) +cmake_minimum_required(VERSION 3.8) -# This step is crucial to ensure that the -# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. -# e.g. ~/gprc/examples/cpp/cmake/common.cmake -include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) +project(inference C CXX) +include(/home/paulzhan/grpc/examples/cpp/cmake/common.cmake) -# abi and other flags -if(DEFINED GLIBCXX_USE_CXX11_ABI) - if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") - set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") - endif() -endif() +# Proto file +get_filename_component(hw_proto "/home/paulzhan/torchrec/torchrec/inference/protos/predictor.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) -# keep it static for now since folly-shared version is broken -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") -# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/predictor.grpc.pb.h") +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") -# dependencies -find_package(Boost REQUIRED) -find_package(Torch REQUIRED) -find_package(folly REQUIRED) -find_package(gflags REQUIRED) - -include_directories(${Torch_INCLUDE_DIRS}) -include_directories(${folly_INCLUDE_DIRS}) -include_directories(${PYTORCH_FMT_INCLUDE_PATH}) +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") set(CMAKE_CXX_STANDARD 17) -# torch deploy library -add_library(torch_deploy_internal STATIC - ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o - ${DEPLOY_SRC_PATH}/deploy.cpp - ${DEPLOY_SRC_PATH}/loader.cpp - ${DEPLOY_SRC_PATH}/path_environment.cpp - ${DEPLOY_SRC_PATH}/elf_file.cpp) - -# For python builtins. caffe2_interface_library properly -# makes use of the --whole-archive option. -target_link_libraries(torch_deploy_internal PRIVATE - crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw -) -target_link_libraries(torch_deploy_internal - PUBLIC shm torch ${PYTORCH_LIB_FMT} -) -caffe2_interface_library(torch_deploy_internal torch_deploy) - -# inference library - -# for our own header files -include_directories(include/) -include_directories(gen/) - -# define our library target -add_library(inference STATIC - src/Batching.cpp - src/BatchingQueue.cpp - src/GPUExecutor.cpp - src/ResultSplit.cpp - src/Exception.cpp - src/ResourceManager.cpp -) - -# -rdynamic is needed to link against the static library -target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" - dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} -) - -# for generated protobuf - -# grpc headers. e.g. ~/.local/include -include_directories(${GRPC_HEADER_INCLUDE_PATH}) - -set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") -set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") -set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") -set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") +# Torch + FBGEMM +find_package(Torch REQUIRED) +add_library( fbgemm SHARED IMPORTED GLOBAL ) +set_target_properties(fbgemm PROPERTIES IMPORTED_LOCATION ${FBGEMM_LIB}) -add_library(pred_grpc_proto STATIC - ${pred_grpc_srcs} - ${pred_grpc_hdrs} - ${pred_proto_srcs} - ${pred_proto_hdrs}) -target_link_libraries(pred_grpc_proto +add_library(hw_grpc_proto STATIC + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) +target_link_libraries(hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF}) -# server +# Targets greeter_[async_](client|server) add_executable(server server.cpp) target_link_libraries(server - inference - torch_deploy - pred_grpc_proto "${TORCH_LIBRARIES}" - ${FOLLY_LIBRARIES} - ${PYTORCH_LIB_FMT} - ${FBGEMM_LIB} + fbgemm + hw_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} - ${_PROTOBUF_LIBPROTOBUF}) + ${_PROTOBUF_LIBPROTOBUF} +) diff --git a/torchrec/inference/README.md b/torchrec/inference/README.md index 8958fc72d..0a8e1cc17 100644 --- a/torchrec/inference/README.md +++ b/torchrec/inference/README.md @@ -1,116 +1,65 @@ # TorchRec Inference Library (**Experimental** Release) ## Overview -TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. +TorchRec Inference is a C++ library that supports **gpu inference**. Previously, the TorchRec inference library was authored with torch.package and torch.deploy, which are old and deprecated. All the previous files live under the directory inference_legacy for reference. -Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. +TorchRec inference was reauthored with simplicity in mind, while also reflecting the current production environment for RecSys models, namely torch.fx for graph capturing/tracing and TorchScript for model inference in a C++ environment. The inference solution here is meant to serve as a simple reference and example, not a fully scaled out solution for production use cases. The current solution demonstrates converting the DLRM model in Python to TorchScript, running a C++ inference server with the model on a GPU, and sending requests to said server via a python client. -## Example +## Requirements -C++ 17 is a requirement. +C++ 17 is a requirement. GCC version has to be >= 9, with initial testing done on GCC 9.
### **1. Install Dependencies** - -Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy -is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: - +1. [GRPC for C++][https://grpc.io/docs/languages/cpp/quickstart/] needs to be installed, with the resulting installation directory being `$HOME/.local` +2. Ensure that **the protobuf compiler (protoc) binary being used is from the GRPC installation above**. The protoc binary will live in `$HOME/.local/bin`, which may not match with the system protoc binary, can check with `which protoc`. +3. Install PyTorch, FBGEMM, and TorchRec (ideally in a virtual environment): ``` -sudo nvidia-docker build -t torchrec . -sudo nvidia-docker run -it torchrec:latest +pip install torch --index-url https://download.pytorch.org/whl/cu121 +pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/cu121 +pip install torchmetrics==1.0.3 +pip install torchrec --index-url https://download.pytorch.org/whl/cu121 ``` + ### **2. Set variables** Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. ``` -# provide the cmake prefix path of pytorch, folly, and fmt. -# fmt and boost are pulled from folly's installation in this example. -export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" -export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" -export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" - -# provide fmt from pytorch for torch deploy -export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" -export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" - -# provide necessary info to link to torch deploy -export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" -export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" - -# provide common.cmake from grpc/examples, makes linking to grpc easier -export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" -export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" - -# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators -export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" - -# provide path to python packages for torch deploy runtime -export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" -``` - -Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. +# provide fbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +find $HOME -name fbgemm_gpu_py.so +# Use path from correct virtual environment above and set environment variable $FBGEMM_LIB to it +export FBGEMM_LIB="" ``` -# double-conversion, fmt and gflags are pulled from folly's installation in this example -export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" -export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" -export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" -export PYTORCH_LIB_PATH="~/pytorch/build/lib/" -export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" -export LIBRARY_PATH="$PYTORCH_LIB_PATH" -``` +### **3. Generate TorchScripted DLRM model** -### **3. Package DLRM model** - -The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement -`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and -implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read -more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. - -`/torchrec/examples/inference/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference/dlrm_predict.py`). -`DLRMPredictModule` is packaged for inference in the following example. +Here, we generate the DLRM model in Torchscript and save it for model loading later on. ``` git clone https://github.com/pytorch/torchrec.git -cd ~/torchrec/examples/inference/ -python dlrm_packager.py --output_path /tmp/model_package.zip +cd ~/torchrec/torchrec/inference/ +python3 dlrm_packager.py --output_path /tmp/model.pt ``` - ### **4. Build inference library and example server** -Generate protobuf C++ and Python code from protobuf +Generate Python code from protobuf for client and build the server. ``` -cd ~/torchrec/inference/ -mkdir -p gen/torchrec/inference - -# C++ (server) -protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto - -protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto - - # Python (client) -python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +python -m grpc_tools.protoc -I protos --python_out=. --grpc_python_out=. protos/predictor.proto ``` -Build inference library and example server +Build server and C++ protobufs ``` -cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" --DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ --DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ --DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ --DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ --DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ --DFBGEMM_LIB="$FBGEMM_LIB" +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');" -DFBGEMM_LIB="$FBGEMM_LIB" cd build make -j @@ -119,28 +68,20 @@ make -j ### **5. Run server and client** -Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +Start the server, loading in the model saved previously ``` -CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +./server /tmp/model.pt ``` **output** -In the logs, a plan should be outputted by the Torchrec planner: - -``` -INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # -INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # -INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # -INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # -INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # -INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # -INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # -INFO:.torchrec.distributed.planner.stats:# # -INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # -INFO:.torchrec.distributed.planner.stats:# quant: 26 # +In the logs, you should see: + +``` +Loading model... +Sanity Check with dummy inputs + Model Forward Completed, Output: 0.489247 +Server listening on 0.0.0.0:50051 ```` `nvidia-smi` output should also show allocation of the model onto the gpu: @@ -155,7 +96,7 @@ INFO:.torchrec.distributed.planner.stats:# quant: 26 +-----------------------------------------------------------------------------+ ``` -Make a request to the server via the client: +In another terminal instance, make a request to the server via the client: ``` python client.py @@ -166,74 +107,3 @@ python client.py ``` Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] ``` - -
- -## Planned work - -- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference -- In-code documentation -- Simplify installation process - -
- -## Potential issues and solutions - -Skip this section if you had no issues with installation or running the example. - -**Missing header files during pytorch installation** - -If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: - -``` -~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory - #include - ^~~~~~~~ -compilation terminated. -[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o -ninja: build stopped: subcommand failed. -``` - -To get these header files, install `cudatoolkit-dev`: -``` -conda install -c conda-forge cudatoolkit-dev -``` - -Re-run the installation after this. - -**libdouble-conversion missing** -``` -~/torchrec/torchrec/inference/build$ ./example -./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory -``` - -If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: - -``` -sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ -libdouble-conversion.so.3 -``` - -**Two installations of glog** -``` -~/torchrec/torchrec/inference/build$ ./example -ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and -'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). -``` -The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was -previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. - -To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. - -**Undefined symbols with std::string or cxx11** - -If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely -that your dependencies were compiled with different ABI values. Re-compile your dependencies -and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. - -The ABI value of pytorch can be checked via: - -``` -import torch -torch._C._GLIBCXX_USE_CXX11_ABI -``` diff --git a/torchrec/inference/__init__.py b/torchrec/inference/__init__.py index 6c2050114..a0ce8680a 100644 --- a/torchrec/inference/__init__.py +++ b/torchrec/inference/__init__.py @@ -7,21 +7,4 @@ # pyre-strict -"""Torchrec Inference - -Torchrec inference provides a Torch.Deploy based library for GPU inference. - -These includes: - - Model packaging in Python - - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. - - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. - - Model serving in C++ - - `BatchingQueue` is a generalized config-based request tensor batching implementation. - - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. - -We implemented an example of how to use this library with the TorchRec DLRM model. - - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. - - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. -""" - from . import model_packager, modules # noqa # noqa diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py index c185f3909..725338f46 100644 --- a/torchrec/inference/client.py +++ b/torchrec/inference/client.py @@ -9,8 +9,8 @@ import logging import grpc +import predictor_pb2, predictor_pb2_grpc import torch -from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc from torch.utils.data import DataLoader from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset @@ -20,18 +20,13 @@ def create_training_batch(args: argparse.Namespace) -> Batch: return next( iter( - DataLoader( - RandomRecDataset( - keys=DEFAULT_CAT_NAMES, - batch_size=args.batch_size, - hash_size=args.num_embedding_features, - ids_per_feature=1, - num_dense=len(DEFAULT_INT_NAMES), - ), - batch_sampler=None, - pin_memory=False, - num_workers=0, - ) + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), ) ) diff --git a/torchrec/inference/dlrm_packager.py b/torchrec/inference/dlrm_packager.py new file mode 100644 index 000000000..560e31c16 --- /dev/null +++ b/torchrec/inference/dlrm_packager.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import sys +from typing import List + +from dlrm_predict import create_training_batch, DLRMModelConfig, DLRMPredictFactory +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec dlrm model packager") + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default="45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138," + "10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35", + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--sparse_feature_names", + type=str, + default=",".join(DEFAULT_CAT_NAMES), + help="Comma separated names of the sparse features.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--num_dense_features", + type=int, + default=len(DEFAULT_INT_NAMES), + help="Number of dense features.", + ) + parser.add_argument( + "--output_path", + type=str, + help="Output path of model package.", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + """ + Use torch.package to package the torchrec DLRM Model. + + Args: + argv (List[str]): command line args. + + Returns: + None. + """ + + args = parse_args(argv) + + args.batch_size = 10 + args.num_embedding_features = 26 + batch = create_training_batch(args) + + model_config = DLRMModelConfig( + dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))), + dense_in_features=args.num_dense_features, + embedding_dim=args.embedding_dim, + id_list_features_keys=args.sparse_feature_names.split(","), + num_embeddings_per_feature=list( + map(int, args.num_embeddings_per_feature.split(",")) + ), + num_embeddings=args.num_embeddings, + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + sample_input=batch, + ) + + script_module = DLRMPredictFactory(model_config).create_predict_module(world_size=1) + + script_module.save(args.output_path) + print(f"Package is saved to {args.output_path}") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/inference/dlrm_predict.py b/torchrec/inference/dlrm_predict.py new file mode 100644 index 000000000..3ac2966ad --- /dev/null +++ b/torchrec/inference/dlrm_predict.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch +from torchrec.fx.tracer import Tracer +from torchrec.inference.modules import ( + PredictFactory, + PredictModule, + quantize_inference_model, + shard_quant_model, +) +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_training_batch(args) -> Batch: + return RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ).batch_generator._generate_batch() + + +# OSS Only + + +@dataclass +class DLRMModelConfig: + dense_arch_layer_sizes: List[int] + dense_in_features: int + embedding_dim: int + id_list_features_keys: List[str] + num_embeddings_per_feature: List[int] + num_embeddings: int + over_arch_layer_sizes: List[int] + sample_input: Batch + + +class DLRMPredictModule(PredictModule): + """ + nn.Module to wrap DLRM model to use for inference. + + Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (List[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (List[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + id_list_features_keys (List[str]): the names of the sparse features. Used to + construct a batch for inference. + dense_device: (Optional[torch.device]). + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + id_list_features_keys: List[str], + dense_device: Optional[torch.device] = None, + ) -> None: + module = DLRM( + embedding_bag_collection=embedding_bag_collection, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=dense_device, + ) + super().__init__(module) + + self.id_list_features_keys: List[str] = id_list_features_keys + + def predict_forward( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Args: + batch (Dict[str, torch.Tensor]): currently expects input dense features + to be mapped to the key "float_features" and input sparse features + to be mapped to the key "id_list_features". + + Returns: + Dict[str, torch.Tensor]: output of inference. + """ + + try: + logits = self.predict_module( + batch["float_features"], + KeyedJaggedTensor( + keys=self.id_list_features_keys, + lengths=batch["id_list_features.lengths"], + values=batch["id_list_features.values"], + ), + ) + predictions = logits.sigmoid() + except Exception as e: + logger.info(e) + raise e + + # Flip predictions tensor to be 1D. TODO: Determine why prediction shape + # can be 2D at times (likely due to input format?) + predictions = predictions.reshape( + [ + predictions.size()[0], + ] + ) + + return { + "default": predictions.to(torch.device("cpu"), non_blocking=True).float() + } + + +class DLRMPredictFactory(PredictFactory): + def __init__(self, model_config: DLRMModelConfig) -> None: + self.model_config = model_config + + def create_predict_module(self, world_size: int) -> torch.nn.Module: + logging.basicConfig(level=logging.INFO) + device = torch.device("cuda:0") + + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=self.model_config.embedding_dim, + num_embeddings=( + self.model_config.num_embeddings_per_feature[feature_idx] + if self.model_config.num_embeddings is None + else self.model_config.num_embeddings + ), + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate( + self.model_config.id_list_features_keys + ) + ] + ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) + + module = DLRMPredictModule( + embedding_bag_collection=ebc, + dense_in_features=self.model_config.dense_in_features, + dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, + over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, + id_list_features_keys=self.model_config.id_list_features_keys, + dense_device=device, + ) + + table_fqns = [] + for name, _ in module.named_modules(): + if "t_" in name: + table_fqns.append(name.split(".")[-1]) + + quant_model = quantize_inference_model(module) + sharded_model, _ = shard_quant_model(quant_model, table_fqns) + + batch = {} + batch["float_features"] = self.model_config.sample_input.dense_features.cuda() + batch["id_list_features.lengths"] = ( + self.model_config.sample_input.sparse_features.lengths().cuda() + ) + batch["id_list_features.values"] = ( + self.model_config.sample_input.sparse_features.values().cuda() + ) + + sharded_model(batch) + + tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"]) + + graph = tracer.trace(sharded_model) + gm = torch.fx.GraphModule(sharded_model, graph) + + gm(batch) + scripted_gm = torch.jit.script(gm) + scripted_gm(batch) + return scripted_gm + + def batching_metadata(self) -> Dict[str, str]: + return { + "float_features": "dense", + "id_list_features": "sparse", + } + + def result_metadata(self) -> str: + return "dict_of_tensor" + + def run_weights_independent_tranformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + return predict_module + + def run_weights_dependent_transformations( + self, predict_module: torch.nn.Module + ) -> torch.nn.Module: + """ + Run transformations that depends on weights of the predict module. e.g. lowering to a backend. + """ + return predict_module diff --git a/torchrec/inference/inference_legacy/CMakeLists.txt b/torchrec/inference/inference_legacy/CMakeLists.txt new file mode 100644 index 000000000..794bb1490 --- /dev/null +++ b/torchrec/inference/inference_legacy/CMakeLists.txt @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) +project(inference) + +# This step is crucial to ensure that the +# _REFLECTION, _GRPC_GRPCPP and _PROTOBUF_LIBPROTOBUF variables are set. +# e.g. ~/gprc/examples/cpp/cmake/common.cmake +include(${GRPC_COMMON_CMAKE_PATH}/common.cmake) + + +# abi and other flags + +if(DEFINED GLIBCXX_USE_CXX11_ABI) + if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") + endif() +endif() + +# keep it static for now since folly-shared version is broken +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") + +# dependencies +find_package(Boost REQUIRED) +find_package(Torch REQUIRED) +find_package(folly REQUIRED) +find_package(gflags REQUIRED) + +include_directories(${Torch_INCLUDE_DIRS}) +include_directories(${folly_INCLUDE_DIRS}) +include_directories(${PYTORCH_FMT_INCLUDE_PATH}) + +set(CMAKE_CXX_STANDARD 17) + +# torch deploy library +add_library(torch_deploy_internal STATIC + ${DEPLOY_INTERPRETER_PATH}/libtorch_deployinterpreter.o + ${DEPLOY_SRC_PATH}/deploy.cpp + ${DEPLOY_SRC_PATH}/loader.cpp + ${DEPLOY_SRC_PATH}/path_environment.cpp + ${DEPLOY_SRC_PATH}/elf_file.cpp) + +# For python builtins. caffe2_interface_library properly +# makes use of the --whole-archive option. +target_link_libraries(torch_deploy_internal PRIVATE + crypt pthread dl util m z ffi lzma readline nsl ncursesw panelw +) +target_link_libraries(torch_deploy_internal + PUBLIC shm torch ${PYTORCH_LIB_FMT} +) +caffe2_interface_library(torch_deploy_internal torch_deploy) + +# inference library + +# for our own header files +include_directories(include/) +include_directories(gen/) + +# define our library target +add_library(inference STATIC + src/Batching.cpp + src/BatchingQueue.cpp + src/GPUExecutor.cpp + src/ResultSplit.cpp + src/Exception.cpp + src/ResourceManager.cpp +) + +# -rdynamic is needed to link against the static library +target_link_libraries(inference "-Wl,--no-as-needed -rdynamic" + dl torch_deploy "${TORCH_LIBRARIES}" ${FBGEMM_LIB} ${FOLLY_LIBRARIES} +) + +# for generated protobuf + +# grpc headers. e.g. ~/.local/include +include_directories(${GRPC_HEADER_INCLUDE_PATH}) + +set(pred_grpc_srcs "gen/torchrec/inference/predictor.grpc.pb.cc") +set(pred_grpc_hdrs "gen/torchrec/inference/predictor.grpc.pb.h") +set(pred_proto_srcs "gen/torchrec/inference/predictor.pb.cc") +set(pred_proto_hdrs "gen/torchrec/inference/predictor.pb.h") + +add_library(pred_grpc_proto STATIC + ${pred_grpc_srcs} + ${pred_grpc_hdrs} + ${pred_proto_srcs} + ${pred_proto_hdrs}) + +target_link_libraries(pred_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) + +# server +add_executable(server server.cpp) +target_link_libraries(server + inference + torch_deploy + pred_grpc_proto + "${TORCH_LIBRARIES}" + ${FOLLY_LIBRARIES} + ${PYTORCH_LIB_FMT} + ${FBGEMM_LIB} + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) diff --git a/torchrec/inference/inference_legacy/README.md b/torchrec/inference/inference_legacy/README.md new file mode 100644 index 000000000..fc3d8afcb --- /dev/null +++ b/torchrec/inference/inference_legacy/README.md @@ -0,0 +1,239 @@ +# TorchRec Inference Library (**Experimental** Release) + +## Overview +TorchRec Inference is a C++ library that supports **multi-gpu inference**. The Torchrec library is used to shard models written and packaged in Python via [torch.package](https://pytorch.org/docs/stable/package.html) (an alternative to TorchScript). The [torch.deploy](https://pytorch.org/docs/stable/deploy.html) library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. + +Follow the instructions below to package a DLRM model in Python, run a C++ inference server with the model on a GPU and send requests to said server via a python client. + +## Example + +C++ 17 is a requirement. + +
+ +### **1. Install Dependencies** + +Follow the instructions at: https://github.com/pytorch/pytorch/blob/master/docs/source/deploy.rst to ensure torch::deploy +is working in your environment. Use the Dockerfile in the docker directory to install all dependencies. Run it via: + +``` +sudo nvidia-docker build -t torchrec . +sudo nvidia-docker run -it torchrec:latest +``` + +### **2. Set variables** + +Replace these variables with the relevant paths in your system. Check `CMakeLists.txt` and `server.cpp` to see how they're used throughout the build and runtime. + +``` +# provide the cmake prefix path of pytorch, folly, and fmt. +# fmt and boost are pulled from folly's installation in this example. +export FOLLY_CMAKE_DIR="~/folly-build/installed/folly/lib/cmake/folly" +export FMT_CMAKE_DIR="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/cmake/fmt" +export BOOST_CMAKE_DIR="~/folly-build/installed/boost-4M2ZnvEM4UWTqpsEJRQTB4oejmX3LmgYC9pcBiuVlmA/lib/cmake/Boost-1.78.0" + +# provide fmt from pytorch for torch deploy +export PYTORCH_FMT_INCLUDE_PATH="~/pytorch/third_party/fmt/include/" +export PYTORCH_LIB_FMT="~/pytorch/build/lib/libfmt.a" + +# provide necessary info to link to torch deploy +export DEPLOY_INTERPRETER_PATH="/pytorch/build/torch/csrc/deploy" +export DEPLOY_SRC_PATH="~/pytorch/torch/csrc/deploy" + +# provide common.cmake from grpc/examples, makes linking to grpc easier +export GRPC_COMMON_CMAKE_PATH="~/grpc/examples/cpp/cmake" +export GRPC_HEADER_INCLUDE_PATH="~/.local/include/" + +# provide libfbgemm_gpu_py.so to enable fbgemm_gpu c++ operators +export FBGEMM_LIB="~/anaconda3/envs/inference/lib/python3.8/site-packages/fbgemm_gpu-0.1.0-py3.8-linux-x86_64.egg/fbgemm_gpu/libfbgemm_gpu_py.so" + +# provide path to python packages for torch deploy runtime +export PYTHON_PACKAGES_PATH="~/anaconda3/envs/inference/lib/python3.8/site-packages/" +``` + +Update `$LD_LIBRARY_PATH` and `$LIBRARY_PATH` to enable linker to locate libraries. + +``` +# double-conversion, fmt and gflags are pulled from folly's installation in this example +export DOUBLE_CONVERSION_LIB_PATH="~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib" +export FMT_LIB_PATH="~/folly-build/fmt-WuK9LOk2P03KJKCt1so6VofoXa4qtyD5kQk6cZMphjg/lib64/" +export GFLAGS_LIB_PATH="~/folly-build/installed/gflags-KheHQBqQ3_iL3yJBFwWe5M5f8Syd-LKAX352cxkhQMc/lib" +export PYTORCH_LIB_PATH="~/pytorch/build/lib/" + +export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$DOUBLE_CONVERSION_LIB_PATH:$FMT_LIB_PATH:$GFLAGS_LIB_PATH:$PYTORCH_LIB_PATH" +export LIBRARY_PATH="$PYTORCH_LIB_PATH" +``` + +### **3. Package DLRM model** + +The `PredictFactoryPackager` class in `model_packager.py` can be used to implement your own packager class. Implement +`set_extern_modules` to specify the dependencies of your predict module that should be accessed from the system and +implement `set_mocked_modules` to specify dependencies that should be mocked (necessary to import but not use). Read +more about extern and mock modules in the `torch.package` documentation: https://pytorch.org/docs/stable/package.html. + +`/torchrec/examples/inference_legacy/dlrm_package.py` provides an example of packaging a module for inference (`/torchrec/examples/inference_legacy/dlrm_predict.py`). +`DLRMPredictModule` is packaged for inference in the following example. + +``` +git clone https://github.com/pytorch/torchrec.git + +cd ~/torchrec/examples/inference_legacy/ +python dlrm_packager.py --output_path /tmp/model_package.zip +``` + + + +### **4. Build inference library and example server** + +Generate protobuf C++ and Python code from protobuf + +``` +cd ~/torchrec/inference/ +mkdir -p gen/torchrec/inference + +# C++ (server) +protoc -I protos/ --grpc_out=gen/torchrec/inference --plugin=protoc-gen-grpc=/home/shabab/.local/bin/grpc_cpp_plugin protos/predictor.proto + +protoc -I protos/ --cpp_out=gen/torchrec/inference protos/predictor.proto + + +# Python (client) +python -m grpc_tools.protoc -I protos --python_out=gen/torchrec/inference --grpc_python_out=gen/torchrec/inference protos/predictor.proto +``` + + +Build inference library and example server +``` +cmake -S . -B build/ -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)');$FOLLY_CMAKE_DIR;$BOOST_CMAKE_DIR;$BOOST_CMAKE_DIR;" +-DPYTORCH_FMT_INCLUDE_PATH="$PYTORCH_FMT_INCLUDE_PATH" \ +-DPYTORCH_LIB_FMT="$PYTORCH_LIB_FMT" \ +-DDEPLOY_INTERPRETER_PATH="$DEPLOY_INTERPRETER_PATH" \ +-DDEPLOY_SRC_PATH="$DEPLOY_SRC_PATH" \ +-DGRPC_COMMON_CMAKE_PATH="$GRPC_COMMON_CMAKE_PATH" \ -DGRPC_HEADER_INCLUDE_PATH="$GRPC_HEADER_INCLUDE_PATH" \ +-DFBGEMM_LIB="$FBGEMM_LIB" + +cd build +make -j +``` + + +### **5. Run server and client** + +Run server. Update `CUDA_VISABLE_DEVICES` depending on the world size. +``` +CUDA_VISABLE_DEVICES="0" ./server --package_path="/tmp/model_package.zip" --python_packages_path $PYTHON_PACKAGES_PATH +``` + +**output** + +In the logs, a plan should be outputted by the Torchrec planner: + +``` +INFO:.torchrec.distributed.planner.stats:# --- Planner Statistics --- # +INFO:.torchrec.distributed.planner.stats:# --- Evalulated 1 proposal(s), found 1 possible plan(s) --- # +INFO:.torchrec.distributed.planner.stats:# ----------------------------------------------------------------------------------------------- # +INFO:.torchrec.distributed.planner.stats:# Rank HBM (GB) DDR (GB) Perf (ms) Input (MB) Output (MB) Shards # +INFO:.torchrec.distributed.planner.stats:# ------ ---------- ---------- ----------- ------------ ------------- -------- # +INFO:.torchrec.distributed.planner.stats:# 0 0.2 (1%) 0.0 (0%) 0.08 0.1 1.02 TW: 26 # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Input: MB/iteration, Output: MB/iteration, Shards: number of tables # +INFO:.torchrec.distributed.planner.stats:# HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients # +INFO:.torchrec.distributed.planner.stats:# # +INFO:.torchrec.distributed.planner.stats:# Compute Kernels: # +INFO:.torchrec.distributed.planner.stats:# quant: 26 # +```` + +`nvidia-smi` output should also show allocation of the model onto the gpu: + +``` ++-----------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=============================================================================| +| 0 N/A N/A 86668 C ./example 1357MiB | ++-----------------------------------------------------------------------------+ +``` + +Make a request to the server via the client: + +``` +python client.py +``` + +**output** + +``` +Response: [0.13199582695960999, -0.1048036441206932, -0.06022112816572189, -0.08765199035406113, -0.12735335528850555, -0.1004377081990242, 0.05509107559919357, -0.10504599660634995, 0.1350800096988678, -0.09468207508325577, 0.24013587832450867, -0.09682435542345047, 0.0025023818016052246, -0.09786031395196915, -0.26396819949150085, -0.09670191258192062, 0.2691854238510132, -0.10246685892343521, -0.2019493579864502, -0.09904996305704117, 0.3894067406654358, ...] +``` + +
+ +## Planned work + +- Provide benchmarks for torch deploy vs TorchScript and cpu, single gpu and multi-gpu inference +- In-code documentation +- Simplify installation process + +
+ +## Potential issues and solutions + +Skip this section if you had no issues with installation or running the example. + +**Missing header files during pytorch installation** + +If your environment is missing a speicfic set of header files such as `nvml.h` and `cuda_profiler_api.h`, the pytorch installation will fail with error messages similar to the code snippet below: + +``` +~/nvml_lib.h:13:10: fatal error: nvml.h: No such file or directory + #include + ^~~~~~~~ +compilation terminated. +[80/2643] Building CXX object third_party/ideep/mkl-dnn/third_party/oneDNN/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_convolution_list.cpp.o +ninja: build stopped: subcommand failed. +``` + +To get these header files, install `cudatoolkit-dev`: +``` +conda install -c conda-forge cudatoolkit-dev +``` + +Re-run the installation after this. + +**libdouble-conversion missing** +``` +~/torchrec/torchrec/inference/build$ ./example +./example: error while loading shared libraries: libdouble-conversion.so.3: cannot open shared object file: No such file or directory +``` + +If this issue persists even after adding double-conversion's path to $LD_LIBRARY_PATH (step 2) then solve by creating a symlink to `libdouble-conversion.so.3` with folly's installation of double-conversion: + +``` +sudo ln -s ~/folly-build/installed/double-conversion-skGL6pOaPHjtDwdXY-agzdwT1gvTXP0bD-7P4gKJD9I/lib/libdouble-conversion.so.3.1.4 \ +libdouble-conversion.so.3 +``` + +**Two installations of glog** +``` +~/torchrec/torchrec/inference/build$ ./example +ERROR: flag 'logtostderr' was defined more than once (in files '/home/shabab/glog/src/logging.cc' and +'/home/shabab/folly-build/extracted/glog-v0.4.0.tar.gz/glog-0.4.0/src/logging.cc'). +``` +The above issue, along with a host of others during building, can potentially occur if libinference is pointing to two different versions of glog (if one was +previously installed in your system). You can find this out by running `ldd` on your libinference shared object within the build path. The issue can be solved by using the glog version provided by folly. + +To use the glog version provided by folly, add the glog install path (in your folly-build directory) to your LD_LIBRARY_PATH much like step 2. + +**Undefined symbols with std::string or cxx11** + +If you get undefined symbol errors and the errors mention `std::string` or `cxx11`, it's likely +that your dependencies were compiled with different ABI values. Re-compile your dependencies +and ensure they all have the same value for `_GLIBCXX_USE_CXX11_ABI` in their build. + +The ABI value of pytorch can be checked via: + +``` +import torch +torch._C._GLIBCXX_USE_CXX11_ABI +``` diff --git a/torchrec/inference/inference_legacy/__init__.py b/torchrec/inference/inference_legacy/__init__.py new file mode 100644 index 000000000..670f2af78 --- /dev/null +++ b/torchrec/inference/inference_legacy/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Torchrec Inference + +Torchrec inference provides a Torch.Deploy based library for GPU inference. + +These includes: + - Model packaging in Python + - `PredictModule` and `PredictFactory` are the contracts between the Python model authoring and the C++ model serving. + - `PredictFactoryPackager` can be used to package a PredictFactory class using torch.package. + - Model serving in C++ + - `BatchingQueue` is a generalized config-based request tensor batching implementation. + - `GPUExecutor` handles the forward call into the inference model inside Torch.Deploy. + +We implemented an example of how to use this library with the TorchRec DLRM model. + - `examples/dlrm/inference/dlrm_packager.py`: this demonstrates how to export the DLRM model as a torch.package. + - `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model. +""" + +from . import model_packager, modules # noqa # noqa diff --git a/torchrec/inference/inference_legacy/client.py b/torchrec/inference/inference_legacy/client.py new file mode 100644 index 000000000..a3a9f2a83 --- /dev/null +++ b/torchrec/inference/inference_legacy/client.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint +# pyre-ignore-all-errors + + +import argparse +import logging + +import grpc +import torch +from gen.torchrec.inference import predictor_pb2, predictor_pb2_grpc +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.datasets.utils import Batch + + +def create_training_batch(args: argparse.Namespace) -> Batch: + return next( + iter( + DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=args.num_embedding_features, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), + batch_sampler=None, + pin_memory=False, + num_workers=0, + ) + ) + ) + + +def create_request( + batch: Batch, args: argparse.Namespace +) -> predictor_pb2.PredictionRequest: + def to_bytes(tensor: torch.Tensor) -> bytes: + return tensor.cpu().numpy().tobytes() + + float_features = predictor_pb2.FloatFeatures( + num_features=args.num_float_features, + values=to_bytes(batch.dense_features), + ) + + id_list_features = predictor_pb2.SparseFeatures( + num_features=args.num_id_list_features, + values=to_bytes(batch.sparse_features.values()), + lengths=to_bytes(batch.sparse_features.lengths()), + ) + + id_score_list_features = predictor_pb2.SparseFeatures(num_features=0) + embedding_features = predictor_pb2.FloatFeatures(num_features=0) + unary_features = predictor_pb2.SparseFeatures(num_features=0) + + return predictor_pb2.PredictionRequest( + batch_size=args.batch_size, + float_features=float_features, + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + embedding_features=embedding_features, + unary_features=unary_features, + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--ip", + type=str, + default="0.0.0.0", + ) + parser.add_argument( + "--port", + type=int, + default=50051, + ) + parser.add_argument( + "--num_float_features", + type=int, + default=13, + ) + parser.add_argument( + "--num_id_list_features", + type=int, + default=26, + ) + parser.add_argument( + "--num_id_score_list_features", + type=int, + default=0, + ) + parser.add_argument( + "--num_embedding_features", + type=int, + default=100000, + ) + parser.add_argument( + "--embedding_feature_dim", + type=int, + default=100, + ) + parser.add_argument( + "--batch_size", + type=int, + default=100, + ) + + args: argparse.Namespace = parser.parse_args() + + training_batch: Batch = create_training_batch(args) + request: predictor_pb2.PredictionRequest = create_request(training_batch, args) + + with grpc.insecure_channel(f"{args.ip}:{args.port}") as channel: + stub = predictor_pb2_grpc.PredictorStub(channel) + response = stub.Predict(request) + print("Response: ", response.predictions["default"].data) + +if __name__ == "__main__": + logging.basicConfig() diff --git a/torchrec/inference/docker/Dockerfile b/torchrec/inference/inference_legacy/docker/Dockerfile similarity index 100% rename from torchrec/inference/docker/Dockerfile rename to torchrec/inference/inference_legacy/docker/Dockerfile diff --git a/torchrec/inference/docs/inference.rst b/torchrec/inference/inference_legacy/docs/inference.rst similarity index 100% rename from torchrec/inference/docs/inference.rst rename to torchrec/inference/inference_legacy/docs/inference.rst diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h new file mode 100644 index 000000000..26e7987f9 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Assert.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Internal Assertion failed: (" + std::string(#condition) + "), " + \ + "function " + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + \ + "Please report bug to TorchRec.\n" + message + "\n"); \ + } + +#define TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(condition) \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(#condition, "") + +#define TORCHREC_INTERNAL_ASSERT_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_INTERNAL_ASSERT(...) \ + TORCHREC_INTERNAL_ASSERT_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_INTERNAL_ASSERT_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_INTERNAL_ASSERT_NO_MESSAGE(__VA_ARGS__)); + +#define TORCHREC_CHECK_WITH_MESSAGE(condition, message) \ + if (!(condition)) { \ + throw std::runtime_error( \ + "Check failed: (" + std::string(#condition) + "), " + "function " + \ + __FUNCTION__ + ", file " + __FILE__ + ", line " + \ + std::to_string(__LINE__) + ".\n" + message + "\n"); \ + } + +#define TORCHREC_CHECK_NO_MESSAGE(condition) \ + TORCHREC_CHECK_WITH_MESSAGE(#condition, "") + +#define TORCHREC_CHECK_(x, condition, message, FUNC, ...) FUNC + +#define TORCHREC_CHECK(...) \ + TORCHREC_CHECK_( \ + , \ + ##__VA_ARGS__, \ + TORCHREC_CHECK_WITH_MESSAGE(__VA_ARGS__), \ + TORCHREC_CHECK_NO_MESSAGE(__VA_ARGS__)); diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h new file mode 100644 index 000000000..2c5e8837c --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Batching.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using LazyTensorRef = folly::detail::Lazy>&; + +// BatchingFunc should be responsible to move the output tensor to desired +// location using the device input. +class BatchingFunc { + public: + virtual ~BatchingFunc() = default; + + virtual std::unordered_map batch( + const std::string& /* featureName */, + const std::vector>& /* requests */, + const int64_t& /* totalNumBatch */, + LazyTensorRef /* batchOffsets */, + const c10::Device& /* device */, + LazyTensorRef /* batchItems */) = 0; +}; + +/** + * TorchRecBatchingFuncRegistry is used to register custom batching functions. + */ +C10_DECLARE_REGISTRY(TorchRecBatchingFuncRegistry, BatchingFunc); + +#define REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY(name, priority, ...) \ + C10_REGISTER_CLASS_WITH_PRIORITY( \ + TorchRecBatchingFuncRegistry, name, priority, __VA_ARGS__); + +#define REGISTER_TORCHREC_BATCHING_FUNC(name, ...) \ + REGISTER_TORCHREC_BATCHING_FUNC_WITH_PIORITY( \ + name, c10::REGISTRY_DEFAULT, __VA_ARGS__); + +std::unordered_map combineFloat( + const std::string& featureName, + const std::vector>& requests); + +std::unordered_map combineSparse( + const std::string& featureName, + const std::vector>& requests, + bool isWeighted); + +std::unordered_map combineEmbedding( + const std::string& featureName, + const std::vector>& requests); + +void moveIValueToDevice(c10::IValue& val, const c10::Device& device); + +std::unordered_map moveToDevice( + std::unordered_map combined, + const c10::Device& device); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h new file mode 100644 index 000000000..f012572b2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/BatchingQueue.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torchrec/inference/Batching.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResourceManager.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +using BatchQueueCb = std::function)>; + +class BatchingQueue { + public: + struct Config { + std::chrono::milliseconds batchingInterval = std::chrono::milliseconds(10); + std::chrono::milliseconds queueTimeout = std::chrono::milliseconds(500); + int numExceptionThreads = 4; + int numMemPinnerThreads = 4; + int maxBatchSize = 2000; + // For feature name to BatchingFunc name. + const std::unordered_map batchingMetadata; + std::function eventCreationFn; + std::function warmupFn; + }; + + BatchingQueue(const BatchingQueue&) = delete; + BatchingQueue& operator=(const BatchingQueue&) = delete; + + BatchingQueue( + std::vector cbs, + const Config& config, + int worldSize, + std::unique_ptr observer, + std::shared_ptr resourceManager = nullptr); + ~BatchingQueue(); + + void add( + std::shared_ptr request, + folly::Promise> promise); + + void stop(); + + private: + struct QueryQueueEntry { + std::shared_ptr request; + RequestContext context; + std::chrono::time_point addedTime; + }; + + struct BatchingQueueEntry { + std::vector> requests; + std::vector contexts; + std::chrono::time_point addedTime; + }; + + void createBatch(); + + void pinMemory(int gpuIdx); + + void observeBatchCompletion(size_t batchSizeBytes, size_t numRequests); + + const Config config_; + + // Batching func name to batching func instance. + std::unordered_map> batchingFuncs_; + std::vector cbs_; + std::thread batchingThread_; + std::vector memPinnerThreads_; + std::unique_ptr rejectionExecutor_; + folly::Synchronized> requestQueue_; + std::vector>> + batchingQueues_; + std::atomic stopping_; + int worldSize_; + std::unique_ptr observer_; + std::shared_ptr resourceManager_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h new file mode 100644 index 000000000..5667d1ad2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Exception.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace torchrec { + +// We have different error code defined for different kinds of exceptions in +// fblearner/sigrid predictor. (Code pointer: +// fblearner/predictor/if/prediction_service.thrift.) We define different +// exception type here so that in fblearner/sigrid predictor we can detect the +// exception type and return the corresponding error code to reflect the right +// info. +class TorchrecException : public std::runtime_error { + public: + explicit TorchrecException(const std::string& error) + : std::runtime_error(error) {} +}; + +// GPUOverloadException maps to +// PredictionExceptionCode::GPU_BATCHING_QUEUE_TIMEOUT +class GPUOverloadException : public TorchrecException { + public: + explicit GPUOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// GPUExecutorOverloadException maps to +// PredictionExceptionCode::GPU_EXECUTOR_QUEUE_TIMEOUT +class GPUExecutorOverloadException : public TorchrecException { + public: + explicit GPUExecutorOverloadException(const std::string& error) + : TorchrecException(error) {} +}; + +// TorchDeployException maps to +// PredictorUserErrorCode::TORCH_DEPLOY_ERROR +class TorchDeployException : public TorchrecException { + public: + explicit TorchDeployException(const std::string& error) + : TorchrecException(error) {} +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h new file mode 100644 index 000000000..491acf48f --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ExceptionHandler.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/Exception.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { +template +void handleRequestException( + folly::Promise>& promise, + const std::string& msg) { + auto ex = folly::make_exception_wrapper(msg); + auto response = std::make_unique(); + response->exception = std::move(ex); + promise.setValue(std::move(response)); +} + +template +void handleBatchException( + std::vector& contexts, + const std::string& msg) { + for (auto& context : contexts) { + handleRequestException(context.promise, msg); + } +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h new file mode 100644 index 000000000..b55dacc5a --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/GPUExecutor.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#else +#include // @manual +#endif + +#include "torchrec/inference/BatchingQueue.h" +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/ResultSplit.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" + +namespace torchrec { + +class GPUExecutor { + public: + // Used to interface with python's garbage collector + struct GCConfig { + bool optimizationEnabled = false; + size_t collectionFreq = 1000; + size_t statReportingFreq = 10000; + std::unique_ptr observer = + std::make_unique(); + std::map threadIdToNumForwards = std::map(); + }; + + GPUExecutor( + std::shared_ptr manager, + torch::deploy::ReplicatedObj model, + size_t rank, + size_t worldSize, + std::shared_ptr func, + std::chrono::milliseconds queueTimeout, + std::shared_ptr + observer, // shared_ptr because used in completion executor callback + std::function warmupFn = {}, + std::optional numThreadsPerGPU = c10::nullopt, + std::unique_ptr gcConfig = std::make_unique()); + GPUExecutor(GPUExecutor&& executor) noexcept = default; + GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default; + ~GPUExecutor(); + + void callback(std::shared_ptr batch); + + void process(int idx); + + private: + // torch deploy + std::shared_ptr manager_; + torch::deploy::ReplicatedObj model_; + const size_t rank_; + const size_t worldSize_; + + folly::MPMCQueue> batches_; + std::vector processThreads_; + std::unique_ptr rejectionExecutor_; + std::unique_ptr completionExecutor_; + std::shared_ptr resultSplitFunc_; + const std::chrono::milliseconds queueTimeout_; + std::shared_ptr observer_; + std::function warmupFn_; + + std::mutex warmUpMutex_; + std::mutex warmUpAcquireSessionMutex_; + std::condition_variable warmUpCV_; + int warmUpCounter_{0}; + + size_t numThreadsPerGPU_; + + std::unique_ptr gcConfig_; + + void reportGCStats(c10::IValue stats); +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h new file mode 100644 index 000000000..de00aebac --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/JaggedTensor.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace torchrec { + +struct JaggedTensor { + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +struct KeyedJaggedTensor { + std::vector keys; + at::Tensor lengths; + at::Tensor values; + at::Tensor weights; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h new file mode 100644 index 000000000..14ac3bceb --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Observer.h @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +// Record generic timeseries stat with a key +class IDynamicTimeseriesObserver { + public: + virtual void addCount(uint32_t value, std::string key) = 0; + + virtual ~IDynamicTimeseriesObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyDynamicTimeseriesObserver : public IDynamicTimeseriesObserver { + public: + void addCount(uint32_t /* value */, std::string /* key */) override {} +}; + +class IBatchingQueueObserver { + public: + // Record the amount of time an entry of PredictionRequests + // in the batching queue waits before they are read and allocated + // onto a GPU device. + virtual void recordBatchingQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes for a batching function + // to execute. + virtual void recordBatchingFuncLatency( + uint32_t value, + std::string batchingFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the amount of time it takes to create a batch of + // requests. + virtual void recordBatchCreationLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of batching queue timeouts experienced. + virtual void addBatchingQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of times a GPU could not be chosen + // for allocation. + virtual void addGPUBusyCount(uint32_t value) = 0; + + // Increment the number of requests entering the batching queue. + virtual void addRequestsCount(uint32_t value) = 0; + + // Increment the number of bytes of tensors moved to cuda. + virtual void addBytesMovedToGPUCount(uint32_t value) = 0; + + // Increment the number of batches processed by the batching + // queue (moved onto the GPU executor). + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + // Increment the number of requests processed by the batching + // queue (moved onto the GPU executor). + virtual void addRequestsProcessedCount(uint32_t value) = 0; + + // The obervations that should be made when a batch is completed. + virtual void observeBatchCompletion( + size_t batchSizeBytes, + size_t numRequests) { + addBytesMovedToGPUCount(batchSizeBytes); + addBatchesProcessedCount(1); + addRequestsProcessedCount(numRequests); + } + + virtual ~IBatchingQueueObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyBatchingQueueObserver : public IBatchingQueueObserver { + public: + void recordBatchingQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchingFuncLatency( + uint32_t /* value */, + std::string /* batchingFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordBatchCreationLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addBatchingQueueTimeoutCount(uint32_t /* value */) override {} + + void addGPUBusyCount(uint32_t /* value */) override {} + + void addRequestsCount(uint32_t /* value */) override {} + + void addBytesMovedToGPUCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} + + void addRequestsProcessedCount(uint32_t /* value */) override {} +}; + +class IGPUExecutorObserver { + public: + // Record the amount of time a batch spends in the GPU Executor + // queue. + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of prediction (forward call, H2D). + virtual void recordPredictionLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of device to host transfer facilitated + // by result split function. + virtual void recordDeviceToHostLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency of splitting the result. + virtual void recordResultSplitLatency( + uint32_t value, + std::string resultSplitFuncName, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Record the latency from enqueue to completion. + virtual void recordTotalLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Increment the number of GPUExecutor queue timeouts. + virtual void addQueueTimeoutCount(uint32_t value) = 0; + + // Increment the number of predict exceptions. + virtual void addPredictionExceptionCount(uint32_t value) = 0; + + // Increment the number of batches successfully processed. + virtual void addBatchesProcessedCount(uint32_t value) = 0; + + virtual ~IGPUExecutorObserver() {} +}; + +class ISingleGPUExecutorObserver { + public: + virtual void addRequestsCount(uint32_t value) = 0; + virtual void addRequestProcessingExceptionCount(uint32_t value) = 0; + virtual void recordQueueLatency( + uint32_t value, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) = 0; + + virtual void recordRequestProcessingLatency( + uint32_t value, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + virtual ~ISingleGPUExecutorObserver() = default; +}; + +class EmptySingleGPUExecutorObserver : public ISingleGPUExecutorObserver { + void addRequestsCount(uint32_t) override {} + void addRequestProcessingExceptionCount(uint32_t) override {} + void recordQueueLatency( + uint32_t, + std::chrono::steady_clock::time_point = + std::chrono::steady_clock::now()) override {} + + void recordRequestProcessingLatency( + uint32_t, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) override {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyGPUExecutorObserver : public IGPUExecutorObserver { + public: + void recordQueueLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordPredictionLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordDeviceToHostLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordResultSplitLatency( + uint32_t /* value */, + std::string /* resultSplitFuncName */, + std::chrono::steady_clock::time_point /* now */) override {} + + void recordTotalLatency( + uint32_t /* value */, + std::chrono::steady_clock::time_point /* now */) override {} + + void addQueueTimeoutCount(uint32_t /* value */) override {} + + void addPredictionExceptionCount(uint32_t /* value */) override {} + + void addBatchesProcessedCount(uint32_t /* value */) override {} +}; + +class IResourceManagerObserver { + public: + // Add the number of requests in flight for a gpu + virtual void addOutstandingRequestsCount(uint32_t value, int gpuIdx) = 0; + + // Add the most in flight requests on a gpu ever + virtual void addAllTimeHighOutstandingCount(uint32_t value, int gpuIdx) = 0; + + // Record the latency for finding a device + virtual void addWaitingForDeviceLatency( + uint32_t value, + int gpuIdx, + std::chrono::steady_clock::time_point now = + std::chrono::steady_clock::now()) = 0; + + // Recording all stats related to resource manager at once. + virtual void recordAllStats( + uint32_t outstandingRequests, + uint32_t allTimeHighOutstanding, + uint32_t waitedForMs, + int gpuIdx) { + addOutstandingRequestsCount(outstandingRequests, gpuIdx); + addAllTimeHighOutstandingCount(allTimeHighOutstanding, gpuIdx); + addWaitingForDeviceLatency(waitedForMs, gpuIdx); + } + + virtual ~IResourceManagerObserver() {} +}; + +// Can be used for testing or for opt-ing out of observation. +class EmptyResourceManagerObserver : public IResourceManagerObserver { + public: + void addOutstandingRequestsCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addAllTimeHighOutstandingCount(uint32_t /* value */, int /* gpuIdx */) + override {} + + void addWaitingForDeviceLatency( + uint32_t /* value */, + int /* gpuIdx */, + std::chrono::steady_clock::time_point /* now */) override {} +}; + +// Helper for determining how much time has elapsed in milliseconds since a +// given time point. +inline std::chrono::milliseconds getTimeElapsedMS( + std::chrono::steady_clock::time_point startTime) { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime); +} + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h new file mode 100644 index 000000000..d3dd1ea18 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResourceManager.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "torchrec/inference/Observer.h" + +namespace torchrec { + +/** + * ResourceManager can be used to limit in-flight batches + * allocated onto GPUs to prevent OOMing. + */ +class ResourceManager { + public: + ResourceManager( + int worldSize, + size_t maxOutstandingBatches, + int logFrequency = 100, + std::unique_ptr observer = + std::make_unique()); + + // Returns whether batches can be allocated onto a device based on + // slack provided (ms) and maxOutstandingBatches_). + bool occupyDevice(int gpuIdx, std::chrono::milliseconds slack); + + void release(int gpuIdx); + + private: + folly::small_vector gpuToOutstandingBatches_; + // Helpful for tuning + folly::small_vector allTimeHigh_; + const size_t maxOutstandingBatches_; + const int logFrequency_; + // Align as 64B to avoid false sharing + alignas(64) std::mutex mu_; + std::unique_ptr observer_; +}; + +class ResourceManagerGuard { + public: + ResourceManagerGuard( + std::weak_ptr resourceManager, + int gpuIdx) + : resourceManager_(std::move(resourceManager)), gpuIdx_(gpuIdx) {} + + ~ResourceManagerGuard() { + std::shared_ptr rm = resourceManager_.lock(); + if (rm != nullptr) { + rm->release(gpuIdx_); + } + } + + private: + std::weak_ptr resourceManager_; + const int gpuIdx_; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h new file mode 100644 index 000000000..2c3ef2463 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ResultSplit.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torchrec { + +class ResultSplitFunc { + public: + virtual ~ResultSplitFunc() = default; + + virtual std::string name() = 0; + + virtual c10::IValue splitResult( + c10::IValue /* result */, + size_t /* nOffset */, + size_t /* nLength */, + size_t /* nTotalLength */) = 0; + + virtual c10::IValue moveToHost(c10::IValue /* result */) = 0; +}; + +/** + * TorchRecResultSplitFuncRegistry is used to register custom result split + * functions. + */ +C10_DECLARE_REGISTRY(TorchRecResultSplitFuncRegistry, ResultSplitFunc); + +#define REGISTER_TORCHREC_RESULTSPLIT_FUNC(name, ...) \ + C10_REGISTER_CLASS(TorchRecResultSplitFuncRegistry, name, __VA_ARGS__); + +c10::IValue splitDictOfTensor( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue splitDictOfTensors( + c10::IValue result, + size_t nOffset, + size_t nLength, + size_t nTotalLength); + +c10::IValue +splitDictWithMaskTensor(c10::IValue result, size_t nOffset, size_t nLength); + +class DictWithMaskTensorResultSplitFunc : public torchrec::ResultSplitFunc { + public: + virtual std::string name() override; + + virtual c10::IValue splitResult( + c10::IValue result, + size_t offset, + size_t length, + size_t /* nTotalLength */) override; + + c10::IValue moveToHost(c10::IValue result) override; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h new file mode 100644 index 000000000..478773768 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/ShardedTensor.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace torchrec { + +struct ShardMetadata { + std::vector shard_offsets; + std::vector shard_lengths; + + bool operator==(const ShardMetadata& other) const { + return shard_offsets == other.shard_offsets && + shard_lengths == other.shard_lengths; + } +}; + +struct Shard { + ShardMetadata metadata; + at::Tensor tensor; +}; + +struct ShardedTensorMetadata { + std::vector shards_metadata; +}; + +struct ShardedTensor { + std::vector sizes; + std::vector local_shards; + ShardedTensorMetadata metadata; +}; + +struct ReplicatedTensor { + ShardedTensor local_replica; + int64_t local_replica_id; + int64_t replica_count; +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h new file mode 100644 index 000000000..9da63d7c2 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/SingleGPUExecutor.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "torchrec/inference/Observer.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +class SingleGPUExecutor { + constexpr static const size_t kQUEUE_CAPACITY = 10000; + + public: + struct ExecInfo { + size_t gpuIdx; + size_t interpIdx; + torch::deploy::ReplicatedObj model; + }; + using ExecInfos = std::vector; + + SingleGPUExecutor( + std::shared_ptr manager, + ExecInfos execInfos, + size_t numGpu, + std::shared_ptr observer = + std::make_shared(), + c10::Device resultDevice = c10::kCPU, + size_t numProcessThreads = 1u, + bool useHighPriCudaStream = false); + + // Moveable only + SingleGPUExecutor(SingleGPUExecutor&& executor) noexcept = default; + SingleGPUExecutor& operator=(SingleGPUExecutor&& executor) noexcept = default; + ~SingleGPUExecutor(); + + void schedule(std::shared_ptr request); + + private: + void process(); + + std::shared_ptr manager_; + const ExecInfos execInfos_; + const size_t numGpu_; + const size_t numProcessThreads_; + const bool useHighPriCudaStream_; + const c10::Device resultDevice_; + std::shared_ptr observer_; + folly::MPMCQueue> requests_; + + std::unique_ptr processExecutor_; + std::unique_ptr completionExecutor_; + std::atomic roundRobinExecInfoNextIdx_; +}; +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h new file mode 100644 index 000000000..1150e2c23 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/TestUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include "torchrec/inference/JaggedTensor.h" +#include "torchrec/inference/Types.h" + +namespace torchrec { + +std::shared_ptr createRequest(at::Tensor denseTensor); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, const JaggedTensor& jagged); + +std::shared_ptr +createRequest(size_t batchSize, size_t numFeatures, at::Tensor embedding); + +JaggedTensor createJaggedTensor(const std::vector>& input); + +c10::List createIValueList( + const std::vector>& input); + +at::Tensor createEmbeddingTensor( + const std::vector>& input); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h new file mode 100644 index 000000000..0a09d1f35 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Types.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "torchrec/inference/ResourceManager.h" + +namespace torchrec { + +struct SparseFeatures { + uint32_t num_features; + // int32: T x B + folly::IOBuf lengths; + // T x B x L (jagged) + folly::IOBuf values; + // float16 + folly::IOBuf weights; +}; + +struct FloatFeatures { + uint32_t num_features; + // shape: {B} + folly::IOBuf values; +}; + +// TODO: Change the input format to torch::IValue. +// Currently only dense batching function support IValue. +using Feature = std::variant; + +struct PredictionRequest { + uint32_t batch_size; + std::unordered_map features; +}; + +struct PredictionResponse { + uint32_t batchSize; + c10::IValue predictions; + // If set, the result is an exception. + std::optional exception; +}; + +struct RequestContext { + uint32_t batchSize; + folly::Promise> promise; + // folly request context for request tracking in crochet + std::shared_ptr follyRequestContext; +}; + +using PredictionException = std::runtime_error; + +using Event = std:: + unique_ptr>; + +struct BatchingMetadata { + std::string type; + std::string device; + folly::F14FastSet pinned; +}; + +// noncopyable because we only want to move PredictionBatch around +// as it holds a reference to ResourceManagerGuard. We wouldn't want +// to inadvertently increase the reference count to ResourceManagerGuard +// with copies of this struct. +struct PredictionBatch : public boost::noncopyable { + std::string methodName; + std::vector args; + + size_t batchSize; + + c10::impl::GenericDict forwardArgs; + + std::vector contexts; + + std::unique_ptr resourceManagerGuard = nullptr; + + std::chrono::time_point enqueueTime = + std::chrono::steady_clock::now(); + + Event event; + + // Need a constructor to use make_shared/unique with + // noncopyable struct and not trigger copy-constructor. + PredictionBatch( + size_t bs, + c10::impl::GenericDict fa, + std::vector ctxs, + std::unique_ptr rmg = nullptr) + : batchSize(bs), + forwardArgs(std::move(fa)), + contexts(std::move(ctxs)), + resourceManagerGuard(std::move(rmg)) {} + + PredictionBatch( + std::string methodNameArg, + std::vector argsArg, + folly::Promise> promise) + : methodName(std::move(methodNameArg)), + args(std::move(argsArg)), + forwardArgs( + c10::impl::GenericDict(at::StringType::get(), at::AnyType::get())) { + contexts.push_back(RequestContext{1u, std::move(promise)}); + } + + size_t sizeOfIValue(const c10::IValue& val) const { + size_t size = 0; + if (val.isTensor()) { + size += val.toTensor().storage().nbytes(); + } else if (val.isList()) { + for (const auto& v : val.toListRef()) { + size += sizeOfIValue(v); + } + } + return size; + } + + inline size_t size() const { + size_t size = 0; + for (const auto& iter : forwardArgs) { + size += sizeOfIValue(iter.value()); + } + return size; + } +}; + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h new file mode 100644 index 000000000..4844c9ca1 --- /dev/null +++ b/torchrec/inference/inference_legacy/include/torchrec/inference/Validation.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "torchrec/inference/Types.h" + +namespace torchrec { + +// Returns whether sparse features (KeyedJaggedTensor) are valid. +// Currently validates: +// 1. Whether sum(lengths) == size(values) +// 2. Whether there are negative values in lengths +// 3. If weights is present, whether sum(lengths) == size(weights) +bool validateSparseFeatures( + at::Tensor& values, + at::Tensor& lengths, + std::optional maybeWeights = c10::nullopt); + +// Returns whether dense features are valid. +// Currently validates: +// 1. Whether the size of values is divisable by batch size (request level) +bool validateDenseFeatures(at::Tensor& values, size_t batchSize); + +} // namespace torchrec diff --git a/torchrec/inference/inference_legacy/model_packager.py b/torchrec/inference/inference_legacy/model_packager.py new file mode 100644 index 000000000..9957b3373 --- /dev/null +++ b/torchrec/inference/inference_legacy/model_packager.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Type, TypeVar, Union + +import torch +from torch.package import PackageExporter +from torchrec.inference.modules import PredictFactory + +LOADER_MODULE = "__module_loader" +LOADER_FACTORY = "MODULE_FACTORY" +LOADER_CODE = f""" +import %PACKAGE% + +{LOADER_FACTORY}=%PACKAGE%.%CLASS% +""" +CONFIG_MODULE = "__configs" + +T = TypeVar("T") + + +try: + # pyre-fixme[21]: Could not find module `torch_package_importer`. + import torch_package_importer # @manual +except ImportError: + pass + + +def load_config_text(name: str) -> str: + return torch_package_importer.load_text("__configs", name) + + +def load_pickle_config(name: str, clazz: Type[T]) -> T: + loaded_obj = torch_package_importer.load_pickle("__configs", name) + assert isinstance( + loaded_obj, clazz + ), f"The loaded config {type(loaded_obj)} is not of type {clazz}" + return loaded_obj + + +class PredictFactoryPackager: + @classmethod + @abc.abstractclassmethod + def set_extern_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + @abc.abstractclassmethod + def set_mocked_modules(cls, pe: PackageExporter) -> None: + pass + + @classmethod + def save_predict_factory( + cls, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + output: Union[str, Path, BinaryIO], + extra_files: Dict[str, Union[str, bytes]], + loader_code: str = LOADER_CODE, + package_importer: Union[ + torch.package.Importer, List[torch.package.Importer] + ] = torch.package.sys_importer, + ) -> None: + with PackageExporter(output, importer=package_importer) as pe: + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_extern_modules(pe) + # pyre-fixme[29]: `BoundMethod[abc.abstractclassmethod[None], + # Type[PredictFactoryPackager]]` is not a function. + cls.set_mocked_modules(pe) + pe.extern(["sys"]) + pe.intern("**") + for k, v in extra_files.items(): + if isinstance(v, str): + pe.save_text("extra_files", k, v) + elif isinstance(v, bytes): + pe.save_binary("extra_files", k, v) + else: + raise ValueError(f"Unsupported type {type(v)}") + cls._save_predict_factory( + pe, predict_factory, configs, loader_code=loader_code + ) + + @classmethod + def _save_predict_factory( + cls, + pe: PackageExporter, + predict_factory: Type[PredictFactory], + configs: Dict[str, Any], + loader_code: str = LOADER_CODE, + ) -> None: + # If predict_factory is coming from a torch package, + # __module__ would have prefix. + # To save such predict factory, we need to remove + # the prefix. + package_name = predict_factory.__module__ + if package_name.startswith(" predictions = 1; +} + +// The predictor service definition. Synchronous for now. +service Predictor { + rpc Predict(PredictionRequest) returns (PredictionResponse) {} +} diff --git a/torchrec/inference/inference_legacy/server.cpp b/torchrec/inference/inference_legacy/server.cpp new file mode 100644 index 000000000..f7695cdcf --- /dev/null +++ b/torchrec/inference/inference_legacy/server.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +// remove this after we switch over to multipy externally for torchrec +#ifdef FBCODE_CAFFE2 +#include // @manual +#include +#else +#include +#include +#endif + +#include + +#include "torchrec/inference/GPUExecutor.h" +#include "torchrec/inference/predictor.grpc.pb.h" +#include "torchrec/inference/predictor.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using predictor::FloatVec; +using predictor::PredictionRequest; +using predictor::PredictionResponse; +using predictor::Predictor; + +DEFINE_int32(n_interp_per_gpu, 1, ""); +DEFINE_int32(n_gpu, 1, ""); +DEFINE_string(package_path, "", ""); + +DEFINE_int32(batching_interval, 10, ""); +DEFINE_int32(queue_timeout, 500, ""); + +DEFINE_int32(num_exception_threads, 4, ""); +DEFINE_int32(num_mem_pinner_threads, 4, ""); +DEFINE_int32(max_batch_size, 2048, ""); +DEFINE_int32(gpu_executor_queue_timeout, 50, ""); + +DEFINE_string(server_address, "0.0.0.0", ""); +DEFINE_string(server_port, "50051", ""); + +DEFINE_string( + python_packages_path, + "", + "Used to load the packages that you 'extern' with torch.package"); + +namespace { + +std::unique_ptr toTorchRecRequest( + const PredictionRequest* request) { + auto torchRecRequest = std::make_unique(); + torchRecRequest->batch_size = request->batch_size(); + + // Client sends a request with serialized tensor to bytes. + // Byte string is converted to folly::iobuf for torchrec request. + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->float_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["float_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["id_list_features"] = std::move(sparseFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->id_score_list_features(); + auto encoded_values = feature.values(); + auto encoded_lengths = feature.lengths(); + auto encoded_weights = feature.weights(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + sparseFeature.weights = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_weights.data(), + encoded_weights.size()}; + + torchRecRequest->features["id_score_list_features"] = + std::move(sparseFeature); + } + + { + torchrec::FloatFeatures floatFeature; + + auto feature = request->embedding_features(); + auto encoded_values = feature.values(); + + floatFeature.num_features = feature.num_features(); + floatFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["embedding_features"] = std::move(floatFeature); + } + + { + torchrec::SparseFeatures sparseFeature; + + auto feature = request->unary_features(); + auto encoded_lengths = feature.lengths(); + auto encoded_values = feature.values(); + + sparseFeature.num_features = feature.num_features(); + sparseFeature.lengths = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_lengths.data(), + encoded_lengths.size()}; + sparseFeature.values = folly::IOBuf{ + folly::IOBuf::COPY_BUFFER, + encoded_values.data(), + encoded_values.size()}; + + torchRecRequest->features["unary_features"] = std::move(sparseFeature); + } + + return torchRecRequest; +} + +// Logic behind the server's behavior. +class PredictorServiceHandler final : public Predictor::Service { + public: + explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) + : queue_(queue) {} + + Status Predict( + grpc::ServerContext* context, + const PredictionRequest* request, + PredictionResponse* reply) override { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + queue_.add(toTorchRecRequest(request), std::move(promise)); + auto torchRecResponse = + std::move(future).get(); // blocking, TODO: Write async server + auto predictions = reply->mutable_predictions(); + + // Convert ivalue to map, TODO: find out if protobuf + // can support custom types (folly::iobuf), so we can avoid this overhead. + for (const auto& item : torchRecResponse->predictions.toGenericDict()) { + auto tensor = item.value().toTensor(); + FloatVec fv; + fv.mutable_data()->Add( + tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); + (*predictions)[item.key().toStringRef()] = fv; + } + + return Status::OK; + } + + private: + torchrec::BatchingQueue& queue_; +}; + +} // namespace + +int main(int argc, char* argv[]) { + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + LOG(INFO) << "Creating GPU executors"; + + // store the executors and interpreter managers + std::vector> executors; + std::vector models; + std::vector batchQueueCbs; + std::unordered_map batchingMetadataMap; + + std::shared_ptr env = + std::make_shared( + FLAGS_python_packages_path); + + auto manager = std::make_shared( + FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); + { + torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); + auto I = package.acquireSession(); + auto imported = I.self.attr("import_module")({"__module_loader"}); + auto factoryType = imported.attr("MODULE_FACTORY"); + auto factory = factoryType.attr("__new__")({factoryType}); + factoryType.attr("__init__")({factory}); + + // Process forward metadata. + try { + auto batchingMetadataJsonStr = + factory.attr("batching_metadata_json")(at::ArrayRef()) + .toIValue() + .toString() + ->string(); + auto dynamic = folly::parseJson(batchingMetadataJsonStr); + CHECK(dynamic.isObject()); + for (auto it : dynamic.items()) { + torchrec::BatchingMetadata metadata; + metadata.type = it.second["type"].asString(); + metadata.device = it.second["device"].asString(); + batchingMetadataMap[it.first.asString()] = std::move(metadata); + } + } catch (...) { + auto batchingMetadata = + factory.attr("batching_metadata")(at::ArrayRef()) + .toIValue(); + for (const auto& iter : batchingMetadata.toGenericDict()) { + torchrec::BatchingMetadata metadata; + metadata.type = iter.value().toStringRef(); + metadata.device = "cuda"; + batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); + } + } + + // Process result metadata. + auto resultMetadata = + factory.attr("result_metadata")(at::ArrayRef()) + .toIValue() + .toStringRef(); + std::shared_ptr resultSplitFunc = + torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); + + LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; + auto dmp = factory.attr("create_predict_module") + .callKwargs({{"world_size", FLAGS_n_gpu}}); + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto device = I.self.attr("import_module")({"torch"}).attr("device")( + {"cuda", rank}); + auto m = dmp.attr("copy")({device.toIValue()}); + models.push_back(I.createMovable(m)); + } + + for (int rank = 0; rank < FLAGS_n_gpu; rank++) { + auto executor = std::make_unique( + manager, + std::move(models[rank]), + rank, + FLAGS_n_gpu, + resultSplitFunc, + std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); + executors.push_back(std::move(executor)); + batchQueueCbs.push_back( + [&, rank](std::shared_ptr batch) { + executors[rank]->callback(std::move(batch)); + }); + } + } + + torchrec::BatchingQueue queue( + batchQueueCbs, + torchrec::BatchingQueue::Config{ + .batchingInterval = + std::chrono::milliseconds(FLAGS_batching_interval), + .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), + .numExceptionThreads = FLAGS_num_exception_threads, + .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, + .maxBatchSize = FLAGS_max_batch_size, + .batchingMetadata = std::move(batchingMetadataMap), + }, + FLAGS_n_gpu); + + // create the server + std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); + auto service = PredictorServiceHandler(queue); + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + grpc::ServerBuilder builder; + + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + LOG(INFO) << "Server listening on " << server_address; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); + + LOG(INFO) << "Shutting down server"; + return 0; +} diff --git a/torchrec/inference/src/Batching.cpp b/torchrec/inference/inference_legacy/src/Batching.cpp similarity index 100% rename from torchrec/inference/src/Batching.cpp rename to torchrec/inference/inference_legacy/src/Batching.cpp diff --git a/torchrec/inference/src/BatchingQueue.cpp b/torchrec/inference/inference_legacy/src/BatchingQueue.cpp similarity index 100% rename from torchrec/inference/src/BatchingQueue.cpp rename to torchrec/inference/inference_legacy/src/BatchingQueue.cpp diff --git a/torchrec/inference/inference_legacy/src/Executer2.cpp b/torchrec/inference/inference_legacy/src/Executer2.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/inference/src/GPUExecutor.cpp b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp similarity index 100% rename from torchrec/inference/src/GPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/GPUExecutor.cpp diff --git a/torchrec/inference/src/ResourceManager.cpp b/torchrec/inference/inference_legacy/src/ResourceManager.cpp similarity index 100% rename from torchrec/inference/src/ResourceManager.cpp rename to torchrec/inference/inference_legacy/src/ResourceManager.cpp diff --git a/torchrec/inference/src/ResultSplit.cpp b/torchrec/inference/inference_legacy/src/ResultSplit.cpp similarity index 100% rename from torchrec/inference/src/ResultSplit.cpp rename to torchrec/inference/inference_legacy/src/ResultSplit.cpp diff --git a/torchrec/inference/src/SingleGPUExecutor.cpp b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp similarity index 100% rename from torchrec/inference/src/SingleGPUExecutor.cpp rename to torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp diff --git a/torchrec/inference/src/TestUtils.cpp b/torchrec/inference/inference_legacy/src/TestUtils.cpp similarity index 100% rename from torchrec/inference/src/TestUtils.cpp rename to torchrec/inference/inference_legacy/src/TestUtils.cpp diff --git a/torchrec/inference/src/Validation.cpp b/torchrec/inference/inference_legacy/src/Validation.cpp similarity index 100% rename from torchrec/inference/src/Validation.cpp rename to torchrec/inference/inference_legacy/src/Validation.cpp diff --git a/torchrec/inference/inference_legacy/state_dict_transform.py b/torchrec/inference/inference_legacy/state_dict_transform.py new file mode 100644 index 000000000..0379b1b80 --- /dev/null +++ b/torchrec/inference/inference_legacy/state_dict_transform.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Union + +import torch +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor + + +def state_dict_gather( + src: Dict[str, Union[torch.Tensor, ShardedTensor]], + dst: Dict[str, torch.Tensor], +) -> None: + """ + Gathers the values of the src state_dict of the keys present in the dst state_dict. Can handle ShardedTensors in the src state_dict. + + Args: + src (Dict[str, Union[torch.Tensor, ShardedTensor]]): source's state_dict for this rank + dst (Dict[str, torch.Tensor]): destination's state_dict + """ + for key, dst_tensor in dst.items(): + src_tensor = src[key] + if isinstance(src_tensor, ShardedTensor): + src_tensor.gather(out=dst_tensor if (dist.get_rank() == 0) else None) + elif isinstance(src_tensor, torch.Tensor): + dst_tensor.copy_(src_tensor) + else: + raise ValueError(f"Unsupported tensor {key} type {type(src_tensor)}") + + +def state_dict_all_gather_keys( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, +) -> List[str]: + """ + Gathers all the keys of the state_dict from all ranks. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): keys of this state_dict will be gathered + pg (ProcessGroup): Process Group used for comms + """ + names = list(state_dict.keys()) + all_names = [None] * dist.get_world_size(pg) + dist.all_gather_object(all_names, names, pg) + deduped_names = set() + for local_names in all_names: + # pyre-ignore[16] + for name in local_names: + deduped_names.add(name) + return sorted(deduped_names) + + +def state_dict_to_device( + state_dict: Dict[str, Union[torch.Tensor, ShardedTensor]], + pg: ProcessGroup, + device: torch.device, +) -> Dict[str, Union[torch.Tensor, ShardedTensor]]: + """ + Moves a state_dict to a device with a process group. Can handle ShardedTensors in the state_dict. + + Args: + state_dict (Dict[str, Union[torch.Tensor, ShardedTensor]]): state_dict to move + pg (ProcessGroup): Process Group used for comms + device (torch.device): device to put state_dict on + """ + ret = {} + all_keys = state_dict_all_gather_keys(state_dict, pg) + for key in all_keys: + if key in state_dict: + tensor = state_dict[key] + if isinstance(tensor, ShardedTensor): + copied_shards = [ + Shard.from_tensor_and_offsets( + tensor=shard.tensor.to(device), + shard_offsets=shard.metadata.shard_offsets, + rank=dist.get_rank(pg), + ) + for shard in tensor.local_shards() + ] + ret[key] = ShardedTensor._init_from_local_shards( + copied_shards, + tensor.metadata().size, + process_group=pg, + ) + elif isinstance(tensor, torch.Tensor): + ret[key] = tensor.to(device) + else: + raise ValueError(f"Unsupported tensor {key} type {type(tensor)}") + else: + # No state_dict entries for table-wise sharding, + # but need to follow full-sync. + ret[key] = ShardedTensor._init_from_local_shards( + [], + [], + process_group=pg, + ) + return ret diff --git a/torchrec/inference/tests/BatchingQueueTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp similarity index 100% rename from torchrec/inference/tests/BatchingQueueTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingQueueTest.cpp diff --git a/torchrec/inference/tests/BatchingTest.cpp b/torchrec/inference/inference_legacy/tests/BatchingTest.cpp similarity index 100% rename from torchrec/inference/tests/BatchingTest.cpp rename to torchrec/inference/inference_legacy/tests/BatchingTest.cpp diff --git a/torchrec/inference/tests/ResultSplitTest.cpp b/torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp similarity index 100% rename from torchrec/inference/tests/ResultSplitTest.cpp rename to torchrec/inference/inference_legacy/tests/ResultSplitTest.cpp diff --git a/torchrec/inference/tests/SingleGPUExecutorMultiGPUTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp similarity index 100% rename from torchrec/inference/tests/SingleGPUExecutorMultiGPUTest.cpp rename to torchrec/inference/inference_legacy/tests/SingleGPUExecutorMultiGPUTest.cpp diff --git a/torchrec/inference/tests/SingleGPUExecutorTest.cpp b/torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp similarity index 100% rename from torchrec/inference/tests/SingleGPUExecutorTest.cpp rename to torchrec/inference/inference_legacy/tests/SingleGPUExecutorTest.cpp diff --git a/torchrec/inference/tests/ValidationTest.cpp b/torchrec/inference/inference_legacy/tests/ValidationTest.cpp similarity index 100% rename from torchrec/inference/tests/ValidationTest.cpp rename to torchrec/inference/inference_legacy/tests/ValidationTest.cpp diff --git a/torchrec/inference/tests/generate_test_packages.py b/torchrec/inference/inference_legacy/tests/generate_test_packages.py similarity index 100% rename from torchrec/inference/tests/generate_test_packages.py rename to torchrec/inference/inference_legacy/tests/generate_test_packages.py diff --git a/torchrec/inference/tests/model_packager_tests.py b/torchrec/inference/inference_legacy/tests/model_packager_tests.py similarity index 99% rename from torchrec/inference/tests/model_packager_tests.py rename to torchrec/inference/inference_legacy/tests/model_packager_tests.py index b88942e19..41533e34a 100644 --- a/torchrec/inference/tests/model_packager_tests.py +++ b/torchrec/inference/inference_legacy/tests/model_packager_tests.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict import io import tempfile diff --git a/torchrec/inference/tests/predict_module_tests.py b/torchrec/inference/inference_legacy/tests/predict_module_tests.py similarity index 99% rename from torchrec/inference/tests/predict_module_tests.py rename to torchrec/inference/inference_legacy/tests/predict_module_tests.py index 22d8c99f8..5a1ac8a31 100644 --- a/torchrec/inference/tests/predict_module_tests.py +++ b/torchrec/inference/inference_legacy/tests/predict_module_tests.py @@ -5,8 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict - import unittest from typing import Dict diff --git a/torchrec/inference/tests/test_modules.py b/torchrec/inference/inference_legacy/tests/test_modules.py similarity index 100% rename from torchrec/inference/tests/test_modules.py rename to torchrec/inference/inference_legacy/tests/test_modules.py diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 27db09280..4d0e68a14 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -16,9 +16,9 @@ import torch.nn as nn import torch.quantization as quant import torchrec as trec +import torchrec.distributed as trec_dist import torchrec.quant as trec_quant from torch.fx.passes.split_utils import getattr_recursive -from torchrec import distributed as trec_dist, inference as trec_infer from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fused_params import ( FUSED_PARAM_BOUNDS_CHECK_MODE, @@ -357,7 +357,7 @@ def _quantize_fp_module( _quantize_fp_module(model, m, n) quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) - trec_infer.modules.quantize_embeddings( + quantize_embeddings( model, dtype=DEFAULT_QUANTIZATION_DTYPE, additional_qconfig_spec_keys=additional_qconfig_spec_keys, diff --git a/torchrec/inference/server.cpp b/torchrec/inference/server.cpp index f7695cdcf..51ad58fad 100644 --- a/torchrec/inference/server.cpp +++ b/torchrec/inference/server.cpp @@ -10,327 +10,170 @@ #include #include -#include -#include -#include -#include -#include -#include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_format.h" + #include #include +#include + +#include +#include -// remove this after we switch over to multipy externally for torchrec -#ifdef FBCODE_CAFFE2 -#include // @manual -#include +#ifdef BAZEL_BUILD +#include "examples/protos/predictor.grpc.pb.h" #else -#include -#include +#include "predictor.grpc.pb.h" #endif -#include +#define NUM_BYTES_FLOAT_FEATURES 4 +#define NUM_BYTES_SPARSE_FEATURES 4 -#include "torchrec/inference/GPUExecutor.h" -#include "torchrec/inference/predictor.grpc.pb.h" -#include "torchrec/inference/predictor.pb.h" - -using grpc::Channel; -using grpc::ClientContext; +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; using grpc::Status; + +ABSL_FLAG(uint16_t, port, 50051, "Server port for the service"); + using predictor::FloatVec; using predictor::PredictionRequest; using predictor::PredictionResponse; using predictor::Predictor; -DEFINE_int32(n_interp_per_gpu, 1, ""); -DEFINE_int32(n_gpu, 1, ""); -DEFINE_string(package_path, "", ""); - -DEFINE_int32(batching_interval, 10, ""); -DEFINE_int32(queue_timeout, 500, ""); - -DEFINE_int32(num_exception_threads, 4, ""); -DEFINE_int32(num_mem_pinner_threads, 4, ""); -DEFINE_int32(max_batch_size, 2048, ""); -DEFINE_int32(gpu_executor_queue_timeout, 50, ""); - -DEFINE_string(server_address, "0.0.0.0", ""); -DEFINE_string(server_port, "50051", ""); - -DEFINE_string( - python_packages_path, - "", - "Used to load the packages that you 'extern' with torch.package"); - -namespace { - -std::unique_ptr toTorchRecRequest( - const PredictionRequest* request) { - auto torchRecRequest = std::make_unique(); - torchRecRequest->batch_size = request->batch_size(); - - // Client sends a request with serialized tensor to bytes. - // Byte string is converted to folly::iobuf for torchrec request. - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->float_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["float_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["id_list_features"] = std::move(sparseFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->id_score_list_features(); - auto encoded_values = feature.values(); - auto encoded_lengths = feature.lengths(); - auto encoded_weights = feature.weights(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - sparseFeature.weights = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_weights.data(), - encoded_weights.size()}; - - torchRecRequest->features["id_score_list_features"] = - std::move(sparseFeature); - } - - { - torchrec::FloatFeatures floatFeature; - - auto feature = request->embedding_features(); - auto encoded_values = feature.values(); - - floatFeature.num_features = feature.num_features(); - floatFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["embedding_features"] = std::move(floatFeature); - } - - { - torchrec::SparseFeatures sparseFeature; - - auto feature = request->unary_features(); - auto encoded_lengths = feature.lengths(); - auto encoded_values = feature.values(); - - sparseFeature.num_features = feature.num_features(); - sparseFeature.lengths = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_lengths.data(), - encoded_lengths.size()}; - sparseFeature.values = folly::IOBuf{ - folly::IOBuf::COPY_BUFFER, - encoded_values.data(), - encoded_values.size()}; - - torchRecRequest->features["unary_features"] = std::move(sparseFeature); - } - - return torchRecRequest; -} - -// Logic behind the server's behavior. class PredictorServiceHandler final : public Predictor::Service { public: - explicit PredictorServiceHandler(torchrec::BatchingQueue& queue) - : queue_(queue) {} + PredictorServiceHandler(torch::jit::script::Module& module) + : module_(module) {} Status Predict( grpc::ServerContext* context, const PredictionRequest* request, PredictionResponse* reply) override { - folly::Promise> promise; - auto future = promise.getSemiFuture(); - queue_.add(toTorchRecRequest(request), std::move(promise)); - auto torchRecResponse = - std::move(future).get(); // blocking, TODO: Write async server - auto predictions = reply->mutable_predictions(); - - // Convert ivalue to map, TODO: find out if protobuf - // can support custom types (folly::iobuf), so we can avoid this overhead. - for (const auto& item : torchRecResponse->predictions.toGenericDict()) { - auto tensor = item.value().toTensor(); - FloatVec fv; - fv.mutable_data()->Add( - tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); - (*predictions)[item.key().toStringRef()] = fv; - } + std::cout << "Predict Called!" << std::endl; + c10::Dict dict; + + auto floatFeature = request->float_features(); + auto floatFeatureBlob = floatFeature.values(); + auto numFloatFeatures = floatFeature.num_features(); + auto batchSize = + floatFeatureBlob.size() / (NUM_BYTES_FLOAT_FEATURES * numFloatFeatures); + + std::cout << "Size: " << floatFeatureBlob.size() + << " Num Features: " << numFloatFeatures << std::endl; + auto floatFeatureTensor = torch::from_blob( + floatFeatureBlob.data(), + {batchSize, numFloatFeatures}, + torch::kFloat32); + + auto idListFeature = request->id_list_features(); + auto numIdListFeatures = idListFeature.num_features(); + auto lengthsBlob = idListFeature.lengths(); + auto valuesBlob = idListFeature.values(); + + std::cout << "Lengths Size: " << lengthsBlob.size() + << " Num Features: " << numIdListFeatures << std::endl; + assert( + batchSize == + (lengthsBlob.size() / (NUM_BYTES_SPARSE_FEATURES * numIdListFeatures))); + + auto lengthsTensor = torch::from_blob( + lengthsBlob.data(), + {lengthsBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + auto valuesTensor = torch::from_blob( + valuesBlob.data(), + {valuesBlob.size() / NUM_BYTES_SPARSE_FEATURES}, + torch::kInt32); + + dict.insert("float_features", floatFeatureTensor.to(torch::kCUDA)); + dict.insert("id_list_features.lengths", lengthsTensor.to(torch::kCUDA)); + dict.insert("id_list_features.values", valuesTensor.to(torch::kCUDA)); + + std::vector input; + input.push_back(c10::IValue(dict)); + + torch::Tensor output = + this->module_.forward(input).toGenericDict().at("default").toTensor(); + auto predictions = reply->mutable_predictions(); + FloatVec fv; + fv.mutable_data()->Add( + output.data_ptr(), output.data_ptr() + output.numel()); + (*predictions)["default"] = fv; return Status::OK; } private: - torchrec::BatchingQueue& queue_; + torch::jit::script::Module& module_; }; -} // namespace - -int main(int argc, char* argv[]) { - google::InitGoogleLogging(argv[0]); - gflags::ParseCommandLineFlags(&argc, &argv, true); - - LOG(INFO) << "Creating GPU executors"; - - // store the executors and interpreter managers - std::vector> executors; - std::vector models; - std::vector batchQueueCbs; - std::unordered_map batchingMetadataMap; - - std::shared_ptr env = - std::make_shared( - FLAGS_python_packages_path); - - auto manager = std::make_shared( - FLAGS_n_gpu * FLAGS_n_interp_per_gpu, env); - { - torch::deploy::Package package = manager->loadPackage(FLAGS_package_path); - auto I = package.acquireSession(); - auto imported = I.self.attr("import_module")({"__module_loader"}); - auto factoryType = imported.attr("MODULE_FACTORY"); - auto factory = factoryType.attr("__new__")({factoryType}); - factoryType.attr("__init__")({factory}); - - // Process forward metadata. - try { - auto batchingMetadataJsonStr = - factory.attr("batching_metadata_json")(at::ArrayRef()) - .toIValue() - .toString() - ->string(); - auto dynamic = folly::parseJson(batchingMetadataJsonStr); - CHECK(dynamic.isObject()); - for (auto it : dynamic.items()) { - torchrec::BatchingMetadata metadata; - metadata.type = it.second["type"].asString(); - metadata.device = it.second["device"].asString(); - batchingMetadataMap[it.first.asString()] = std::move(metadata); - } - } catch (...) { - auto batchingMetadata = - factory.attr("batching_metadata")(at::ArrayRef()) - .toIValue(); - for (const auto& iter : batchingMetadata.toGenericDict()) { - torchrec::BatchingMetadata metadata; - metadata.type = iter.value().toStringRef(); - metadata.device = "cuda"; - batchingMetadataMap[iter.key().toStringRef()] = std::move(metadata); - } - } - - // Process result metadata. - auto resultMetadata = - factory.attr("result_metadata")(at::ArrayRef()) - .toIValue() - .toStringRef(); - std::shared_ptr resultSplitFunc = - torchrec::TorchRecResultSplitFuncRegistry()->Create(resultMetadata); - - LOG(INFO) << "Creating Model Shard for " << FLAGS_n_gpu << " GPUs."; - auto dmp = factory.attr("create_predict_module") - .callKwargs({{"world_size", FLAGS_n_gpu}}); - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto device = I.self.attr("import_module")({"torch"}).attr("device")( - {"cuda", rank}); - auto m = dmp.attr("copy")({device.toIValue()}); - models.push_back(I.createMovable(m)); - } - - for (int rank = 0; rank < FLAGS_n_gpu; rank++) { - auto executor = std::make_unique( - manager, - std::move(models[rank]), - rank, - FLAGS_n_gpu, - resultSplitFunc, - std::chrono::milliseconds(FLAGS_gpu_executor_queue_timeout)); - executors.push_back(std::move(executor)); - batchQueueCbs.push_back( - [&, rank](std::shared_ptr batch) { - executors[rank]->callback(std::move(batch)); - }); - } - } - - torchrec::BatchingQueue queue( - batchQueueCbs, - torchrec::BatchingQueue::Config{ - .batchingInterval = - std::chrono::milliseconds(FLAGS_batching_interval), - .queueTimeout = std::chrono::milliseconds(FLAGS_queue_timeout), - .numExceptionThreads = FLAGS_num_exception_threads, - .numMemPinnerThreads = FLAGS_num_mem_pinner_threads, - .maxBatchSize = FLAGS_max_batch_size, - .batchingMetadata = std::move(batchingMetadataMap), - }, - FLAGS_n_gpu); - - // create the server - std::string server_address(FLAGS_server_address + ":" + FLAGS_server_port); - auto service = PredictorServiceHandler(queue); +void RunServer(uint16_t port, torch::jit::script::Module& module) { + std::string server_address = absl::StrFormat("0.0.0.0:%d", port); + PredictorServiceHandler service(module); grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); - grpc::ServerBuilder builder; - + ServerBuilder builder; // Listen on the given address without any authentication mechanism. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - // Register "service" as the instance through which we'll communicate with // clients. In this case it corresponds to an *synchronous* service. builder.RegisterService(&service); - // Finally assemble the server. - std::unique_ptr server(builder.BuildAndStart()); - LOG(INFO) << "Server listening on " << server_address; + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; // Wait for the server to shutdown. Note that some other thread must be // responsible for shutting down the server for this call to ever return. server->Wait(); +} + +int main(int argc, char** argv) { + // absl::ParseCommandLine(argc, argv); + + if (argc != 2) { + std::cerr << "usage: ts-infer \n"; + return -1; + } + + std::cout << "Loading model...\n"; + + // deserialize ScriptModule + torch::jit::script::Module module; + try { + module = torch::jit::load(argv[1]); + } catch (const c10::Error& e) { + std::cerr << "Error loading model\n"; + return -1; + } + + torch::NoGradGuard no_grad; // ensures that autograd is off + module.eval(); // turn off dropout and other training-time layers/functions + + std::cout << "Sanity Check with dummy inputs" << std::endl; + c10::Dict dict; + dict.insert( + "float_features", + torch::ones( + {1, 13}, torch::dtype(torch::kFloat32).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.lengths", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + dict.insert( + "id_list_features.values", + torch::ones({26}, torch::dtype(torch::kLong).device(torch::kCUDA, 0))); + + std::vector input; + input.push_back(c10::IValue(dict)); + + // Execute the model and turn its output into a tensor. + auto output = module.forward(input).toGenericDict().at("default").toTensor(); + std::cout << " Model Forward Completed, Output: " << output.item() + << std::endl; + + RunServer(absl::GetFlag(FLAGS_port), module); - LOG(INFO) << "Shutting down server"; return 0; }