diff --git a/Dockerfile b/Dockerfile index e4a01b249..1b543ccc0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,35 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef -WORKDIR /usr/src +# Dockerfile for TEI with Python backend and CUDA support +# Supports: L40s (sm_89), RTX 3090 (sm_86) + +# ============================================================================= +# Stage 1: Rust Builder +# ============================================================================= +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder ENV SCCACHE=0.10.0 ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" +ENV CARGO_CHEF=0.1.71 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* -# Donwload, configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache -FROM chef AS planner +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --version $CARGO_CHEF --locked + +# ============================================================================= +# Stage 2: Recipe Planner +# ============================================================================= +FROM rust-builder AS planner + +WORKDIR /usr/src COPY backends backends COPY core core @@ -16,34 +37,21 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN cargo chef prepare --recipe-path recipe.json +RUN cargo chef prepare --recipe-path recipe.json -FROM chef AS builder +# ============================================================================= +# Stage 3: Dependency Builder +# ============================================================================= +FROM rust-builder AS builder ARG GIT_SHA ARG DOCKER_LABEL -# sccache specific variables -ARG SCCACHE_GHA_ENABLED - -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ - | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \ - tee /etc/apt/sources.list.d/oneAPI.list - -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - intel-oneapi-mkl-devel=2024.0.0-49656 \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \ - gcc -shared -fPIC -o libfakeintel.so fakeintel.c +WORKDIR /usr/src COPY --from=planner /usr/src/recipe.json recipe.json -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s +RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s COPY backends backends COPY core core @@ -51,73 +59,83 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -FROM builder AS http-builder +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +# ============================================================================= +# Stage 4: Python Environment +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s +ENV DEBIAN_FRONTEND=noninteractive -FROM builder AS grpc-builder +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3.10-dev \ + python3-pip \ + git \ + && rm -rf /var/lib/apt/lists/* -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -COPY proto proto +RUN pip install --no-cache-dir --upgrade pip setuptools wheel -RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ - --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ - cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s +WORKDIR /opt/server -FROM debian:bookworm-slim AS base +COPY backends/proto /opt/proto +COPY backends/python/server /opt/server -ENV HUGGINGFACE_HUB_CACHE=/data \ - PORT=80 \ - MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \ - RAYON_NUM_THREADS=8 \ - LD_PRELOAD=/usr/local/libfakeintel.so \ - LD_LIBRARY_PATH=/usr/local/lib +RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \ + mkdir -p text_embeddings_server/pb && \ + python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \ + --grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \ + find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \ + touch text_embeddings_server/pb/__init__.py -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - libomp-dev \ +RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124 + +RUN pip install --no-cache-dir -r requirements.txt + +RUN pip install --no-cache-dir . + +# ============================================================================= +# Stage 5: Final Image +# ============================================================================= +FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV HUGGINGFACE_HUB_CACHE=/data +ENV PORT=80 +ENV TQDM_DISABLE=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3-pip \ ca-certificates \ libssl-dev \ curl \ && rm -rf /var/lib/apt/lists/* -# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch... -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 -COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so - -FROM base AS grpc - -COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router - -ENTRYPOINT ["text-embeddings-router"] -CMD ["--json-output"] +RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \ + ln -sf /usr/bin/python3.10 /usr/bin/python3 -FROM base AS http +COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages +COPY --from=python-builder /usr/local/bin/python-text-embeddings-server /usr/local/bin/python-text-embeddings-server +COPY --from=python-builder /opt/server /opt/server -COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router -# Amazon SageMaker compatible image -FROM http AS sagemaker -COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh +ENV PATH="/usr/local/bin:${PATH}" +ENV PYTHONPATH="/opt/server:${PYTHONPATH}" -ENTRYPOINT ["./entrypoint.sh"] +# Download spacy model in final image (ensures it's available at runtime) +# This is needed because spacy models may not be fully copied from builder stage +RUN pip install --no-cache-dir spacy>=3.7.0 && \ + python -m spacy download xx_sent_ud_sm && \ + python -c "import spacy; spacy.load('xx_sent_ud_sm')" && \ + echo "Spacy model verified successfully" -# Default image -FROM http +WORKDIR /opt/server ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] diff --git a/assets/bs1-lat.png b/assets/bs1-lat.png deleted file mode 100644 index 6105ddcc9..000000000 --- a/assets/bs1-lat.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:778b29d7d21382004fef2c528973f66bb175951ab7cd168d588cd245e36bd629 -size 15202 diff --git a/assets/bs1-tp.png b/assets/bs1-tp.png deleted file mode 100644 index 953ff0b68..000000000 --- a/assets/bs1-tp.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:478984ace4f33044bc0a53b0503a0cbfcd0a64f601922e2a13cc34d52c2b7c2b -size 17169 diff --git a/assets/bs32-lat.png b/assets/bs32-lat.png deleted file mode 100644 index ed352e40f..000000000 --- a/assets/bs32-lat.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:769326aad7e582a2e5271dd2d73c3bb5289684add10eb7146ddadd00d3b2077f -size 17596 diff --git a/assets/bs32-tp.png b/assets/bs32-tp.png deleted file mode 100644 index c952bd285..000000000 --- a/assets/bs32-tp.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c227c5adbb8664af7aa3d59aaa408557b2865dcfbd3c6c6353caf71f2eb5b7bc -size 18521 diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..0d5fa97fc 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize}; use std::collections::HashMap; use std::path::Path; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions, }; #[cfg(feature = "cuda")] @@ -653,7 +653,10 @@ impl Backend for CandleBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in results.into_iter().enumerate() { - predictions.insert(i, r); + predictions.insert(i, Prediction { + scores: r, + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 8e134d2be..55dad0d8e 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -14,6 +14,10 @@ pub struct Batch { pub max_length: u32, pub pooled_indices: Vec, pub raw_indices: Vec, + /// XProvence: raw query texts for context pruning + pub raw_queries: Vec>, + /// XProvence: raw context texts for context pruning + pub raw_texts: Vec>, } impl Batch { @@ -32,7 +36,16 @@ pub enum Embedding { } pub type Embeddings = IntMap; -pub type Predictions = IntMap>; + +/// XProvence: Prediction result containing scores and optional pruned text +#[derive(Debug, Clone)] +pub struct Prediction { + pub scores: Vec, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, +} + +pub type Predictions = IntMap; pub trait Backend { fn health(&self) -> Result<(), BackendError>; diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 1f6036eed..a5872642f 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -59,6 +59,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_queries: vec![], + raw_texts: vec![], }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); @@ -73,6 +75,8 @@ impl Client { position_ids: Vec, cu_seq_lengths: Vec, max_length: u32, + raw_queries: Vec, + raw_texts: Vec, ) -> Result> { let request = tonic::Request::new(EmbedRequest { input_ids, @@ -80,6 +84,8 @@ impl Client { position_ids, max_length, cu_seq_lengths, + raw_queries, + raw_texts, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index bfc2d03ad..4f84d4f79 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -8,7 +8,7 @@ use std::ops::{Div, Mul}; use std::path::Path; use std::sync::Mutex; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; #[derive(Debug, Clone, Deserialize)] @@ -679,7 +679,10 @@ impl Backend for OrtBackend { let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in outputs.rows().into_iter().enumerate() { - predictions.insert(i, r.to_vec()); + predictions.insert(i, Prediction { + scores: r.to_vec(), + pruned_text: None, + }); } Ok(predictions) diff --git a/backends/proto/embed.proto b/backends/proto/embed.proto index 036f3db4b..e233902d0 100644 --- a/backends/proto/embed.proto +++ b/backends/proto/embed.proto @@ -21,6 +21,10 @@ message EmbedRequest { repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; + /// XProvence: raw query texts for context pruning (one per batch item) + repeated string raw_queries = 6; + /// XProvence: raw context texts for context pruning (one per batch item) + repeated string raw_texts = 7; } message Embedding { @@ -33,6 +37,8 @@ message EmbedResponse { message Score { repeated float values = 1; + /// XProvence: pruned context text after removing irrelevant sentences + optional string pruned_text = 2; } message PredictResponse { diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 687ec1028..b893f569c 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -52,6 +52,7 @@ safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13" +spacy>=3.7.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f233..ac6fd0211 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,4 +1,5 @@ import os +import json import torch from loguru import logger @@ -11,11 +12,38 @@ from text_embeddings_server.models.masked_model import MaskedLanguageModel from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel -from text_embeddings_server.models.jinaBert_model import FlashJinaBert -from text_embeddings_server.models.flash_mistral import FlashMistral -from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +from text_embeddings_server.models.xprovence_model import XProvenceModel from text_embeddings_server.utils.device import get_device, use_ipex + +def _is_xprovence_model(model_path: Path) -> bool: + """Check if model is XProvence by reading config.json directly. + + This avoids calling AutoConfig.from_pretrained which can pollute + transformers' internal registry and cause config class conflicts. + """ + config_path = model_path / "config.json" + if not config_path.exists(): + return False + + try: + with open(config_path, "r") as f: + config = json.load(f) + architectures = config.get("architectures", []) + return any("XProvence" in arch for arch in architectures) + except Exception: + return False + +FlashJinaBert = None +FlashMistral = None +FlashQwen3 = None +try: + from text_embeddings_server.models.jinaBert_model import FlashJinaBert + from text_embeddings_server.models.flash_mistral import FlashMistral + from text_embeddings_server.models.flash_qwen3 import FlashQwen3 +except ImportError as e: + logger.warning(f"Flash attention models not available: {e}") + __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] @@ -73,16 +101,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): device = get_device() logger.info(f"backend device: {device}") + # Check for XProvence BEFORE calling AutoConfig.from_pretrained + # to avoid polluting transformers' internal config registry + if _is_xprovence_model(model_path): + logger.info("Detected XProvence model for context pruning") + return XProvenceModel(model_path, device, datatype, trust_remote=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) if ( - hasattr(config, "auto_map") + FlashJinaBert is not None + and hasattr(config, "auto_map") and isinstance(config.auto_map, dict) and "AutoModel" in config.auto_map and config.auto_map["AutoModel"] == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" ): - # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository return create_model(FlashJinaBert, model_path, device, datatype) if config.model_type == "bert": @@ -116,19 +150,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "mistral" and device.type == "hpu": + if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu": try: return create_model(FlashMistral, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "qwen3" and device.type == "hpu": + if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu": try: return create_model(FlashQwen3, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - # Default case if config.architectures[0].endswith("Classification"): return create_model(ClassificationModel, model_path, device, datatype) elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index f27572a9b..92eb5b2ee 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -3,7 +3,8 @@ import torch from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List, Optional from opentelemetry import trace from text_embeddings_server.pb import embed_pb2 @@ -36,6 +37,9 @@ class PaddedBatch(Batch): token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor + # XProvence: raw texts for context pruning (one per batch item) + raw_queries: Optional[List[str]] = None + raw_texts: Optional[List[str]] = None @classmethod @tracer.start_as_current_span("from_pb") @@ -77,11 +81,17 @@ def from_pb( # Move padded tensors all at once all_tensors = all_tensors.to(device) + # XProvence: Extract repeated raw_queries/raw_texts from proto + raw_queries = list(pb.raw_queries) if pb.raw_queries else None + raw_texts = list(pb.raw_texts) if pb.raw_texts else None + return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], + raw_queries=raw_queries, + raw_texts=raw_texts, ) def __len__(self): diff --git a/backends/python/server/text_embeddings_server/models/xprovence_model.py b/backends/python/server/text_embeddings_server/models/xprovence_model.py new file mode 100644 index 000000000..fb4856432 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/xprovence_model.py @@ -0,0 +1,303 @@ +import os +import torch + +from pathlib import Path +from typing import Type, List, Optional +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from huggingface_hub import hf_hub_download +from opentelemetry import trace +from loguru import logger + +from text_embeddings_server.models.model import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +def _parse_bool(value: str) -> bool: + """Parse boolean from string with common conventions.""" + return str(value).lower() in ("true", "1", "t", "yes", "on") + + +def _extract_model_id(model_path_str: str) -> Optional[str]: + """Extract model_id from HF cache path format. + + Converts paths like '/data/models--naver--xprovence-reranker-bgem3-v1/snapshots/...' + to 'naver/xprovence-reranker-bgem3-v1' + """ + if "/models--" not in model_path_str: + return None + + parts = model_path_str.split("/") + for part in parts: + if part.startswith("models--"): + # models--naver--xprovence-reranker-bgem3-v1 -> naver/xprovence-reranker-bgem3-v1 + return part.replace("models--", "").replace("--", "/", 1) + return None + + +class XProvenceModel(Model): + """ + XProvence: Zero-cost context pruning model for RAG. + + XProvence removes irrelevant sentences from passages based on relevance + to the query, returning both a reranking score and pruned context. + + Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages. + + Environment Variables: + XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0. + - 0.3 (default): Conservative pruning, minimal performance drop + - 0.7: Aggressive pruning, higher compression + XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title. + - true (default): Always include first sentence (useful for Wikipedia) + - false: Only include sentences above threshold + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = True, + ): + model_path_str = str(model_path) + cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data") + + # Extract model_id from cache path for proper trust_remote_code handling + model_id = _extract_model_id(model_path_str) + + if model_id: + # Directly import the custom model class to avoid AutoModel's config class mismatch + # AutoModel.from_pretrained internally loads config which causes XLMRobertaConfig + # to be registered, conflicting with the model's expected XProvenceConfig + logger.info(f"XProvence: Loading custom model class for {model_id}") + + # Get the custom model class directly from the dynamic module + model_class = get_class_from_dynamic_module( + "modeling_xprovence_hf.XProvence", + model_id, + cache_dir=cache_dir, + ) + + # Load using the custom class directly - this uses the correct config_class + model = model_class.from_pretrained( + model_id, + trust_remote_code=True, + cache_dir=cache_dir, + ) + else: + # Fallback for local paths - try to import from local path + logger.info(f"XProvence: Loading from local path {model_path}") + model_class = get_class_from_dynamic_module( + "modeling_xprovence_hf.XProvence", + model_path, + ) + model = model_class.from_pretrained( + model_path, + trust_remote_code=True, + ) + + if dtype == torch.bfloat16: + logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility") + dtype = torch.float32 + + model = model.to(dtype).to(device) + + self.hidden_size = model.config.hidden_size + + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + try: + threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3") + self.threshold = float(threshold_env) + if not (0.0 <= self.threshold <= 1.0): + logger.warning( + f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], " + f"defaulting to 0.3" + ) + self.threshold = 0.3 + except ValueError: + logger.error( + f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3" + ) + self.threshold = 0.3 + + self.always_select_title = _parse_bool( + os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true") + ) + + logger.info( + f"XProvence model loaded: threshold={self.threshold}, " + f"always_select_title={self.always_select_title} " + f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)" + ) + + super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + """ + XProvence prediction with context pruning support. + + For batches with raw_queries/raw_texts available (one per item), + uses XProvence's process() method for sentence-level pruning on each pair. + Otherwise falls back to standard forward pass. + """ + batch_size = len(batch) + raw_queries = batch.raw_queries or [] + raw_texts = batch.raw_texts or [] + + # Broadcasting: 1 query → N texts (common reranking pattern) + if len(raw_queries) == 1 and len(raw_texts) == batch_size and batch_size > 1: + logger.info(f"XProvence: Broadcasting single query to {batch_size} texts") + raw_queries = raw_queries * batch_size + + # Check for dimension mismatch with explicit warning + if len(raw_queries) != batch_size or len(raw_texts) != batch_size: + if raw_queries or raw_texts: + logger.warning( + f"XProvence: Dimension mismatch - batch_size={batch_size}, " + f"raw_queries={len(raw_queries)}, raw_texts={len(raw_texts)}. " + f"Falling back to standard inference (no pruned_text)." + ) + return self._predict_standard(batch) + + # Process batch with pruning (optimized) + logger.info(f"XProvence: Processing {batch_size} pairs with pruning") + return self._predict_batch_with_pruning(raw_queries, raw_texts) + + def _predict_batch_with_pruning( + self, raw_queries: List[str], raw_texts: List[str] + ) -> List[Score]: + """ + Optimized batch processing with pruning. + + Uses inference_mode and batched dtype handling to reduce per-item overhead. + Note: XProvence process() is inherently per-pair for sentence-level analysis. + """ + batch_size = len(raw_queries) + results = [] + + # Suppress progress bars once for entire batch + os.environ["TQDM_DISABLE"] = "1" + + # Use inference_mode for better performance (no grad tracking) + with torch.inference_mode(): + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + + try: + for i in range(batch_size): + query = raw_queries[i] + text = raw_texts[i] + + if not query or not text: + logger.warning( + f"XProvence: Empty query/text at index {i}, score=0.0" + ) + results.append(Score(values=[0.0], pruned_text=None)) + continue + + try: + output = self.model.process( + query, + text, + threshold=self.threshold, + always_select_title=self.always_select_title, + ) + + score = float(output["reranking_score"]) + pruned = output["pruned_context"] + + logger.debug( + f"XProvence [{i}]: score={score:.4f}, " + f"len={len(text)}→{len(pruned)}" + ) + results.append(Score(values=[score], pruned_text=pruned)) + + except Exception as e: + logger.error(f"XProvence process() failed at index {i}: {e}") + results.append(Score(values=[0.0], pruned_text=None)) + + finally: + torch.set_default_dtype(original_dtype) + + return results + + def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]: + """ + Use XProvence's process() method for context pruning. + + Returns score with pruned_text containing only relevant sentences. + """ + try: + os.environ["TQDM_DISABLE"] = "1" + + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + + try: + output = self.model.process( + raw_query, + raw_text, + threshold=self.threshold, + always_select_title=self.always_select_title, + ) + finally: + torch.set_default_dtype(original_dtype) + + reranking_score = float(output["reranking_score"]) + pruned_context = output["pruned_context"] + + logger.debug( + f"XProvence pruning: score={reranking_score:.4f}, " + f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}" + ) + + return [Score(values=[reranking_score], pruned_text=pruned_context)] + + except Exception as e: + logger.error(f"XProvence process() failed: {e}, falling back to standard") + return [Score(values=[0.0], pruned_text=None)] + + def _predict_standard(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + + output = self.model(**kwargs, return_dict=True) + + if hasattr(output, "ranking_scores"): + scores_tensor = output.ranking_scores + elif hasattr(output, "logits"): + scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits + else: + scores_tensor = output[0] + + if scores_tensor.dim() == 0: + scores = [float(scores_tensor.item())] + else: + scores = scores_tensor.view(-1).tolist() + + if isinstance(scores, float): + scores = [scores] + + return [Score(values=[float(s)], pruned_text=None) for s in scores] diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 53255b07d..331391a99 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -5,7 +5,7 @@ use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions, }; use tokio::runtime::Runtime; @@ -108,6 +108,19 @@ impl Backend for PythonBackend { )); } let batch_size = batch.len(); + + // XProvence: Collect all raw queries/texts for the batch (one per item) + let raw_queries: Vec = batch + .raw_queries + .into_iter() + .map(|q| q.unwrap_or_default()) + .collect(); + let raw_texts: Vec = batch + .raw_texts + .into_iter() + .map(|t| t.unwrap_or_default()) + .collect(); + let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( @@ -116,15 +129,22 @@ impl Backend for PythonBackend { batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, + raw_queries, + raw_texts, )) .map_err(|err| BackendError::Inference(err.to_string()))?; - let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, r) in raw_results.into_iter().enumerate() { - predictions.insert(i, r); + for (i, score) in results.into_iter().enumerate() { + predictions.insert( + i, + Prediction { + scores: score.values, + pruned_text: score.pruned_text, + }, + ); } Ok(predictions) diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b38..79bc05d29 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -223,6 +223,8 @@ impl Backend { max_length: tmp_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], } } @@ -280,6 +282,8 @@ impl Backend { max_length, pooled_indices, raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { @@ -314,6 +318,8 @@ impl Backend { max_length: 1, pooled_indices: vec![0], raw_indices: vec![], + raw_queries: vec![], + raw_texts: vec![], }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), diff --git a/core/src/infer.rs b/core/src/infer.rs index a2ff22c51..fb16eb15a 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -561,11 +561,13 @@ async fn backend_task(backend: Backend, mut embed_receiver: mpsc::Receiver, + /// XProvence: pruned context text after removing irrelevant sentences + pub pruned_text: Option, pub metadata: InferMetadata, } diff --git a/core/src/queue.rs b/core/src/queue.rs index 3fd8b7715..acc3409d4 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -129,6 +129,10 @@ fn queue_blocking_task( let mut cu_seq_lengths = Vec::with_capacity(capacity); cu_seq_lengths.push(0); + // XProvence: raw text vectors for context pruning + let mut raw_queries = Vec::with_capacity(capacity); + let mut raw_texts = Vec::with_capacity(capacity); + let mut current_tokens = 0; let mut max_length = 0; @@ -168,6 +172,10 @@ fn queue_blocking_task( token_type_ids.extend(entry.encoding.token_type_ids); position_ids.extend(entry.encoding.position_ids); + // XProvence: collect raw texts for context pruning + raw_queries.push(entry.encoding.raw_query); + raw_texts.push(entry.encoding.raw_text); + current_tokens += entry_tokens; metadata.push(entry.metadata); cu_seq_lengths.push(current_tokens as u32); @@ -193,6 +201,8 @@ fn queue_blocking_task( max_length, pooled_indices, raw_indices, + raw_queries, + raw_texts, }, )) }; diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 3639b9845..f42ceb352 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -374,6 +374,12 @@ fn encode_input( prompts: Option<&HashMap>, tokenizer: &mut Tokenizer, ) -> Result { + // XProvence: Extract raw query and text before tokenization (for Dual inputs) + let (raw_query, raw_text) = match &inputs { + EncodingInput::Dual(query, text) => (Some(query.clone()), Some(text.clone())), + _ => (None, None), + }; + // Default truncation params let truncate_params = truncate.then_some(TruncationParams { direction: truncation_direction, @@ -406,6 +412,8 @@ fn encode_input( token_type_ids: encoding.get_type_ids().to_vec(), position_ids: (position_offset as u32..(seq_len + position_offset) as u32) .collect::>(), + raw_query, + raw_text, }) } @@ -414,6 +422,10 @@ pub struct ValidEncoding { pub input_ids: Vec, pub token_type_ids: Vec, pub position_ids: Vec, + /// XProvence: raw query text for context pruning (from Dual input) + pub raw_query: Option, + /// XProvence: raw context text for context pruning (from Dual input) + pub raw_text: Option, } #[derive(Debug)] diff --git a/revision b/revision new file mode 100644 index 000000000..d00491fd7 --- /dev/null +++ b/revision @@ -0,0 +1 @@ +1 diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a22af9628..1cb57a165 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -361,13 +361,16 @@ async fn rerank( .map_err(ErrorResponse::from)?; let score = response.results[0]; + // XProvence: extract pruned_text from response + let pruned_text = response.pruned_text; - Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + Ok::<(usize, Duration, Duration, Duration, f32, Option), ErrorResponse>(( response.metadata.prompt_tokens, response.metadata.tokenization, response.metadata.queue, response.metadata.inference, score, + pruned_text, )) }; @@ -410,7 +413,7 @@ async fn rerank( let results = join_all(futures) .await .into_iter() - .collect::, ErrorResponse>>()?; + .collect::)>, ErrorResponse>>()?; let mut ranks = Vec::with_capacity(batch_size); let mut total_tokenization_time = 0; @@ -430,6 +433,9 @@ async fn rerank( }; let score = r.4; + // XProvence: extract pruned_text from result + let pruned_text = r.5; + // Check that s is not NaN or the partial_cmp below will panic if score.is_nan() { Err(ErrorResponse { @@ -438,7 +444,7 @@ async fn rerank( })?; } - ranks.push(Rank { index, text, score }) + ranks.push(Rank { index, text, score, pruned_text }) } // Reverse sort diff --git a/router/src/http/types.rs b/router/src/http/types.rs index dedaab60a..ce9994b22 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -266,6 +266,10 @@ pub(crate) struct Rank { pub text: Option, #[schema(example = "1.0")] pub score: f32, + /// XProvence: pruned context with irrelevant sentences removed + #[schema(nullable = true, default = "null")] + #[serde(skip_serializing_if = "Option::is_none")] + pub pruned_text: Option, } #[derive(Serialize, ToSchema)] diff --git a/router/src/lib.rs b/router/src/lib.rs index d83bd95c5..9c5eb98f4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -396,7 +396,8 @@ fn get_backend_model_type( return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); - } else if arch.ends_with("Classification") { + } else if arch.ends_with("Classification") || arch == "XProvence" { + // XProvence is a reranker model for context pruning if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."