diff --git a/.github/workflows/build_75.yaml b/.github/workflows/build_75.yaml index dc6ee6b7..d744f96c 100644 --- a/.github/workflows/build_75.yaml +++ b/.github/workflows/build_75.yaml @@ -77,7 +77,7 @@ tags: | type=semver,pattern=turing-{{version}} type=semver,pattern=turing-{{major}}.{{minor}} - type=raw,value=turing-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=turing-latest type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }} - name: Build and push Docker image id: build-and-push-75 @@ -99,3 +99,37 @@ labels: ${{ steps.meta-75.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-75-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=turing-{{version}}+grpc + type=semver,pattern=turing-{{major}}.{{minor}}+grpc + type=raw,value=turing-latest+grpc + type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-75-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=75 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + DEFAULT_USE_FLASH_ATTENTION=False + tags: ${{ steps.meta-75-grpc.outputs.tags }} + labels: ${{ steps.meta-75-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max diff --git a/.github/workflows/build_80.yaml b/.github/workflows/build_80.yaml index f5d6fe5a..589f1aef 100644 --- a/.github/workflows/build_80.yaml +++ b/.github/workflows/build_80.yaml @@ -98,3 +98,36 @@ labels: ${{ steps.meta-80.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-80-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern={{version}}+grpc + type=semver,pattern={{major}}.{{minor}}+grpc + type=raw,value=latest+grpc + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-80-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=80 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-80-grpc.outputs.tags }} + labels: ${{ steps.meta-80-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max diff --git a/.github/workflows/build_86.yaml b/.github/workflows/build_86.yaml index bd824414..d799abaf 100644 --- a/.github/workflows/build_86.yaml +++ b/.github/workflows/build_86.yaml @@ -98,3 +98,37 @@ labels: ${{ steps.meta-86.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-86-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=86-{{version}}+grpc + type=semver,pattern=86-{{major}}.{{minor}}+grpc + type=raw,value=86-latest+grpc + type=raw,value=86-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-86-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=86 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-86-grpc.outputs.tags }} + labels: ${{ steps.meta-86-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max + diff --git a/.github/workflows/build_89.yaml b/.github/workflows/build_89.yaml index a5a5be7c..5126ab69 100644 --- a/.github/workflows/build_89.yaml +++ b/.github/workflows/build_89.yaml @@ -98,3 +98,37 @@ labels: ${{ steps.meta-89.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-89-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=89-{{version}}+grpc + type=semver,pattern=89-{{major}}.{{minor}}+grpc + type=raw,value=89-latest+grpc + type=raw,value=89-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-89-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=89 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-89-grpc.outputs.tags }} + labels: ${{ steps.meta-89-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max + diff --git a/.github/workflows/build_90.yaml b/.github/workflows/build_90.yaml index 9c6f2d6a..63fc3b6f 100644 --- a/.github/workflows/build_90.yaml +++ b/.github/workflows/build_90.yaml @@ -98,4 +98,39 @@ labels: ${{ steps.meta-90.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-90-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=hopper-{{version}}+grpc + type=semver,pattern=hopper-{{major}}.{{minor}}+grpc + type=raw,value=hopper-latest+grpc + type=raw,value=hopper-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-90-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=90 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-90-grpc.outputs.tags }} + labels: ${{ steps.meta-90-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max + + diff --git a/.github/workflows/build_cpu.yaml b/.github/workflows/build_cpu.yaml index e0237c81..bc6623c7 100644 --- a/.github/workflows/build_cpu.yaml +++ b/.github/workflows/build_cpu.yaml @@ -97,3 +97,36 @@ labels: ${{ steps.meta-cpu.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-cpu-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=cpu-{{version}}+grpc + type=semver,pattern=cpu-{{major}}.{{minor}}+grpc + type=raw,value=cpu-latest+grpc + type=raw,value=cpu-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-cpu-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-cpu-grpc.outputs.tags }} + labels: ${{ steps.meta-cpu-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max + diff --git a/Cargo.lock b/Cargo.lock index 023dc833..348db95e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -227,12 +227,12 @@ name = "backend-grpc-client" version = "0.5.0" dependencies = [ "grpc-metadata", - "prost", - "prost-build", + "prost 0.11.9", + "prost-build 0.11.9", "thiserror", "tokio", - "tonic", - "tonic-build", + "tonic 0.9.2", + "tonic-build 0.9.2", "tower", "tracing", ] @@ -1186,7 +1186,7 @@ name = "grpc-metadata" version = "0.1.0" dependencies = [ "opentelemetry 0.19.0", - "tonic", + "tonic 0.9.2", "tracing", "tracing-opentelemetry 0.19.0", ] @@ -2041,10 +2041,10 @@ dependencies = [ "opentelemetry-semantic-conventions", "opentelemetry_api 0.20.0", "opentelemetry_sdk 0.20.0", - "prost", + "prost 0.11.9", "thiserror", "tokio", - "tonic", + "tonic 0.9.2", ] [[package]] @@ -2055,8 +2055,8 @@ checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" dependencies = [ "opentelemetry_api 0.20.0", "opentelemetry_sdk 0.20.0", - "prost", - "tonic", + "prost 0.11.9", + "tonic 0.9.2", ] [[package]] @@ -2275,6 +2275,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prettyplease" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" +dependencies = [ + "proc-macro2", + "syn 2.0.39", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2315,7 +2325,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.11.9", +] + +[[package]] +name = "prost" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a" +dependencies = [ + "bytes", + "prost-derive 0.12.3", ] [[package]] @@ -2331,15 +2351,37 @@ dependencies = [ "log", "multimap", "petgraph", - "prettyplease", - "prost", - "prost-types", + "prettyplease 0.1.25", + "prost 0.11.9", + "prost-types 0.11.9", "regex", "syn 1.0.109", "tempfile", "which", ] +[[package]] +name = "prost-build" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2" +dependencies = [ + "bytes", + "heck", + "itertools 0.11.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease 0.2.15", + "prost 0.12.3", + "prost-types 0.12.3", + "regex", + "syn 2.0.39", + "tempfile", + "which", +] + [[package]] name = "prost-derive" version = "0.11.9" @@ -2353,13 +2395,35 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prost-derive" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" +dependencies = [ + "anyhow", + "itertools 0.11.0", + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "prost-types" version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" dependencies = [ - "prost", + "prost 0.11.9", +] + +[[package]] +name = "prost-types" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193898f59edcf43c26227dcd4c8427f00d99d61e95dcde58dabd49fa291d470e" +dependencies = [ + "prost 0.12.3", ] [[package]] @@ -3132,12 +3196,14 @@ dependencies = [ "clap", "futures", "hf-hub", + "http 0.2.11", "init-tracing-opentelemetry", "metrics", "metrics-exporter-prometheus", "num_cpus", "opentelemetry 0.20.0", "opentelemetry-otlp", + "prost 0.12.3", "reqwest", "serde", "serde_json", @@ -3146,9 +3212,13 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build 0.10.2", + "tonic-health", + "tonic-reflection", "tower-http", "tracing", - "tracing-chrome", "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", @@ -3367,7 +3437,34 @@ dependencies = [ "hyper-timeout", "percent-encoding", "pin-project", - "prost", + "prost 0.11.9", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.5", + "bytes", + "h2", + "http 0.2.11", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost 0.12.3", "tokio", "tokio-stream", "tower", @@ -3382,13 +3479,52 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07" dependencies = [ - "prettyplease", + "prettyplease 0.1.25", "proc-macro2", - "prost-build", + "prost-build 0.11.9", "quote", "syn 1.0.109", ] +[[package]] +name = "tonic-build" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" +dependencies = [ + "prettyplease 0.2.15", + "proc-macro2", + "prost-build 0.12.3", + "quote", + "syn 2.0.39", +] + +[[package]] +name = "tonic-health" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f80db390246dfb46553481f6024f0082ba00178ea495dbb99e70ba9a4fafb5e1" +dependencies = [ + "async-stream", + "prost 0.12.3", + "tokio", + "tokio-stream", + "tonic 0.10.2", +] + +[[package]] +name = "tonic-reflection" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fa37c513df1339d197f4ba21d28c918b9ef1ac1768265f11ecb6b7f1cba1b76" +dependencies = [ + "prost 0.12.3", + "prost-types 0.12.3", + "tokio", + "tokio-stream", + "tonic 0.10.2", +] + [[package]] name = "tower" version = "0.4.13" @@ -3462,17 +3598,6 @@ dependencies = [ "syn 2.0.39", ] -[[package]] -name = "tracing-chrome" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "496b3cd5447f7ff527bbbf19b071ad542a000adf297d4127078b4dfdb931f41a" -dependencies = [ - "serde_json", - "tracing-core", - "tracing-subscriber", -] - [[package]] name = "tracing-core" version = "0.1.32" diff --git a/Dockerfile b/Dockerfile index f557c1d3..07baafb6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,9 +51,23 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ +FROM builder as http-builder + RUN cargo build --release --bin text-embeddings-router -F candle -F mkl-dynamic --no-default-features && sccache -s -FROM debian:bookworm-slim +FROM builder as grpc-builder + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +RUN cargo build --release --bin text-embeddings-router -F grpc -F candle -F mkl-dynamic --no-default-features && sccache -s + +FROM debian:bookworm-slim as base ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ @@ -80,7 +94,16 @@ COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /u COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so -COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] \ No newline at end of file diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 98e7f93c..bad00a70 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -70,6 +70,8 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ +FROM builder as http-builder + RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ then \ cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F static-linking --no-default-features && sccache -s; \ @@ -77,7 +79,28 @@ RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ cargo build --release --bin text-embeddings-router -F candle-cuda -F static-linking --no-default-features && sccache -s; \ fi; -FROM nvidia/cuda:12.0.0-base-ubuntu22.04 +FROM builder as grpc-builder + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ + then \ + cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F static-linking -F grpc --no-default-features && sccache -s; \ + else \ + cargo build --release --bin text-embeddings-router -F candle-cuda -F static-linking -F grpc --no-default-features && sccache -s; \ + fi; + +FROM nvidia/cuda:12.0.0-base-ubuntu22.04 as base ARG DEFAULT_USE_FLASH_ATTENTION=True @@ -85,7 +108,16 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION -COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] \ No newline at end of file diff --git a/README.md b/README.md index 66abc7b4..08f1dc7a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ length of 512 tokens: - [Using Re-rankers models](#using-re-rankers-models) - [Using Sequence Classification models](#using-sequence-classification-models) - [Distributed Tracing](#distributed-tracing) + - [gRPC](#grpc) - [Local Install](#local-install) - [Docker Build](#docker-build) @@ -334,6 +335,25 @@ curl 127.0.0.1:8080/predict \ `text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature by setting the address to an OTLP collector with the `--otlp-endpoint` argument. +### gRPC + +`text-embeddings-inference` offers a gRPC API as an alternative to the default HTTP API for high performance +deployments. The API protobuf definition can be found [here](https://github.com/huggingface/text-embeddings-inference/blob/main/proto/tei.proto). + +You can use the gRPC API by adding the `+grpc` tag to any TEI Docker image. For example: + +```shell +model=BAAI/bge-large-en-v1.5 +revision=refs/pr/5 +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.5+grpc --model-id $model --revision $revision +``` + +```shell +grpcurl -d '{"inputs": "What is Deep Learning"}' -plaintext 0.0.0.0:8080 tei.v1.Embed/Embed +``` + ## Local install ### CPU diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 77da3440..fd089515 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -1,11 +1,9 @@ mod dtype; use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend_core::Backend as CoreBackend; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, watch}; use tracing::{instrument, Span}; pub use crate::dtype::DType; @@ -22,7 +20,7 @@ pub struct Backend { /// Channel to communicate with the background thread backend_sender: mpsc::UnboundedSender, /// Health status - health: Arc, + health_receiver: watch::Receiver, pub max_batch_size: Option, pub model_type: ModelType, } @@ -46,11 +44,15 @@ impl Backend { )?; let max_batch_size = backend.max_batch_size(); - tokio::task::spawn_blocking(move || backend_blocking_task(backend, backend_receiver)); + let (health_sender, health_receiver) = watch::channel(false); + + tokio::task::spawn_blocking(move || { + backend_blocking_task(backend, backend_receiver, health_sender) + }); Ok(Self { backend_sender, - health: Arc::new(AtomicBool::new(false)), + health_receiver, max_batch_size, model_type, }) @@ -58,7 +60,7 @@ impl Backend { #[instrument(skip(self))] pub async fn health(&self) -> Result<(), BackendError> { - let result = if self.health.load(Ordering::SeqCst) { + if *self.health_receiver.borrow() { // The backend is healthy. Only do a basic health check by calling the // the underlying health method. @@ -84,11 +86,12 @@ impl Backend { ModelType::Classifier => self.predict(batch).await.map(|_| ()), ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } - }; + } + } - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); - result + #[instrument(skip(self))] + pub fn health_watcher(&self) -> watch::Receiver { + self.health_receiver.clone() } #[instrument(skip_all)] @@ -98,13 +101,9 @@ impl Backend { self.backend_sender .send(BackendCommand::Embed(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); - let result = receiver.await.expect( + receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", - ); - - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); - result + ) } #[instrument(skip_all)] @@ -114,13 +113,9 @@ impl Backend { self.backend_sender .send(BackendCommand::Predict(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); - let result = receiver.await.expect( + receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", - ); - - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); - result + ) } } @@ -142,8 +137,6 @@ fn init_backend( } else if cfg!(feature = "python") { #[cfg(feature = "python")] { - use std::thread; - return Ok(Box::new( thread::spawn(move || { PythonBackend::new( @@ -165,23 +158,32 @@ fn init_backend( fn backend_blocking_task( backend: Box, mut command_receiver: mpsc::UnboundedReceiver, + health_sender: watch::Sender, ) { while let Some(cmd) = command_receiver.blocking_recv() { let start = Instant::now(); + let mut healthy = false; match cmd { BackendCommand::Health(span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.health()); + let _ = sender.send(backend.health().map(|_| healthy = true)); } BackendCommand::Embed(batch, span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.embed(batch).map(|e| (e, start.elapsed()))); + let _ = sender.send(backend.embed(batch).map(|e| { + healthy = true; + (e, start.elapsed()) + })); } BackendCommand::Predict(batch, span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.predict(batch).map(|e| (e, start.elapsed()))); + let _ = sender.send(backend.predict(batch).map(|e| { + healthy = true; + (e, start.elapsed()) + })); } - } + }; + let _ = health_sender.send(healthy); } } diff --git a/core/src/infer.rs b/core/src/infer.rs index 1c347234..4232ea96 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -4,11 +4,11 @@ use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, ModelType}; -use tokio::sync::{mpsc, oneshot, Notify, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::{instrument, Span}; /// Inference struct -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Infer { tokenization: Tokenization, queue: Queue, @@ -285,6 +285,11 @@ impl Infer { pub async fn health(&self) -> bool { self.backend.health().await.is_ok() } + + #[instrument(skip(self))] + pub fn health_watcher(&self) -> watch::Receiver { + self.backend.health_watcher() + } } #[instrument(skip_all)] diff --git a/load_tests/load.js b/load_tests/load.js index 7a945cce..abaa2090 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -18,16 +18,16 @@ export const options = { scenarios: { // throughput: { // executor: 'shared-iterations', - // vus: 1000, - // iterations: 1000, + // vus: 5000, + // iterations: 5000, // maxDuration: '2m', // gracefulStop: '1s', // }, load_test: { executor: 'constant-arrival-rate', duration: '30s', - preAllocatedVUs: 10000, - rate: 9000, + preAllocatedVUs: 5000, + rate: 1000, timeUnit: '1s', gracefulStop: '1s', }, diff --git a/load_tests/load_grpc.js b/load_tests/load_grpc.js new file mode 100644 index 00000000..4c14407a --- /dev/null +++ b/load_tests/load_grpc.js @@ -0,0 +1,68 @@ +import {check} from 'k6'; +import grpc from 'k6/experimental/grpc'; +import {Trend} from 'k6/metrics'; + +const host = __ENV.HOST || '127.0.0.1:3000'; + +const totalTime = new Trend('total_time', true); +const tokenizationTIme = new Trend('tokenization_time', true); +const queueTime = new Trend('queue_time', true); +const inferenceTime = new Trend('inference_time', true); + +export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; + +export const options = { + thresholds: { + http_req_failed: ['rate==0'], + }, + scenarios: { + // throughput: { + // executor: 'shared-iterations', + // vus: 10000, + // iterations: 10000, + // maxDuration: '2m', + // gracefulStop: '1s', + // }, + load_test: { + executor: 'constant-arrival-rate', + duration: '30s', + preAllocatedVUs: 5000, + rate: 1000, + timeUnit: '1s', + gracefulStop: '1s', + }, + }, +}; + + +const client = new grpc.Client(); + +client.load([], '../proto/tei.proto'); + +export default function () { + if (__ITER == 0) { + client.connect(host, { + plaintext: true + }); + } + + const payload = { + inputs: inputs, + truncate: true, + }; + + const res = client.invoke('tei.v1.Embed/Embed', payload); + + check(res, { + 'status is OK': (r) => r && r.status === grpc.StatusOK, + }); + + if (res.status === grpc.StatusOK) { + totalTime.add(res.headers["x-total-time"]); + tokenizationTIme.add(res.headers["x-tokenization-time"]); + queueTime.add(res.headers["x-queue-time"]); + inferenceTime.add(res.headers["x-inference-time"]); + } else { + console.log(res.error); + } +} diff --git a/load_tests/load_grpc_stream.js b/load_tests/load_grpc_stream.js new file mode 100644 index 00000000..42ab489f --- /dev/null +++ b/load_tests/load_grpc_stream.js @@ -0,0 +1,63 @@ +import grpc from 'k6/experimental/grpc'; +import {Counter, Trend} from 'k6/metrics'; + +const host = __ENV.HOST || '127.0.0.1:3000'; + +const streamCounter = new Counter('stream_counter'); +const totalTime = new Trend('total_time', true); +const tokenizationTIme = new Trend('tokenization_time', true); +const queueTime = new Trend('queue_time', true); +const inferenceTime = new Trend('inference_time', true); + +export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; + +export const options = { + scenarios: { + throughput: { + executor: 'shared-iterations', + vus: 1, + iterations: 1, + maxDuration: '2m', + gracefulStop: '1s', + }, + }, +}; + + +const client = new grpc.Client(); + +client.load([], '../proto/tei.proto'); + +export default function () { + if (__ITER == 0) { + client.connect(host, { + plaintext: true + }); + } + + const stream = new grpc.Stream(client, 'tei.v1.Embed/EmbedStream'); + + stream.on('data', (res) => { + totalTime.add(res.metadata.totalTimeNs / 1e6); + tokenizationTIme.add(res.metadata.tokenizationTimeNs / 1e6); + queueTime.add(res.metadata.queueTimeNs / 1e6); + inferenceTime.add(res.metadata.inferenceTimeNs / 1e6); + }); + + stream.on('error', (err) => { + console.log('Stream Error: ' + JSON.stringify(err)); + }); + + const payload = { + inputs: inputs, + truncate: true, + }; + + // send 10000 requests + for (let i = 0; i < 10000; i++) { + stream.write(payload); + } + + // close the client stream + stream.end(); +} diff --git a/proto/tei.proto b/proto/tei.proto new file mode 100644 index 00000000..d9131679 --- /dev/null +++ b/proto/tei.proto @@ -0,0 +1,113 @@ +syntax = "proto3"; + +package tei.v1; + +service Info { + rpc Info (InfoRequest) returns (InfoResponse) { + option idempotency_level = IDEMPOTENT; + }; +} + +service Embed { + rpc Embed (EmbedRequest) returns (EmbedResponse); + rpc EmbedStream (stream EmbedRequest) returns (stream EmbedResponse); +} + +service Predict { + rpc Predict (PredictRequest) returns (PredictResponse); + rpc PredictStream (stream PredictRequest) returns (stream PredictResponse); +} + +service Rerank { + rpc Rerank (RerankRequest) returns (RerankResponse); + rpc RerankStream (stream RerankStreamRequest) returns (RerankResponse); +} + +message InfoRequest {} + +enum ModelType { + MODEL_TYPE_EMBEDDING = 0; + MODEL_TYPE_CLASSIFIER = 1; + MODEL_TYPE_RERANKER = 2; +} + +message InfoResponse { + string version = 1; + optional string sha = 2; + optional string docker_label = 3; + string model_id = 4; + optional string model_sha = 5; + string model_dtype = 6; + ModelType model_type = 7; + uint32 max_concurrent_requests = 8; + uint32 max_input_length = 9; + uint32 max_batch_tokens = 10; + optional uint32 max_batch_requests = 11; + uint32 max_client_batch_size = 12; + uint32 tokenization_workers = 13; +} + +message Metadata { + uint32 compute_chars = 1; + uint32 compute_tokens = 2; + uint64 total_time_ns = 3; + uint64 tokenization_time_ns = 4; + uint64 queue_time_ns = 5; + uint64 inference_time_ns = 6; +} + +message EmbedRequest { + string inputs = 1; + bool truncate = 2; + bool normalize = 3; +} + +message EmbedResponse { + repeated float embeddings = 1; + Metadata metadata = 2; +} + +message PredictRequest { + string inputs = 1; + bool truncate = 2; + bool raw_scores = 3; +} + +message Prediction { + float score = 1; + string label = 2; +} + +message PredictResponse { + repeated Prediction predictions = 1; + Metadata metadata = 2; +} + +message RerankRequest { + string query = 1; + repeated string texts = 2; + bool truncate = 3; + bool raw_scores = 4; + bool return_text = 5; +} + +message RerankStreamRequest{ + string query = 1; + string text = 2; + bool truncate = 3; + // The server will only consider the first value + bool raw_scores = 4; + // The server will only consider the first value + bool return_text = 5; +} + +message Rank { + uint32 index = 1; + optional string text = 2; + float score = 3; +} + +message RerankResponse { + repeated Rank ranks = 1; + Metadata metadata = 2; +} diff --git a/router/Cargo.toml b/router/Cargo.toml index a6b7494a..32d130b9 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,15 +16,13 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.71" -async-stream = "0.3.3" -axum = { version = "0.6.4", features = ["json"] } -axum-tracing-opentelemetry = "0.14.1" text-embeddings-backend = { path = "../backends", features = ["clap"] } text-embeddings-core = { path = "../core" } clap = { version = "4.1.4", features = ["derive", "env"] } futures = "^0.3" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } hf-hub = { version = "0.3.0", features = ["tokio"] } +http = "0.2.9" num_cpus = "1.16.0" metrics = "0.21.0" metrics-exporter-prometheus = { version = "0.12.1", features = [] } @@ -36,20 +34,34 @@ serde_json = "1.0.93" thiserror = "1.0.38" tokenizers = { version = "0.15.0", default-features=false, features=["onig", "esaxx_fast"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tower-http = { version = "0.4.0", features = ["cors"] } tracing = "0.1.37" -tracing-chrome = "0.7.1" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } -utoipa = { version = "4.0.0", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "4.0.0", features = ["axum"] } veil = "0.1.6" +# HTTP dependencies +axum = { version = "0.6.4", features = ["json"], optional = true } +axum-tracing-opentelemetry = { version = "0.14.1", optional = true } +tower-http = { version = "0.4.0", features = ["cors"], optional = true } +utoipa = { version = "4.0.0", features = ["axum_extras"], optional = true } +utoipa-swagger-ui = { version = "4.0.0", features = ["axum"], optional = true } + +# gRPC dependencies +async-stream = { version = "0.3.5", optional = true } +prost = { version = "0.12.1", optional = true } +tonic = { version = "0.10.2", optional = true } +tonic-health = { version = "0.10.2", optional = true } +tonic-reflection = { version = "0.10.2", optional = true } +tokio-stream = { version = "0.1.14", optional = true } + [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } +tonic-build = { version = "0.10.2", optional = true } [features] -default = ["candle"] +default = ["candle", "http"] +http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] +grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"] mkl = ["text-embeddings-backend/mkl"] mkl-dynamic = ["text-embeddings-backend/mkl-dynamic"] accelerate = ["text-embeddings-backend/accelerate"] diff --git a/router/build.rs b/router/build.rs index f5eb8a26..b690eae0 100644 --- a/router/build.rs +++ b/router/build.rs @@ -22,5 +22,24 @@ fn main() -> Result<(), Box> { println!("cargo:rustc-env=DOCKER_LABEL={label}"); } + #[cfg(feature = "grpc")] + { + use std::env; + use std::fs; + use std::path::PathBuf; + + fs::create_dir("src/grpc/pb").unwrap_or(()); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + tonic_build::configure() + .build_client(false) + .build_server(true) + .file_descriptor_set_path(out_dir.join("descriptor.bin")) + .out_dir("src/grpc/pb") + .include_file("mod.rs") + .compile(&["../proto/tei.proto"], &["../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); + } + Ok(()) } diff --git a/router/src/grpc/mod.rs b/router/src/grpc/mod.rs new file mode 100644 index 00000000..91d4c8d4 --- /dev/null +++ b/router/src/grpc/mod.rs @@ -0,0 +1,7 @@ +mod pb; +pub(crate) mod server; + +use pb::tei::v1::{ + embed_server::EmbedServer, info_server::InfoServer, predict_server::PredictServer, + rerank_server::RerankServer, *, +}; diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs new file mode 100644 index 00000000..b4983ab5 --- /dev/null +++ b/router/src/grpc/server.rs @@ -0,0 +1,948 @@ +use crate::grpc::pb::tei::v1::RerankStreamRequest; +use crate::grpc::{ + EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, + Prediction, Rank, RerankRequest, RerankResponse, +}; +use crate::ResponseMetadata; +use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType}; +use futures::future::join_all; +use metrics_exporter_prometheus::PrometheusBuilder; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use text_embeddings_core::infer::Infer; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tonic::codegen::http::HeaderMap; +use tonic::metadata::MetadataMap; +use tonic::server::NamedService; +use tonic::transport::Server; +use tonic::{Code, Extensions, Request, Response, Status, Streaming}; +use tonic_health::ServingStatus; +use tracing::{instrument, Span}; + +impl From<&ResponseMetadata> for grpc::Metadata { + fn from(value: &ResponseMetadata) -> Self { + Self { + compute_chars: value.compute_chars as u32, + compute_tokens: value.compute_tokens as u32, + total_time_ns: value.start_time.elapsed().as_nanos() as u64, + tokenization_time_ns: value.tokenization_time.as_nanos() as u64, + queue_time_ns: value.queue_time.as_nanos() as u64, + inference_time_ns: value.inference_time.as_nanos() as u64, + } + } +} + +#[derive(Debug, Clone)] +struct TextEmbeddingsService { + infer: Infer, + info: Info, + max_parallel_stream_requests: usize, +} + +impl TextEmbeddingsService { + fn new(infer: Infer, info: Info) -> Self { + let max_parallel_stream_requests = std::env::var("GRPC_MAX_PARALLEL_STREAM_REQUESTS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(1024); + Self { + infer, + info, + max_parallel_stream_requests, + } + } + + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn embed_inner( + &self, + request: EmbedRequest, + permit: OwnedSemaphorePermit, + ) -> Result<(EmbedResponse, ResponseMetadata), Status> { + let span = Span::current(); + let start_time = Instant::now(); + + let compute_chars = request.inputs.chars().count(); + let response = self + .infer + .embed(request.inputs, request.truncate, request.normalize, permit) + .await + .map_err(ErrorResponse::from)?; + + let response_metadata = ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + tracing::info!("Success"); + + Ok(( + EmbedResponse { + embeddings: response.results, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }, + response_metadata, + )) + } + + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn predict_inner( + &self, + request: PredictRequest, + permit: OwnedSemaphorePermit, + ) -> Result<(PredictResponse, ResponseMetadata), Status> { + let span = Span::current(); + let start_time = Instant::now(); + + let compute_chars = request.inputs.chars().count(); + let response = self + .infer + .predict(request.inputs, request.truncate, request.raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let id2label = match &self.info.model_type { + ModelType::Classifier(classifier) => &classifier.id2label, + ModelType::Reranker(classifier) => &classifier.id2label, + _ => panic!(), + }; + + let response_metadata = ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ); + + let mut predictions: Vec = { + // Map score to label + response + .results + .into_iter() + .enumerate() + .map(|(i, s)| Prediction { + score: s, + label: id2label.get(&i.to_string()).unwrap().clone(), + }) + .collect() + }; + // Reverse sort + predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + predictions.reverse(); + + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + tracing::info!("Success"); + + Ok(( + PredictResponse { + predictions, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }, + response_metadata, + )) + } +} + +#[tonic::async_trait] +impl grpc::info_server::Info for TextEmbeddingsService { + async fn info(&self, _request: Request) -> Result, Status> { + let model_type = match self.info.model_type { + ModelType::Classifier(_) => grpc::ModelType::Classifier, + ModelType::Embedding(_) => grpc::ModelType::Embedding, + ModelType::Reranker(_) => grpc::ModelType::Reranker, + }; + + Ok(Response::new(InfoResponse { + version: self.info.version.to_string(), + sha: self.info.sha.map(|s| s.to_string()), + docker_label: self.info.docker_label.map(|s| s.to_string()), + model_id: self.info.model_id.clone(), + model_sha: self.info.model_sha.clone(), + model_dtype: self.info.model_dtype.clone(), + model_type: model_type.into(), + max_concurrent_requests: self.info.max_concurrent_requests as u32, + max_input_length: self.info.max_input_length as u32, + max_batch_tokens: self.info.max_batch_tokens as u32, + max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), + max_client_batch_size: self.info.max_client_batch_size as u32, + tokenization_workers: self.info.tokenization_workers as u32, + })) + } +} + +#[tonic::async_trait] +impl grpc::embed_server::Embed for TextEmbeddingsService { + #[instrument(skip_all)] + async fn embed( + &self, + request: Request, + ) -> Result, Status> { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + + let request = request.into_inner(); + let (response, metadata) = self.embed_inner(request, permit).await?; + let headers = HeaderMap::from(metadata); + + metrics::increment_counter!("te_request_success", "method" => "single"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } + + type EmbedStreamStream = UnboundedReceiverStream>; + + #[instrument(skip_all)] + async fn embed_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (embed_sender, mut embed_receiver) = mpsc::channel::<( + EmbedRequest, + oneshot::Sender>, + )>(self.max_parallel_stream_requests); + + // Required for the async move below + let local = self.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some((request, mut sender)) = embed_receiver.recv().await { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local.infer.acquire_permit().await; + + // Required for the async move below + let task_local = local.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = task_local.embed_inner(request, permit) => { + let _ = sender.send(response.map(|(r, _m)| r)); + } + _ = sender.closed() => {} + } + }); + } + }); + + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + // Iterate on input + while let Some(request) = request_stream.next().await { + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + match request { + Ok(request) => embed_sender + .send((request, result_sender)) + .await + .expect("`embed_receiver` was dropped. This is a bug."), + Err(status) => { + // Request is malformed + let _ = result_sender.send(Err(status)); + } + }; + } + }); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(result_receiver) = intermediate_receiver.recv().await { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = result_receiver => { + let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug.")); + } + _ = response_sender.closed() => {} + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new( + response_receiver, + ))) + } +} + +#[tonic::async_trait] +impl grpc::predict_server::Predict for TextEmbeddingsService { + #[instrument(skip_all)] + async fn predict( + &self, + request: Request, + ) -> Result, Status> { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + + let request = request.into_inner(); + let (response, metadata) = self.predict_inner(request, permit).await?; + let headers = HeaderMap::from(metadata); + + metrics::increment_counter!("te_request_success", "method" => "single"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } + + type PredictStreamStream = UnboundedReceiverStream>; + + #[instrument(skip_all)] + async fn predict_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (predict_sender, mut predict_receiver) = mpsc::channel::<( + PredictRequest, + oneshot::Sender>, + )>(self.max_parallel_stream_requests); + + // Required for the async move below + let local = self.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some((request, mut sender)) = predict_receiver.recv().await { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local.infer.acquire_permit().await; + + // Required for the async move below + let task_local = local.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = task_local.predict_inner(request, permit) => { + let _ = sender.send(response.map(|(r, _m)| r)); + } + _ = sender.closed() => {} + } + }); + } + }); + + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + // Iterate on input + while let Some(request) = request_stream.next().await { + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + match request { + Ok(request) => predict_sender + .send((request, result_sender)) + .await + .expect("`predict_receiver` was dropped. This is a bug."), + Err(status) => { + // Request is malformed + let _ = result_sender.send(Err(status)); + } + }; + } + }); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(result_receiver) = intermediate_receiver.recv().await { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = result_receiver => { + let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug.")); + } + _ = response_sender.closed() => {} + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new( + response_receiver, + ))) + } +} + +#[tonic::async_trait] +impl grpc::rerank_server::Rerank for TextEmbeddingsService { + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn rerank( + &self, + request: Request, + ) -> Result, Status> { + let span = Span::current(); + let start_time = Instant::now(); + + let request = request.into_inner(); + + match &self.info.model_type { + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + ModelType::Reranker(_) => Ok(()), + ModelType::Embedding(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; + + // Closure for rerank + let rerank_inner = move |query: String, + text: String, + truncate: bool, + raw_scores: bool, + infer: Infer| async move { + let permit = infer.acquire_permit().await; + + let response = infer + .predict((query, text), truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let score = response.results[0]; + + Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + score, + )) + }; + + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = request.texts.len(); + if batch_size > self.info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + self.info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let query_chars = request.query.chars().count(); + let mut total_compute_chars = query_chars * batch_size; + + for text in &request.texts { + total_compute_chars += text.chars().count(); + let local_infer = self.infer.clone(); + futures.push(rerank_inner( + request.query.clone(), + text.clone(), + request.truncate, + request.raw_scores, + local_infer, + )) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, ErrorResponse>>()?; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (index, r) in results.into_iter().enumerate() { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + let text = if request.return_text { + Some(request.texts[index].clone()) + } else { + None + }; + + ranks.push(Rank { + index: index as u32, + text, + score: r.4, + }) + } + + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); + + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + let message = RerankResponse { + ranks, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }; + + let headers = HeaderMap::from(response_metadata); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + message, + Extensions::default(), + )) + } + + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn rerank_stream( + &self, + request: Request>, + ) -> Result, Status> { + let span = Span::current(); + let start_time = Instant::now(); + + // Check model type + match &self.info.model_type { + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + ModelType::Reranker(_) => Ok(()), + ModelType::Embedding(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; + + // Closure for rerank + let rerank_inner = move |index: usize, + query: String, + text: String, + truncate: bool, + raw_scores: bool, + infer: Infer, + permit: OwnedSemaphorePermit| async move { + let response = infer + .predict((query, text.clone()), truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let score = response.results[0]; + + Ok::<(usize, usize, Duration, Duration, Duration, f32, String), ErrorResponse>(( + index, + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + score, + text, + )) + }; + + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (rerank_sender, mut rerank_receiver) = mpsc::channel::<( + (usize, String, String, bool, bool), + oneshot::Sender< + Result<(usize, usize, Duration, Duration, Duration, f32, String), ErrorResponse>, + >, + )>(self.max_parallel_stream_requests); + + // Required for the async move below + let local_infer = self.infer.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some(((index, query, text, truncate, raw_scores), mut sender)) = + rerank_receiver.recv().await + { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local_infer.acquire_permit().await; + + // Required for the async move below + let task_infer = local_infer.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + result = rerank_inner(index, query, text, truncate, raw_scores, task_infer, permit) => { + let _ = sender.send(result); + } + _ = sender.closed() => {} + } + }); + } + }); + + let mut index = 0; + let mut total_compute_chars = 0; + + // Set by first request + let mut raw_scores = None; + let mut return_text = None; + + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + while let Some(request) = request_stream.next().await { + let request = request?; + + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + // Set `raw_scores` and `return_text` using the values in the first request + if raw_scores.is_none() && return_text.is_none() { + raw_scores = Some(request.raw_scores); + return_text = Some(request.return_text); + } + + total_compute_chars += request.query.chars().count(); + total_compute_chars += request.text.chars().count(); + + rerank_sender + .send(( + ( + index, + request.query, + request.text, + request.truncate, + raw_scores.unwrap(), + ), + result_sender, + )) + .await + .expect("`rerank_receiver` was dropped. This is a bug."); + + index += 1; + } + + // Drop the sender to signal to the underlying task that we are done + drop(rerank_sender); + + let batch_size = index; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + // Iterate on result stream + while let Some(result_receiver) = intermediate_receiver.recv().await { + let r = result_receiver + .await + .expect("`result_sender` was dropped. This is a bug.")?; + + total_compute_tokens += r.1; + total_tokenization_time += r.2.as_nanos() as u64; + total_queue_time += r.3.as_nanos() as u64; + total_inference_time += r.4.as_nanos() as u64; + let text = if return_text.unwrap() { + Some(r.6) + } else { + None + }; + + ranks.push(Rank { + index: r.0 as u32, + text, + score: r.5, + }) + } + + // Check that the outputs have the correct size + if ranks.len() < batch_size { + let message = "rerank results is missing values".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Backend, + }; + metrics::increment_counter!("te_request_failure", "err" => "missing_values"); + Err(err)?; + } + + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); + + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + let message = RerankResponse { + ranks, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }; + + let headers = HeaderMap::from(response_metadata); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + message, + Extensions::default(), + )) + } +} + +pub async fn run( + infer: Infer, + info: Info, + addr: SocketAddr, + prom_builder: PrometheusBuilder, +) -> Result<(), anyhow::Error> { + prom_builder.install()?; + tracing::info!("Serving Prometheus metrics: 0.0.0.0:9000"); + + // Liveness service + let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); + // Info is always serving + health_reporter + .set_serving::>() + .await; + // Set all other services to not serving + // Their health will be updated in the task below + health_reporter + .set_not_serving::>() + .await; + health_reporter + .set_not_serving::>() + .await; + health_reporter + .set_not_serving::>() + .await; + + // Backend health watcher + let mut health_watcher = infer.health_watcher(); + + // Clone model_type and move it to the task + let health_watcher_model_type = info.model_type.clone(); + + // Update services health + tokio::spawn(async move { + while health_watcher.changed().await.is_ok() { + let health = *health_watcher.borrow_and_update(); + let status = match health { + true => ServingStatus::Serving, + false => ServingStatus::NotServing, + }; + + // Match on model type and set the health of the correct service(s) + // + // If Reranker, we have both a predict and rerank service + // + // This logic hints back to the user that if they try using the wrong service + // given the model type, it will always return an error. + // + // For example if the model type is `Embedding`, sending requests to `Rerank` will + // always return an `UNIMPLEMENTED` Status and both the `Rerank` and `Predict` services + // will have a `NOT_SERVING` ServingStatus. + match health_watcher_model_type { + ModelType::Classifier(_) => { + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await + } + ModelType::Embedding(_) => { + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await + } + ModelType::Reranker(_) => { + // Reranker has both a predict and rerank service + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await; + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await; + } + }; + } + }); + + // gRPC reflection + let file_descriptor_set: &[u8] = tonic::include_file_descriptor_set!("descriptor"); + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(file_descriptor_set) + .build()?; + + // Main service + let service = TextEmbeddingsService::new(infer, info); + + // Create gRPC server + tracing::info!("Starting gRPC server: {}", &addr); + Server::builder() + .add_service(health_service) + .add_service(reflection_service) + .add_service(grpc::InfoServer::new(service.clone())) + .add_service(grpc::EmbedServer::new(service.clone())) + .add_service(grpc::PredictServer::new(service.clone())) + .add_service(grpc::RerankServer::new(service)) + .serve_with_shutdown(addr, shutdown::shutdown_signal()) + .await?; + + Ok(()) +} + +impl From for Status { + fn from(value: ErrorResponse) -> Self { + let code = match value.error_type { + ErrorType::Unhealthy => Code::Unavailable, + ErrorType::Backend => Code::FailedPrecondition, + ErrorType::Overloaded => Code::ResourceExhausted, + ErrorType::Validation => Code::InvalidArgument, + ErrorType::Tokenizer => Code::FailedPrecondition, + }; + + Status::new(code, value.error) + } +} diff --git a/router/src/http/mod.rs b/router/src/http/mod.rs new file mode 100644 index 00000000..3214bbf0 --- /dev/null +++ b/router/src/http/mod.rs @@ -0,0 +1,2 @@ +pub mod server; +mod types; diff --git a/router/src/http/server.rs b/router/src/http/server.rs new file mode 100644 index 00000000..bd67a5bd --- /dev/null +++ b/router/src/http/server.rs @@ -0,0 +1,888 @@ +/// HTTP Server logic +use crate::http::types::{ + EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, + OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, + PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, +}; +use crate::{ + shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, + ResponseMetadata, +}; +use anyhow::Context; +use axum::extract::Extension; +use axum::http::HeaderValue; +use axum::http::{HeaderMap, Method, StatusCode}; +use axum::routing::{get, post}; +use axum::{http, Json, Router}; +use axum_tracing_opentelemetry::middleware::OtelAxumLayer; +use futures::future::join_all; +use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; +use std::env; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use text_embeddings_backend::BackendError; +use text_embeddings_core::infer::{Infer, InferResponse}; +use text_embeddings_core::TextEmbeddingsError; +use tokio::sync::OwnedSemaphorePermit; +use tower_http::cors::{AllowOrigin, CorsLayer}; +use tracing::instrument; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; + +///Text Embeddings Inference endpoint info +#[utoipa::path( +get, +tag = "Text Embeddings Inference", +path = "/info", +responses((status = 200, description = "Served model info", body = Info)) +)] +#[instrument] +async fn get_model_info(info: Extension) -> Json { + Json(info.0) +} + +#[utoipa::path( +get, +tag = "Text Embeddings Inference", +path = "/health", +responses( +(status = 200, description = "Everything is working fine"), +(status = 503, description = "Text embeddings Inference is down", body = ErrorResponse, +example = json ! ({"error": "unhealthy", "error_type": "unhealthy"})), +) +)] +#[instrument(skip(infer))] +/// Health check method +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { + match infer.health().await { + true => Ok(()), + false => Err(ErrorResponse { + error: "unhealthy".to_string(), + error_type: ErrorType::Unhealthy, + })?, + } +} + +/// Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/predict", +request_body = PredictRequest, +responses( +(status = 200, description = "Predictions", body = PredictResponse), +(status = 424, description = "Prediction Error", body = ErrorResponse, +example = json ! ({"error": "Inference failed", "error_type": "backend"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), +(status = 422, description = "Tokenization error", body = ErrorResponse, +example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), +(status = 413, description = "Batch size error", body = ErrorResponse, +example = json ! ({"error": "Batch size error", "error_type": "validation"})), +) +)] +#[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) +)] +async fn predict( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + // Closure for predict + let predict_inner = move |inputs: Sequence, + truncate: bool, + raw_scores: bool, + infer: Infer, + info: Info, + permit: Option| async move { + let permit = match permit { + None => infer.acquire_permit().await, + Some(permit) => permit, + }; + + let response = infer + .predict(inputs, truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let id2label = match &info.model_type { + ModelType::Classifier(classifier) => &classifier.id2label, + ModelType::Reranker(classifier) => &classifier.id2label, + _ => panic!(), + }; + + let mut predictions: Vec = { + // Map score to label + response + .results + .into_iter() + .enumerate() + .map(|(i, s)| Prediction { + score: s, + label: id2label.get(&i.to_string()).unwrap().clone(), + }) + .collect() + }; + // Reverse sort + predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + predictions.reverse(); + + Ok::<(usize, Duration, Duration, Duration, Vec), ErrorResponse>(( + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + predictions, + )) + }; + + let (response, metadata) = match req.inputs { + PredictInput::Single(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = inputs.count_chars(); + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( + inputs, + req.truncate, + req.raw_scores, + infer.0, + info.0, + Some(permit), + ) + .await?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + ( + PredictResponse::Single(predictions), + ResponseMetadata::new( + compute_chars, + prompt_tokens, + start_time, + tokenization, + queue, + inference, + ), + ) + } + PredictInput::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.count_chars(); + let local_infer = infer.clone(); + let local_info = info.clone(); + futures.push(predict_inner( + input, + req.truncate, + req.raw_scores, + local_infer.0, + local_info.0, + None, + )) + } + let results = join_all(futures).await.into_iter().collect::)>, + ErrorResponse, + >>()?; + + let mut predictions = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for r in results { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + predictions.push(r.4); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + PredictResponse::Batch(predictions), + ResponseMetadata::new( + compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); + + let headers = HeaderMap::from(metadata); + + tracing::info!("Success"); + + Ok((headers, Json(response))) +} + +/// Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with +/// a single class. +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/rerank", +request_body = RerankRequest, +responses( +(status = 200, description = "Ranks", body = RerankResponse), +(status = 424, description = "Rerank Error", body = ErrorResponse, +example = json ! ({"error": "Inference failed", "error_type": "backend"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), +(status = 422, description = "Tokenization error", body = ErrorResponse, +example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), +(status = 413, description = "Batch size error", body = ErrorResponse, +example = json ! ({"error": "Batch size error", "error_type": "validation"})), +) +)] +#[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) +)] +async fn rerank( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + match &info.model_type { + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))) + } + ModelType::Reranker(_) => Ok(()), + ModelType::Embedding(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))) + } + } + .map_err(|err| { + tracing::error!("{err}"); + ErrorResponse::from(err) + })?; + + // Closure for rerank + let rerank_inner = move |query: String, + text: String, + truncate: bool, + raw_scores: bool, + infer: Infer| async move { + let permit = infer.acquire_permit().await; + + let response = infer + .predict((query, text), truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let score = response.results[0]; + + Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + score, + )) + }; + + let (response, metadata) = { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = req.texts.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let query_chars = req.query.chars().count(); + let mut compute_chars = query_chars * batch_size; + + for text in &req.texts { + compute_chars += text.chars().count(); + let local_infer = infer.clone(); + futures.push(rerank_inner( + req.query.clone(), + text.clone(), + req.truncate, + req.raw_scores, + local_infer.0, + )) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, ErrorResponse>>()?; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (index, r) in results.into_iter().enumerate() { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + let text = if req.return_text { + Some(req.texts[index].clone()) + } else { + None + }; + + ranks.push(Rank { + index, + text, + score: r.4, + }) + } + + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); + + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + RerankResponse(ranks), + ResponseMetadata::new( + compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ), + ) + }; + + metadata.record_span(&span); + metadata.record_metrics(); + + let headers = HeaderMap::from(metadata); + + tracing::info!("Success"); + + Ok((headers, Json(response))) +} + +/// Get Embeddings. Returns a 424 status code if the model is not an embedding model. +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/embed", +request_body = EmbedRequest, +responses( +(status = 200, description = "Embeddings", body = EmbedResponse), +(status = 424, description = "Embedding Error", body = ErrorResponse, +example = json ! ({"error": "Inference failed", "error_type": "backend"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), +(status = 422, description = "Tokenization error", body = ErrorResponse, +example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), +(status = 413, description = "Batch size error", body = ErrorResponse, +example = json ! ({"error": "Batch size error", "error_type": "validation"})), +) +)] +#[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) +)] +async fn embed( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let (response, metadata) = match req.inputs { + Input::Single(input) => { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = input.chars().count(); + + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .embed(input, req.truncate, req.normalize, permit) + .await + .map_err(ErrorResponse::from)?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + ( + EmbedResponse(vec![response.results]), + ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ), + ) + } + Input::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.chars().count(); + + let local_infer = infer.clone(); + futures.push(async move { + let permit = local_infer.acquire_permit().await; + local_infer + .embed(input, req.truncate, req.normalize, permit) + .await + }) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, TextEmbeddingsError>>() + .map_err(ErrorResponse::from)?; + + let mut embeddings = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for r in results { + total_tokenization_time += r.tokenization.as_nanos() as u64; + total_queue_time += r.queue.as_nanos() as u64; + total_inference_time += r.inference.as_nanos() as u64; + total_compute_tokens += r.prompt_tokens; + embeddings.push(r.results); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + EmbedResponse(embeddings), + ResponseMetadata::new( + compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); + + let headers = HeaderMap::from(metadata); + + tracing::info!("Success"); + + Ok((headers, Json(response))) +} + +/// OpenAI compatible route. Returns a 424 status code if the model is not an embedding model. +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/embeddings", +request_body = OpenAICompatRequest, +responses( +(status = 200, description = "Embeddings", body = OpenAICompatResponse), +(status = 424, description = "Embedding Error", body = OpenAICompatErrorResponse, +example = json ! ({"message": "Inference failed", "type": "backend"})), +(status = 429, description = "Model is overloaded", body = OpenAICompatErrorResponse, +example = json ! ({"message": "Model is overloaded", "type": "overloaded"})), +(status = 422, description = "Tokenization error", body = OpenAICompatErrorResponse, +example = json ! ({"message": "Tokenization error", "type": "tokenizer"})), +(status = 413, description = "Batch size error", body = OpenAICompatErrorResponse, +example = json ! ({"message": "Batch size error", "type": "validation"})), +) +)] +#[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) +)] +async fn openai_embed( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> +{ + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let (embeddings, metadata) = match req.input { + Input::Single(input) => { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = input.chars().count(); + + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .embed(input, false, true, permit) + .await + .map_err(ErrorResponse::from)?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + ( + vec![OpenAICompatEmbedding { + object: "embedding", + embedding: response.results, + index: 0, + }], + ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ), + ) + } + Input::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.chars().count(); + + let local_infer = infer.clone(); + futures.push(async move { + let permit = local_infer.acquire_permit().await; + local_infer.embed(input, false, true, permit).await + }) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, TextEmbeddingsError>>() + .map_err(ErrorResponse::from)?; + + let mut embeddings = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (i, r) in results.into_iter().enumerate() { + total_tokenization_time += r.tokenization.as_nanos() as u64; + total_queue_time += r.queue.as_nanos() as u64; + total_inference_time += r.inference.as_nanos() as u64; + total_compute_tokens += r.prompt_tokens; + embeddings.push(OpenAICompatEmbedding { + object: "embedding", + embedding: r.results, + index: i, + }); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + embeddings, + ResponseMetadata::new( + compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); + + let compute_tokens = metadata.compute_tokens; + let headers = HeaderMap::from(metadata); + + tracing::info!("Success"); + + let response = OpenAICompatResponse { + object: "list", + data: embeddings, + model: info.model_id.clone(), + usage: OpenAICompatUsage { + prompt_tokens: compute_tokens, + total_tokens: compute_tokens, + }, + }; + Ok((headers, Json(response))) +} + +/// Prometheus metrics scrape endpoint +#[utoipa::path( +get, +tag = "Text Embeddings Inference", +path = "/metrics", +responses((status = 200, description = "Prometheus Metrics", body = String)) +)] +async fn metrics(prom_handle: Extension) -> String { + prom_handle.render() +} + +/// Serving method +pub async fn run( + infer: Infer, + info: Info, + addr: SocketAddr, + prom_builder: PrometheusBuilder, +) -> Result<(), anyhow::Error> { + // OpenAPI documentation + #[derive(OpenApi)] + #[openapi( + paths( + get_model_info, + health, + predict, + rerank, + embed, + openai_embed, + metrics, + ), + components( + schemas( + PredictInput, + Input, + Info, + ModelType, + ClassifierModel, + EmbeddingModel, + PredictRequest, + Prediction, + PredictResponse, + OpenAICompatRequest, + OpenAICompatEmbedding, + OpenAICompatUsage, + OpenAICompatResponse, + RerankRequest, + Rank, + RerankResponse, + EmbedRequest, + EmbedResponse, + ErrorResponse, + OpenAICompatErrorResponse, + ErrorType, + ) + ), + tags( + (name = "Text Embeddings Inference", description = "Hugging Face Text Embeddings Inference API") + ), + info( + title = "Text Embeddings Inference", + license( + name = "HFOIL", + ) + ) + )] + struct ApiDoc; + + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = + env::var("CORS_ALLOW_ORIGIN").ok().map(|cors_allow_origin| { + let cors_allow_origin = cors_allow_origin.split(','); + AllowOrigin::list( + cors_allow_origin.map(|origin| origin.parse::().unwrap()), + ) + }); + + let prom_handle = prom_builder + .install_recorder() + .context("failed to install metrics recorder")?; + + // CORS layer + let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); + let cors_layer = CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_headers([http::header::CONTENT_TYPE]) + .allow_origin(allow_origin); + + // Create router + let app = Router::new() + .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) + // Base routes + .route("/info", get(get_model_info)) + .route("/embed", post(embed)) + .route("/predict", post(predict)) + .route("/rerank", post(rerank)) + // OpenAI compat route + .route("/embeddings", post(openai_embed)) + // Base Health route + .route("/health", get(health)) + // Inference API health route + .route("/", get(health)) + // AWS Sagemaker health route + .route("/ping", get(health)) + // Prometheus metrics route + .route("/metrics", get(metrics)); + + // Set default routes + let app = match &info.model_type { + ModelType::Classifier(_) => { + app.route("/", post(predict)) + // AWS Sagemaker route + .route("/invocations", post(predict)) + } + ModelType::Reranker(_) => { + app.route("/", post(rerank)) + // AWS Sagemaker route + .route("/invocations", post(rerank)) + } + ModelType::Embedding(_) => { + app.route("/", post(embed)) + // AWS Sagemaker route + .route("/invocations", post(embed)) + } + }; + + let app = app + .layer(Extension(infer)) + .layer(Extension(info)) + .layer(Extension(prom_handle.clone())) + .layer(OtelAxumLayer::default()) + .layer(cors_layer); + + // Run server + axum::Server::bind(&addr) + .serve(app.into_make_service()) + // Wait until all requests are finished to shut down + .with_graceful_shutdown(shutdown::shutdown_signal()) + .await?; + + Ok(()) +} + +impl From<&ErrorType> for StatusCode { + fn from(value: &ErrorType) -> Self { + match value { + ErrorType::Unhealthy => StatusCode::SERVICE_UNAVAILABLE, + ErrorType::Backend => StatusCode::FAILED_DEPENDENCY, + ErrorType::Overloaded => StatusCode::TOO_MANY_REQUESTS, + ErrorType::Tokenizer => StatusCode::UNPROCESSABLE_ENTITY, + ErrorType::Validation => StatusCode::PAYLOAD_TOO_LARGE, + } + } +} + +impl From for OpenAICompatErrorResponse { + fn from(value: ErrorResponse) -> Self { + OpenAICompatErrorResponse { + message: value.error, + code: StatusCode::from(&value.error_type).as_u16(), + error_type: value.error_type, + } + } +} + +/// Convert to Axum supported formats +impl From for (StatusCode, Json) { + fn from(err: ErrorResponse) -> Self { + (StatusCode::from(&err.error_type), Json(err)) + } +} + +impl From for (StatusCode, Json) { + fn from(err: ErrorResponse) -> Self { + (StatusCode::from(&err.error_type), Json(err.into())) + } +} diff --git a/router/src/http/types.rs b/router/src/http/types.rs new file mode 100644 index 00000000..dbdba0f1 --- /dev/null +++ b/router/src/http/types.rs @@ -0,0 +1,324 @@ +use crate::ErrorType; +use serde::de::{SeqAccess, Visitor}; +use serde::{de, Deserialize, Deserializer, Serialize}; +use serde_json::json; +use std::fmt::Formatter; +use text_embeddings_core::tokenization::EncodingInput; +use utoipa::openapi::{RefOr, Schema}; +use utoipa::ToSchema; + +#[derive(Debug)] +pub(crate) enum Sequence { + Single(String), + Pair(String, String), +} + +impl Sequence { + pub(crate) fn count_chars(&self) -> usize { + match self { + Sequence::Single(s) => s.chars().count(), + Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(), + } + } +} + +impl From for EncodingInput { + fn from(value: Sequence) -> Self { + match value { + Sequence::Single(s) => Self::Single(s), + Sequence::Pair(s1, s2) => Self::Dual(s1, s2), + } + } +} + +#[derive(Debug)] +pub(crate) enum PredictInput { + Single(Sequence), + Batch(Vec), +} + +impl<'de> Deserialize<'de> for PredictInput { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum Internal { + Single(String), + Multiple(Vec), + } + + struct PredictInputVisitor; + + impl<'de> Visitor<'de> for PredictInputVisitor { + type Value = PredictInput; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str( + "a string, \ + a pair of strings [string, string] \ + or a batch of mixed strings and pairs [[string], [string, string], ...]", + ) + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(PredictInput::Single(Sequence::Single(v.to_string()))) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let sequence_from_vec = |mut value: Vec| { + // Validate that value is correct + match value.len() { + 1 => Ok(Sequence::Single(value.pop().unwrap())), + 2 => { + // Second element is last + let second = value.pop().unwrap(); + let first = value.pop().unwrap(); + Ok(Sequence::Pair(first, second)) + } + // Sequence can only be a single string or a pair of strings + _ => Err(de::Error::invalid_length(value.len(), &self)), + } + }; + + // Get first element + // This will determine if input is a batch or not + let s = match seq + .next_element::()? + .ok_or_else(|| de::Error::invalid_length(0, &self))? + { + // Input is not a batch + // Return early + Internal::Single(value) => { + // Option get second element + let second = seq.next_element()?; + + if seq.next_element::()?.is_some() { + // Error as we do not accept > 2 elements + return Err(de::Error::invalid_length(3, &self)); + } + + if let Some(second) = second { + // Second element exists + // This is a pair + return Ok(PredictInput::Single(Sequence::Pair(value, second))); + } else { + // Second element does not exist + return Ok(PredictInput::Single(Sequence::Single(value))); + } + } + // Input is a batch + Internal::Multiple(value) => sequence_from_vec(value), + }?; + + let mut batch = Vec::with_capacity(32); + // Push first sequence + batch.push(s); + + // Iterate on all sequences + while let Some(value) = seq.next_element::>()? { + // Validate sequence + let s = sequence_from_vec(value)?; + // Push to batch + batch.push(s); + } + Ok(PredictInput::Batch(batch)) + } + } + + deserializer.deserialize_any(PredictInputVisitor) + } +} + +impl<'__s> ToSchema<'__s> for PredictInput { + fn schema() -> (&'__s str, RefOr) { + ( + "PredictInput", + utoipa::openapi::OneOfBuilder::new() + .item( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String) + .description(Some("A single string")), + ) + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A pair of strings")) + .min_items(Some(2)) + .max_items(Some(2)), + ) + .item( + utoipa::openapi::ArrayBuilder::new().items( + utoipa::openapi::OneOfBuilder::new() + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A single string")) + .min_items(Some(1)) + .max_items(Some(1)), + ) + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A pair of strings")) + .min_items(Some(2)) + .max_items(Some(2)), + ) + ).description(Some("A batch")), + ) + .description(Some( + "Model input. \ + Can be either a single string, a pair of strings or a batch of mixed single and pairs \ + of strings.", + )) + .example(Some(json!("What is Deep Learning?"))) + .into(), + ) + } +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct PredictRequest { + pub inputs: PredictInput, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub raw_scores: bool, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Prediction { + #[schema(example = "0.5")] + pub score: f32, + #[schema(example = "admiration")] + pub label: String, +} + +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum PredictResponse { + Single(Vec), + Batch(Vec>), +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct RerankRequest { + #[schema(example = "What is Deep Learning?")] + pub query: String, + #[schema(example = json!(["Deep Learning is ..."]))] + pub texts: Vec, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub raw_scores: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub return_text: bool, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Rank { + #[schema(example = "0")] + pub index: usize, + #[schema(nullable = true, example = "Deep Learning is ...", default = "null")] + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[schema(example = "1.0")] + pub score: f32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct RerankResponse(pub Vec); + +#[derive(Deserialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum Input { + Single(String), + Batch(Vec), +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct OpenAICompatRequest { + pub input: Input, + #[allow(dead_code)] + #[schema(nullable = true, example = "null")] + pub model: Option, + #[allow(dead_code)] + #[schema(nullable = true, example = "null")] + pub user: Option, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatEmbedding { + #[schema(example = "embedding")] + pub object: &'static str, + #[schema(example = json!([0.0, 1.0, 2.0]))] + pub embedding: Vec, + #[schema(example = "0")] + pub index: usize, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatUsage { + #[schema(example = "512")] + pub prompt_tokens: usize, + #[schema(example = "512")] + pub total_tokens: usize, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatResponse { + #[schema(example = "list")] + pub object: &'static str, + pub data: Vec, + #[schema(example = "thenlper/gte-base")] + pub model: String, + pub usage: OpenAICompatUsage, +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct EmbedRequest { + pub inputs: Input, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default = "default_normalize")] + #[schema(default = "true", example = "true")] + pub normalize: bool, +} + +fn default_normalize() -> bool { + true +} + +#[derive(Serialize, ToSchema)] +#[schema(example = json!([[0.0, 1.0, 2.0]]))] +pub(crate) struct EmbedResponse(pub Vec>); + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatErrorResponse { + pub message: String, + pub code: u16, + #[serde(rename(serialize = "type"))] + pub error_type: ErrorType, +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 9822310d..d7164707 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,395 +1,251 @@ +use ::http::HeaderMap; /// Text Embedding Inference Webserver -pub mod server; - -use serde::de::{SeqAccess, Visitor}; -use serde::{de, Deserialize, Deserializer, Serialize}; -use serde_json::json; +use anyhow::Result; +use serde::Serialize; use std::collections::HashMap; -use std::fmt::Formatter; -use text_embeddings_core::tokenization::EncodingInput; -use utoipa::openapi::{RefOr, Schema}; -use utoipa::ToSchema; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use text_embeddings_core::infer::Infer; +use text_embeddings_core::TextEmbeddingsError; +use tracing::Span; + +mod prometheus; + +#[cfg(feature = "http")] +mod http; + +#[cfg(feature = "grpc")] +mod grpc; +mod shutdown; + +/// Crate entrypoint +pub async fn run(infer: Infer, info: Info, addr: SocketAddr) -> Result<()> { + let prom_builder = prometheus::prometheus_builer(info.max_input_length)?; -#[derive(Clone, Debug, Serialize, ToSchema)] + if cfg!(feature = "http") { + #[cfg(feature = "http")] + { + return http::server::run(infer, info, addr, prom_builder).await; + } + } + + if cfg!(feature = "grpc") { + #[cfg(feature = "grpc")] + { + return grpc::server::run(infer, info, addr, prom_builder).await; + } + } + + anyhow::bail!("You must use one of `http` or `grpc`"); +} + +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct EmbeddingModel { - #[schema(example = "cls")] + #[cfg_attr(feature = "http", schema(example = "cls"))] pub pooling: String, } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct ClassifierModel { - #[schema(example = json!({"0": "LABEL"}))] + #[cfg_attr(feature = "http", schema(example = json!({"0": "LABEL"})))] pub id2label: HashMap, - #[schema(example = json!({"LABEL": 0}))] + #[cfg_attr(feature = "http", schema(example = json!({"LABEL": 0})))] pub label2id: HashMap, } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] #[serde(rename_all = "lowercase")] pub enum ModelType { Classifier(ClassifierModel), Embedding(EmbeddingModel), + Reranker(ClassifierModel), } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct Info { /// Model info - #[schema(example = "thenlper/gte-base")] + #[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))] pub model_id: String, - #[schema(nullable = true, example = "fca14538aa9956a46526bd1d0d11d69e19b5a101")] + #[cfg_attr( + feature = "http", + schema(nullable = true, example = "fca14538aa9956a46526bd1d0d11d69e19b5a101") + )] pub model_sha: Option, - #[schema(example = "float16")] + #[cfg_attr(feature = "http", schema(example = "float16"))] pub model_dtype: String, pub model_type: ModelType, /// Router Parameters - #[schema(example = "128")] + #[cfg_attr(feature = "http", schema(example = "128"))] pub max_concurrent_requests: usize, - #[schema(example = "512")] + #[cfg_attr(feature = "http", schema(example = "512"))] pub max_input_length: usize, - #[schema(example = "2048")] + #[cfg_attr(feature = "http", schema(example = "2048"))] pub max_batch_tokens: usize, - #[schema(nullable = true, example = "null", default = "null")] + #[cfg_attr( + feature = "http", + schema(nullable = true, example = "null", default = "null") + )] pub max_batch_requests: Option, - #[schema(example = "32")] + #[cfg_attr(feature = "http", schema(example = "32"))] pub max_client_batch_size: usize, - #[schema(example = "4")] + #[cfg_attr(feature = "http", schema(example = "4"))] pub tokenization_workers: usize, /// Router Info - #[schema(example = "0.5.0")] + #[cfg_attr(feature = "http", schema(example = "0.5.0"))] pub version: &'static str, - #[schema(nullable = true, example = "null")] + #[cfg_attr(feature = "http", schema(nullable = true, example = "null"))] pub sha: Option<&'static str>, - #[schema(nullable = true, example = "null")] + #[cfg_attr(feature = "http", schema(nullable = true, example = "null"))] pub docker_label: Option<&'static str>, } -#[derive(Debug)] -pub(crate) enum Sequence { - Single(String), - Pair(String, String), +#[derive(Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] +pub enum ErrorType { + Unhealthy, + Backend, + Overloaded, + Validation, + Tokenizer, } -impl Sequence { - pub(crate) fn count_chars(&self) -> usize { - match self { - Sequence::Single(s) => s.chars().count(), - Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(), - } - } +#[derive(Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] +pub struct ErrorResponse { + pub error: String, + pub error_type: ErrorType, } -impl From for EncodingInput { - fn from(value: Sequence) -> Self { - match value { - Sequence::Single(s) => Self::Single(s), - Sequence::Pair(s1, s2) => Self::Dual(s1, s2), +impl From for ErrorResponse { + fn from(err: TextEmbeddingsError) -> Self { + let error_type = match err { + TextEmbeddingsError::Tokenizer(_) => ErrorType::Tokenizer, + TextEmbeddingsError::Validation(_) => ErrorType::Validation, + TextEmbeddingsError::Overloaded(_) => ErrorType::Overloaded, + TextEmbeddingsError::Backend(_) => ErrorType::Backend, + }; + Self { + error: err.to_string(), + error_type, } } } -#[derive(Debug)] -pub(crate) enum PredictInput { - Single(Sequence), - Batch(Vec), -} - -impl<'de> Deserialize<'de> for PredictInput { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Internal { - Single(String), - Multiple(Vec), +struct ResponseMetadata { + compute_chars: usize, + compute_tokens: usize, + start_time: Instant, + tokenization_time: Duration, + queue_time: Duration, + inference_time: Duration, +} + +impl ResponseMetadata { + fn new( + compute_chars: usize, + compute_tokens: usize, + start_time: Instant, + tokenization_time: Duration, + queue_time: Duration, + inference_time: Duration, + ) -> Self { + Self { + compute_chars, + compute_tokens, + start_time, + tokenization_time, + queue_time, + inference_time, } - - struct PredictInputVisitor; - - impl<'de> Visitor<'de> for PredictInputVisitor { - type Value = PredictInput; - - fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter.write_str( - "a string, \ - a pair of strings [string, string] \ - or a batch of mixed strings and pairs [[string], [string, string], ...]", - ) - } - - fn visit_str(self, v: &str) -> Result - where - E: de::Error, - { - Ok(PredictInput::Single(Sequence::Single(v.to_string()))) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let sequence_from_vec = |mut value: Vec| { - // Validate that value is correct - match value.len() { - 1 => Ok(Sequence::Single(value.pop().unwrap())), - 2 => { - // Second element is last - let second = value.pop().unwrap(); - let first = value.pop().unwrap(); - Ok(Sequence::Pair(first, second)) - } - // Sequence can only be a single string or a pair of strings - _ => Err(de::Error::invalid_length(value.len(), &self)), - } - }; - - // Get first element - // This will determine if input is a batch or not - let s = match seq - .next_element::()? - .ok_or_else(|| de::Error::invalid_length(0, &self))? - { - // Input is not a batch - // Return early - Internal::Single(value) => { - // Option get second element - let second = seq.next_element()?; - - if seq.next_element::()?.is_some() { - // Error as we do not accept > 2 elements - return Err(de::Error::invalid_length(3, &self)); - } - - if let Some(second) = second { - // Second element exists - // This is a pair - return Ok(PredictInput::Single(Sequence::Pair(value, second))); - } else { - // Second element does not exist - return Ok(PredictInput::Single(Sequence::Single(value))); - } - } - // Input is a batch - Internal::Multiple(value) => sequence_from_vec(value), - }?; - - let mut batch = Vec::with_capacity(32); - // Push first sequence - batch.push(s); - - // Iterate on all sequences - while let Some(value) = seq.next_element::>()? { - // Validate sequence - let s = sequence_from_vec(value)?; - // Push to batch - batch.push(s); - } - Ok(PredictInput::Batch(batch)) - } - } - - deserializer.deserialize_any(PredictInputVisitor) } -} -impl<'__s> ToSchema<'__s> for PredictInput { - fn schema() -> (&'__s str, RefOr) { - ( - "PredictInput", - utoipa::openapi::OneOfBuilder::new() - .item( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String) - .description(Some("A single string")), - ) - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A pair of strings")) - .min_items(Some(2)) - .max_items(Some(2)), - ) - .item( - utoipa::openapi::ArrayBuilder::new().items( - utoipa::openapi::OneOfBuilder::new() - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A single string")) - .min_items(Some(1)) - .max_items(Some(1)), - ) - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A pair of strings")) - .min_items(Some(2)) - .max_items(Some(2)), - ) - ).description(Some("A batch")), - ) - .description(Some( - "Model input. \ - Can be either a single string, a pair of strings or a batch of mixed single and pairs \ - of strings.", - )) - .example(Some(json!("What is Deep Learning?"))) - .into(), - ) + fn record_span(&self, span: &Span) { + // Tracing metadata + span.record("compute_chars", self.compute_chars); + span.record("compute_tokens", self.compute_tokens); + span.record("total_time", format!("{:?}", self.start_time.elapsed())); + span.record("tokenization_time", format!("{:?}", self.tokenization_time)); + span.record("queue_time", format!("{:?}", self.queue_time)); + span.record("inference_time", format!("{:?}", self.inference_time)); } -} -#[derive(Deserialize, ToSchema)] -pub(crate) struct PredictRequest { - pub inputs: PredictInput, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub raw_scores: bool, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct Prediction { - #[schema(example = "0.5")] - score: f32, - #[schema(example = "admiration")] - label: String, -} - -#[derive(Serialize, ToSchema)] -#[serde(untagged)] -pub(crate) enum PredictResponse { - Single(Vec), - Batch(Vec>), -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct RerankRequest { - #[schema(example = "What is Deep Learning?")] - pub query: String, - #[schema(example = json!(["Deep Learning is ..."]))] - pub texts: Vec, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub raw_scores: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub return_text: bool, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct Rank { - #[schema(example = "0")] - pub index: usize, - #[schema(nullable = true, example = "Deep Learning is ...", default = "null")] - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - #[schema(example = "1.0")] - pub score: f32, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct RerankResponse(Vec); - -#[derive(Deserialize, ToSchema)] -#[serde(untagged)] -pub(crate) enum Input { - Single(String), - Batch(Vec), -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct OpenAICompatRequest { - pub input: Input, - #[allow(dead_code)] - #[schema(nullable = true, example = "null")] - model: Option, - #[allow(dead_code)] - #[schema(nullable = true, example = "null")] - user: Option, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatEmbedding { - #[schema(example = "embedding")] - object: &'static str, - #[schema(example = json!([0.0, 1.0, 2.0]))] - embedding: Vec, - #[schema(example = "0")] - index: usize, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatUsage { - #[schema(example = "512")] - prompt_tokens: usize, - #[schema(example = "512")] - total_tokens: usize, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatResponse { - #[schema(example = "list")] - object: &'static str, - data: Vec, - #[schema(example = "thenlper/gte-base")] - model: String, - usage: OpenAICompatUsage, -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct EmbedRequest { - pub inputs: Input, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default = "default_normalize")] - #[schema(default = "true", example = "true")] - pub normalize: bool, -} - -fn default_normalize() -> bool { - true -} - -#[derive(Serialize, ToSchema)] -#[schema(example = json!([[0.0, 1.0, 2.0]]))] -pub(crate) struct EmbedResponse(Vec>); - -#[derive(Serialize, ToSchema)] -pub(crate) enum ErrorType { - Unhealthy, - Backend, - Overloaded, - Validation, - Tokenizer, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct ErrorResponse { - pub error: String, - pub error_type: ErrorType, + fn record_metrics(&self) { + // Metrics + metrics::histogram!( + "te_request_duration", + self.start_time.elapsed().as_secs_f64() + ); + metrics::histogram!( + "te_request_tokenization_duration", + self.tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", self.queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + self.inference_time.as_secs_f64() + ); + } } -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatErrorResponse { - pub message: String, - pub code: u16, - #[serde(rename(serialize = "type"))] - pub error_type: ErrorType, +impl From for HeaderMap { + fn from(value: ResponseMetadata) -> Self { + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-compute-characters", + value.compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + value.compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-tokenization-time", + value + .tokenization_time + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-queue-time", + value.queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + value + .inference_time + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers + } } diff --git a/router/src/main.rs b/router/src/main.rs index bbb4c260..92dcea61 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,4 @@ use anyhow::{anyhow, Context, Result}; -use axum::http::HeaderValue; use clap::Parser; use hf_hub::api::tokio::ApiBuilder; use hf_hub::{Repo, RepoType}; @@ -18,10 +17,9 @@ use text_embeddings_core::download::{download_artifacts, download_pool_config}; use text_embeddings_core::infer::Infer; use text_embeddings_core::queue::Queue; use text_embeddings_core::tokenization::Tokenization; -use text_embeddings_router::{server, ClassifierModel, EmbeddingModel, Info, ModelType}; +use text_embeddings_router::{ClassifierModel, EmbeddingModel, Info, ModelType}; use tokenizers::decoders::metaspace::PrependScheme; use tokenizers::{PreTokenizerWrapper, Tokenizer}; -use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; @@ -121,9 +119,6 @@ struct Args { #[clap(long, env)] otlp_endpoint: Option, - - #[clap(long, env)] - cors_allow_origin: Option>, } #[derive(Debug, Deserialize)] @@ -236,14 +231,23 @@ async fn main() -> Result<()> { // Info model type let model_type = match &backend_model_type { - text_embeddings_backend::ModelType::Classifier => ModelType::Classifier(ClassifierModel { - id2label: config + text_embeddings_backend::ModelType::Classifier => { + let id2label = config .id2label - .context("`config.json` does not contain `id2label`")?, - label2id: config - .label2id - .context("`config.json` does not contain `label2id`")?, - }), + .context("`config.json` does not contain `id2label`")?; + let n_classes = id2label.len(); + let classifier_model = ClassifierModel { + id2label, + label2id: config + .label2id + .context("`config.json` does not contain `label2id`")?, + }; + if n_classes > 1 { + ModelType::Classifier(classifier_model) + } else { + ModelType::Reranker(classifier_model) + } + } text_embeddings_backend::ModelType::Embedding(pool) => { ModelType::Embedding(EmbeddingModel { pooling: pool.to_string(), @@ -319,7 +323,7 @@ async fn main() -> Result<()> { dtype.clone(), backend_model_type, args.uds_path, - args.otlp_endpoint, + args.otlp_endpoint.clone(), ) .context("Could not create backend")?; backend @@ -350,7 +354,7 @@ async fn main() -> Result<()> { model_dtype: dtype.to_string(), model_type, max_concurrent_requests: args.max_concurrent_requests, - max_input_length: config.max_position_embeddings, + max_input_length, max_batch_tokens: args.max_batch_tokens, tokenization_workers, max_batch_requests, @@ -368,23 +372,18 @@ async fn main() -> Result<()> { } }; - // CORS allowed origins - // map to go inside the option and then map to parse from String to HeaderValue - // Finally, convert to AllowOrigin - let cors_allow_origin: Option = args.cors_allow_origin.map(|cors_allow_origin| { - AllowOrigin::list( - cors_allow_origin - .iter() - .map(|origin| origin.parse::().unwrap()), - ) - }); - tracing::info!("Ready"); // Run axum server - server::run(infer, info, addr, cors_allow_origin) + text_embeddings_router::run(infer, info, addr) .await .unwrap(); + + if args.otlp_endpoint.is_some() { + // Shutdown tracer + global::shutdown_tracer_provider(); + } + Ok(()) } diff --git a/router/src/prometheus.rs b/router/src/prometheus.rs new file mode 100644 index 00000000..3c16684e --- /dev/null +++ b/router/src/prometheus.rs @@ -0,0 +1,36 @@ +use metrics_exporter_prometheus::{BuildError, Matcher, PrometheusBuilder}; + +pub fn prometheus_builer(max_input_length: usize) -> Result { + // Duration buckets + let duration_matcher = Matcher::Suffix(String::from("duration")); + let n_duration_buckets = 35; + let mut duration_buckets = Vec::with_capacity(n_duration_buckets); + // Minimum duration in seconds + let mut value = 0.00001; + for _ in 0..n_duration_buckets { + // geometric sequence + value *= 1.5; + duration_buckets.push(value); + } + + // Input Length buckets + let input_length_matcher = Matcher::Full(String::from("te_request_input_length")); + let input_length_buckets: Vec = (0..100) + .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) + .collect(); + + // Batch size buckets + let batch_size_matcher = Matcher::Full(String::from("te_batch_next_size")); + let batch_size_buckets: Vec = (0..2048).map(|x| (x + 1) as f64).collect(); + + // Batch tokens buckets + let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); + let batch_tokens_buckets: Vec = (0..100_000).map(|x| (x + 1) as f64).collect(); + + // Prometheus handler + PrometheusBuilder::new() + .set_buckets_for_metric(duration_matcher, &duration_buckets)? + .set_buckets_for_metric(input_length_matcher, &input_length_buckets)? + .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)? + .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) +} diff --git a/router/src/server.rs b/router/src/server.rs deleted file mode 100644 index 49c0c6aa..00000000 --- a/router/src/server.rs +++ /dev/null @@ -1,1109 +0,0 @@ -/// HTTP Server logic -use crate::{ - ClassifierModel, EmbedRequest, EmbedResponse, EmbeddingModel, ErrorResponse, ErrorType, Info, - Input, ModelType, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, - OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, - Prediction, Rank, RerankRequest, RerankResponse, Sequence, -}; -use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; -use axum::routing::{get, post}; -use axum::{http, Json, Router}; -use axum_tracing_opentelemetry::middleware::OtelAxumLayer; -use futures::future::join_all; -use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; -use std::net::SocketAddr; -use std::time::{Duration, Instant}; -use text_embeddings_backend::BackendError; -use text_embeddings_core::infer::{Infer, InferResponse}; -use text_embeddings_core::TextEmbeddingsError; -use tokio::signal; -use tower_http::cors::{AllowOrigin, CorsLayer}; -use tracing::instrument; -use utoipa::OpenApi; -use utoipa_swagger_ui::SwaggerUi; - -///Text Embeddings Inference endpoint info -#[utoipa::path( -get, -tag = "Text Embeddings Inference", -path = "/info", -responses((status = 200, description = "Served model info", body = Info)) -)] -#[instrument] -async fn get_model_info(info: Extension) -> Json { - Json(info.0) -} - -#[utoipa::path( -get, -tag = "Text Embeddings Inference", -path = "/health", -responses( -(status = 200, description = "Everything is working fine"), -(status = 503, description = "Text embeddings Inference is down", body = ErrorResponse, -example = json ! ({"error": "unhealthy", "error_type": "unhealthy"})), -) -)] -#[instrument(skip(infer))] -/// Health check method -async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { - match infer.health().await { - true => Ok(()), - false => Err(ErrorResponse { - error: "unhealthy".to_string(), - error_type: ErrorType::Unhealthy, - })?, - } -} - -/// Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model -#[utoipa::path( -post, -tag = "Text Embeddings Inference", -path = "/predict", -request_body = PredictRequest, -responses( -(status = 200, description = "Predictions", body = PredictResponse), -(status = 424, description = "Prediction Error", body = ErrorResponse, -example = json ! ({"error": "Inference failed", "error_type": "backend"})), -(status = 429, description = "Model is overloaded", body = ErrorResponse, -example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), -(status = 422, description = "Tokenization error", body = ErrorResponse, -example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), -(status = 413, description = "Batch size error", body = ErrorResponse, -example = json ! ({"error": "Batch size error", "error_type": "validation"})), -) -)] -#[instrument( - skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) -)] -async fn predict( - infer: Extension, - info: Extension, - Json(req): Json, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - // Closure for predict - let predict_inner = move |inputs: Sequence, - truncate: bool, - raw_scores: bool, - infer: Infer, - info: Info| async move { - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .predict(inputs, truncate, raw_scores, permit) - .await - .map_err(ErrorResponse::from)?; - - let id2label = match &info.model_type { - ModelType::Classifier(classifier) => &classifier.id2label, - _ => panic!(), - }; - - let mut predictions: Vec = { - // Map score to label - response - .results - .into_iter() - .enumerate() - .map(|(i, s)| Prediction { - score: s, - label: id2label.get(&i.to_string()).unwrap().clone(), - }) - .collect() - }; - // Reverse sort - predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); - predictions.reverse(); - - Ok::<(usize, Duration, Duration, Duration, Vec), ErrorResponse>(( - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - predictions, - )) - }; - - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = - match req.inputs { - PredictInput::Single(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "single"); - - let compute_chars = inputs.count_chars(); - let (prompt_tokens, tokenization, queue, inference, predictions) = - predict_inner(inputs, req.truncate, req.raw_scores, infer.0, info.0).await?; - - metrics::increment_counter!("te_request_success", "method" => "single"); - - ( - compute_chars, - prompt_tokens, - tokenization, - queue, - inference, - PredictResponse::Single(predictions), - ) - } - PredictInput::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.count_chars(); - let local_infer = infer.clone(); - let local_info = info.clone(); - futures.push(predict_inner( - input, - req.truncate, - req.raw_scores, - local_infer.0, - local_info.0, - )) - } - let results = join_all(futures).await.into_iter().collect::)>, - ErrorResponse, - >>()?; - - let mut predictions = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for r in results { - total_compute_tokens += r.0; - total_tokenization_time += r.1.as_nanos() as u64; - total_queue_time += r.2.as_nanos() as u64; - total_inference_time += r.3.as_nanos() as u64; - predictions.push(r.4); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - PredictResponse::Batch(predictions), - ) - } - }; - - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); - - tracing::info!("Success"); - - Ok((headers, Json(response))) -} - -/// Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with -/// a single class. -#[utoipa::path( -post, -tag = "Text Embeddings Inference", -path = "/rerank", -request_body = RerankRequest, -responses( -(status = 200, description = "Ranks", body = RerankResponse), -(status = 424, description = "Rerank Error", body = ErrorResponse, -example = json ! ({"error": "Inference failed", "error_type": "backend"})), -(status = 429, description = "Model is overloaded", body = ErrorResponse, -example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), -(status = 422, description = "Tokenization error", body = ErrorResponse, -example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), -(status = 413, description = "Batch size error", body = ErrorResponse, -example = json ! ({"error": "Batch size error", "error_type": "validation"})), -) -)] -#[instrument( - skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) -)] -async fn rerank( - infer: Extension, - info: Extension, - Json(req): Json, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - match &info.model_type { - ModelType::Classifier(classifier) => { - if classifier.id2label.len() > 1 { - metrics::increment_counter!("te_request_failure", "err" => "model_type"); - let message = "model is not a re-ranker model".to_string(); - Err(TextEmbeddingsError::Backend(BackendError::Inference( - message, - ))) - } else { - Ok(()) - } - } - ModelType::Embedding(_) => { - metrics::increment_counter!("te_request_failure", "err" => "model_type"); - let message = "model is not a classifier model".to_string(); - Err(TextEmbeddingsError::Backend(BackendError::Inference( - message, - ))) - } - } - .map_err(|err| { - tracing::error!("{err}"); - ErrorResponse::from(err) - })?; - - // Closure for rerank - let rerank_inner = move |query: String, - text: String, - truncate: bool, - raw_scores: bool, - infer: Infer| async move { - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - - let response = infer - .predict((query, text), truncate, raw_scores, permit) - .await - .map_err(ErrorResponse::from)?; - - let score = response.results[0]; - - Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - score, - )) - }; - - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = req.texts.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let query_chars = req.query.chars().count(); - let mut compute_chars = query_chars * batch_size; - - for text in &req.texts { - compute_chars += text.chars().count(); - let local_infer = infer.clone(); - futures.push(rerank_inner( - req.query.clone(), - text.clone(), - req.truncate, - req.raw_scores, - local_infer.0, - )) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, ErrorResponse>>()?; - - let mut ranks = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for (index, r) in results.into_iter().enumerate() { - total_compute_tokens += r.0; - total_tokenization_time += r.1.as_nanos() as u64; - total_queue_time += r.2.as_nanos() as u64; - total_inference_time += r.3.as_nanos() as u64; - let text = if req.return_text { - Some(req.texts[index].clone()) - } else { - None - }; - - ranks.push(Rank { - index, - text, - score: r.4, - }) - } - - // Reverse sort - ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); - ranks.reverse(); - - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - RerankResponse(ranks), - ) - }; - - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); - - tracing::info!("Success"); - - Ok((headers, Json(response))) -} - -/// Get Embeddings. Returns a 424 status code if the model is not an embedding model. -#[utoipa::path( -post, -tag = "Text Embeddings Inference", -path = "/embed", -request_body = EmbedRequest, -responses( -(status = 200, description = "Embeddings", body = EmbedResponse), -(status = 424, description = "Embedding Error", body = ErrorResponse, -example = json ! ({"error": "Inference failed", "error_type": "backend"})), -(status = 429, description = "Model is overloaded", body = ErrorResponse, -example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), -(status = 422, description = "Tokenization error", body = ErrorResponse, -example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), -(status = 413, description = "Batch size error", body = ErrorResponse, -example = json ! ({"error": "Batch size error", "error_type": "validation"})), -) -)] -#[instrument( - skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) -)] -async fn embed( - infer: Extension, - info: Extension, - Json(req): Json, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = - match req.inputs { - Input::Single(input) => { - metrics::increment_counter!("te_request_count", "method" => "single"); - - let compute_chars = input.chars().count(); - - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .embed(input, req.truncate, req.normalize, permit) - .await - .map_err(ErrorResponse::from)?; - - metrics::increment_counter!("te_request_success", "method" => "single"); - - ( - compute_chars, - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - EmbedResponse(vec![response.results]), - ) - } - Input::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.chars().count(); - - let local_infer = infer.clone(); - futures.push(async move { - let permit = local_infer.acquire_permit().await; - local_infer - .embed(input, req.truncate, req.normalize, permit) - .await - }) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, TextEmbeddingsError>>() - .map_err(ErrorResponse::from)?; - - let mut embeddings = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for r in results { - total_tokenization_time += r.tokenization.as_nanos() as u64; - total_queue_time += r.queue.as_nanos() as u64; - total_inference_time += r.inference.as_nanos() as u64; - total_compute_tokens += r.prompt_tokens; - embeddings.push(r.results); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - EmbedResponse(embeddings), - ) - } - }; - - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); - - tracing::info!("Success"); - - Ok((headers, Json(response))) -} - -/// OpenAI compatible route. Returns a 424 status code if the model is not an embedding model. -#[utoipa::path( -post, -tag = "Text Embeddings Inference", -path = "/embeddings", -request_body = OpenAICompatRequest, -responses( -(status = 200, description = "Embeddings", body = OpenAICompatResponse), -(status = 424, description = "Embedding Error", body = OpenAICompatErrorResponse, -example = json ! ({"message": "Inference failed", "type": "backend"})), -(status = 429, description = "Model is overloaded", body = OpenAICompatErrorResponse, -example = json ! ({"message": "Model is overloaded", "type": "overloaded"})), -(status = 422, description = "Tokenization error", body = OpenAICompatErrorResponse, -example = json ! ({"message": "Tokenization error", "type": "tokenizer"})), -(status = 413, description = "Batch size error", body = OpenAICompatErrorResponse, -example = json ! ({"message": "Batch size error", "type": "validation"})), -) -)] -#[instrument( - skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) -)] -async fn openai_embed( - infer: Extension, - info: Extension, - Json(req): Json, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> -{ - let span = tracing::Span::current(); - let start_time = Instant::now(); - - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, embeddings) = - match req.input { - Input::Single(input) => { - metrics::increment_counter!("te_request_count", "method" => "single"); - - let compute_chars = input.chars().count(); - - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .embed(input, false, true, permit) - .await - .map_err(ErrorResponse::from)?; - - metrics::increment_counter!("te_request_success", "method" => "single"); - - ( - compute_chars, - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - vec![OpenAICompatEmbedding { - object: "embedding", - embedding: response.results, - index: 0, - }], - ) - } - Input::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.chars().count(); - - let local_infer = infer.clone(); - futures.push(async move { - let permit = local_infer.acquire_permit().await; - local_infer.embed(input, false, true, permit).await - }) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, TextEmbeddingsError>>() - .map_err(ErrorResponse::from)?; - - let mut embeddings = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for (i, r) in results.into_iter().enumerate() { - total_tokenization_time += r.tokenization.as_nanos() as u64; - total_queue_time += r.queue.as_nanos() as u64; - total_inference_time += r.inference.as_nanos() as u64; - total_compute_tokens += r.prompt_tokens; - embeddings.push(OpenAICompatEmbedding { - object: "embedding", - embedding: r.results, - index: i, - }); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - embeddings, - ) - } - }; - - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); - - tracing::info!("Success"); - - let response = OpenAICompatResponse { - object: "list", - data: embeddings, - model: info.model_id.clone(), - usage: OpenAICompatUsage { - prompt_tokens: compute_tokens, - total_tokens: compute_tokens, - }, - }; - Ok((headers, Json(response))) -} - -/// Prometheus metrics scrape endpoint -#[utoipa::path( -get, -tag = "Text Embeddings Inference", -path = "/metrics", -responses((status = 200, description = "Prometheus Metrics", body = String)) -)] -async fn metrics(prom_handle: Extension) -> String { - prom_handle.render() -} - -/// Serving method -pub async fn run( - infer: Infer, - info: Info, - addr: SocketAddr, - allow_origin: Option, -) -> Result<(), axum::BoxError> { - // OpenAPI documentation - #[derive(OpenApi)] - #[openapi( - paths( - get_model_info, - health, - predict, - rerank, - embed, - openai_embed, - metrics, - ), - components( - schemas( - PredictInput, - Input, - Info, - ModelType, - ClassifierModel, - EmbeddingModel, - PredictRequest, - Prediction, - PredictResponse, - OpenAICompatRequest, - OpenAICompatEmbedding, - OpenAICompatUsage, - OpenAICompatResponse, - RerankRequest, - Rank, - RerankResponse, - EmbedRequest, - EmbedResponse, - ErrorResponse, - OpenAICompatErrorResponse, - ErrorType, - ) - ), - tags( - (name = "Text Embeddings Inference", description = "Hugging Face Text Embeddings Inference API") - ), - info( - title = "Text Embeddings Inference", - license( - name = "HFOIL", - ) - ) - )] - struct ApiDoc; - - // Duration buckets - let duration_matcher = Matcher::Suffix(String::from("duration")); - let n_duration_buckets = 35; - let mut duration_buckets = Vec::with_capacity(n_duration_buckets); - // Minimum duration in seconds - let mut value = 0.00001; - for _ in 0..n_duration_buckets { - // geometric sequence - value *= 1.5; - duration_buckets.push(value); - } - - // Input Length buckets - let input_length_matcher = Matcher::Full(String::from("te_request_input_length")); - let input_length_buckets: Vec = (0..100) - .map(|x| (info.max_input_length as f64 / 100.0) * (x + 1) as f64) - .collect(); - - // Batch size buckets - let batch_size_matcher = Matcher::Full(String::from("te_batch_next_size")); - let batch_size_buckets: Vec = (0..2048).map(|x| (x + 1) as f64).collect(); - - // Batch tokens buckets - let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); - let batch_tokens_buckets: Vec = (0..100_000).map(|x| (x + 1) as f64).collect(); - - // Prometheus handler - let builder = PrometheusBuilder::new() - .set_buckets_for_metric(duration_matcher, &duration_buckets) - .unwrap() - .set_buckets_for_metric(input_length_matcher, &input_length_buckets) - .unwrap() - .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) - .unwrap(); - - let prom_handle = builder - .install_recorder() - .expect("failed to install metrics recorder"); - - // CORS layer - let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); - let cors_layer = CorsLayer::new() - .allow_methods([Method::GET, Method::POST]) - .allow_headers([http::header::CONTENT_TYPE]) - .allow_origin(allow_origin); - - // Create router - let app = Router::new() - .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - // Base routes - .route("/info", get(get_model_info)) - .route("/embed", post(embed)) - .route("/predict", post(predict)) - .route("/rerank", post(rerank)) - // OpenAI compat route - .route("/embeddings", post(openai_embed)) - // Base Health route - .route("/health", get(health)) - // Inference API health route - .route("/", get(health)) - // AWS Sagemaker health route - .route("/ping", get(health)) - // Prometheus metrics route - .route("/metrics", get(metrics)); - - // Set default routes - let app = match &info.model_type { - ModelType::Classifier(classifier) => { - if classifier.id2label.len() > 1 { - app.route("/", post(predict)) - // AWS Sagemaker route - .route("/invocations", post(predict)) - } else { - app.route("/", post(rerank)) - // AWS Sagemaker route - .route("/invocations", post(rerank)) - } - } - ModelType::Embedding(_) => { - app.route("/", post(embed)) - // AWS Sagemaker route - .route("/invocations", post(embed)) - } - }; - - let app = app - .layer(Extension(infer)) - .layer(Extension(info)) - .layer(Extension(prom_handle.clone())) - .layer(OtelAxumLayer::default()) - .layer(cors_layer); - - // Run server - axum::Server::bind(&addr) - .serve(app.into_make_service()) - // Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) - .await?; - - Ok(()) -} - -/// Shutdown signal handler -async fn shutdown_signal() { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - tracing::info!("signal received, starting graceful shutdown"); - opentelemetry::global::shutdown_tracer_provider(); -} - -impl From for ErrorResponse { - fn from(err: TextEmbeddingsError) -> Self { - let error_type = match err { - TextEmbeddingsError::Tokenizer(_) => ErrorType::Tokenizer, - TextEmbeddingsError::Validation(_) => ErrorType::Validation, - TextEmbeddingsError::Overloaded(_) => ErrorType::Overloaded, - TextEmbeddingsError::Backend(_) => ErrorType::Backend, - }; - Self { - error: err.to_string(), - error_type, - } - } -} - -impl From<&ErrorType> for StatusCode { - fn from(value: &ErrorType) -> Self { - match value { - ErrorType::Unhealthy => StatusCode::SERVICE_UNAVAILABLE, - ErrorType::Backend => StatusCode::FAILED_DEPENDENCY, - ErrorType::Overloaded => StatusCode::TOO_MANY_REQUESTS, - ErrorType::Tokenizer => StatusCode::UNPROCESSABLE_ENTITY, - ErrorType::Validation => StatusCode::PAYLOAD_TOO_LARGE, - } - } -} - -impl From for OpenAICompatErrorResponse { - fn from(value: ErrorResponse) -> Self { - OpenAICompatErrorResponse { - message: value.error, - code: StatusCode::from(&value.error_type).as_u16(), - error_type: value.error_type, - } - } -} - -/// Convert to Axum supported formats -impl From for (StatusCode, Json) { - fn from(err: ErrorResponse) -> Self { - (StatusCode::from(&err.error_type), Json(err)) - } -} - -impl From for (StatusCode, Json) { - fn from(err: ErrorResponse) -> Self { - (StatusCode::from(&err.error_type), Json(err.into())) - } -} diff --git a/router/src/shutdown.rs b/router/src/shutdown.rs new file mode 100644 index 00000000..471eaf14 --- /dev/null +++ b/router/src/shutdown.rs @@ -0,0 +1,29 @@ +use tokio::signal; + +/// Shutdown signal handler +pub(crate) async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("signal received, starting graceful shutdown"); + opentelemetry::global::shutdown_tracer_provider(); +}