Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ pub struct Batch {
pub max_length: u32,
pub pooled_indices: Vec<u32>,
pub raw_indices: Vec<u32>,
/// XProvence: raw query texts for context pruning
pub raw_queries: Vec<Option<String>>,
/// XProvence: raw context texts for context pruning
pub raw_texts: Vec<Option<String>>,
}

impl Batch {
Expand All @@ -32,7 +36,16 @@ pub enum Embedding {
}

pub type Embeddings = IntMap<usize, Embedding>;
pub type Predictions = IntMap<usize, Vec<f32>>;

/// XProvence: Prediction result containing scores and optional pruned text
#[derive(Debug, Clone)]
pub struct Prediction {
pub scores: Vec<f32>,
/// XProvence: pruned context text after removing irrelevant sentences
pub pruned_text: Option<String>,
}

pub type Predictions = IntMap<usize, Prediction>;

pub trait Backend {
fn health(&self) -> Result<(), BackendError>;
Expand Down
4 changes: 4 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,17 @@ impl Client {
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
raw_query: Option<String>,
raw_text: Option<String>,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
raw_query,
raw_text,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Expand Down
6 changes: 6 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ message EmbedRequest {
repeated uint32 cu_seq_lengths = 4;
/// Length of the longest request
uint32 max_length = 5;
/// XProvence: raw query text for context pruning
optional string raw_query = 6;
/// XProvence: raw context text for context pruning
optional string raw_text = 7;
}

message Embedding {
Expand All @@ -33,6 +37,8 @@ message EmbedResponse {

message Score {
repeated float values = 1;
/// XProvence: pruned context text after removing irrelevant sentences
optional string pruned_text = 2;
}

message PredictResponse {
Expand Down
31 changes: 23 additions & 8 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.models.xprovence_model import XProvenceModel
from text_embeddings_server.utils.device import get_device, use_ipex

FlashJinaBert = None
FlashMistral = None
FlashQwen3 = None
try:
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
except ImportError as e:
logger.warning(f"Flash attention models not available: {e}")

__all__ = ["Model"]

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
Expand Down Expand Up @@ -76,13 +84,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

if (
hasattr(config, "auto_map")
hasattr(config, "architectures")
and config.architectures
and "XProvence" in config.architectures[0]
):
logger.info("Detected XProvence model for context pruning")
return XProvenceModel(model_path, device, datatype, trust_remote=True)

if (
FlashJinaBert is not None
and hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return create_model(FlashJinaBert, model_path, device, datatype)

if config.model_type == "bert":
Expand Down Expand Up @@ -116,19 +132,18 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
else:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "mistral" and device.type == "hpu":
if FlashMistral is not None and config.model_type == "mistral" and device.type == "hpu":
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "qwen3" and device.type == "hpu":
if FlashQwen3 is not None and config.model_type == "qwen3" and device.type == "hpu":
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
Expand Down
9 changes: 9 additions & 0 deletions backends/python/server/text_embeddings_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class PaddedBatch(Batch):
token_type_ids: torch.Tensor
position_ids: torch.Tensor
attention_mask: torch.Tensor
# XProvence: raw text for context pruning
raw_query: str = None
raw_text: str = None

@classmethod
@tracer.start_as_current_span("from_pb")
Expand Down Expand Up @@ -77,11 +80,17 @@ def from_pb(
# Move padded tensors all at once
all_tensors = all_tensors.to(device)

# XProvence: Extract raw text if present in proto
raw_query = pb.raw_query if hasattr(pb, 'raw_query') and pb.raw_query else None
raw_text = pb.raw_text if hasattr(pb, 'raw_text') and pb.raw_text else None

return PaddedBatch(
input_ids=all_tensors[0],
token_type_ids=all_tensors[1],
position_ids=all_tensors[2],
attention_mask=all_tensors[3],
raw_query=raw_query,
raw_text=raw_text,
)

def __len__(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import os
import torch

from pathlib import Path
from typing import Type, List
from transformers import AutoModel
from opentelemetry import trace
from loguru import logger

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)


def _parse_bool(value: str) -> bool:
"""Parse boolean from string with common conventions."""
return str(value).lower() in ("true", "1", "t", "yes", "on")


class XProvenceModel(Model):
"""
XProvence: Zero-cost context pruning model for RAG.

XProvence removes irrelevant sentences from passages based on relevance
to the query, returning both a reranking score and pruned context.

Based on bge-reranker-v2-m3 (XLM-RoBERTa), supports 16+ languages.

Environment Variables:
XPROVENCE_THRESHOLD (float): Pruning threshold between 0.0-1.0.
- 0.3 (default): Conservative pruning, minimal performance drop
- 0.7: Aggressive pruning, higher compression
XPROVENCE_ALWAYS_SELECT_TITLE (bool): Keep first sentence as title.
- true (default): Always include first sentence (useful for Wikipedia)
- false: Only include sentences above threshold
"""

def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = True,
):
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)

if dtype == torch.bfloat16:
logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility")
dtype = torch.float32

model = model.to(dtype).to(device)

self.hidden_size = model.config.hidden_size

position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
position_offset = model.config.pad_token_id + 1

if hasattr(model.config, "max_seq_length"):
self.max_input_length = model.config.max_seq_length
else:
self.max_input_length = (
model.config.max_position_embeddings - position_offset
)

try:
threshold_env = os.getenv("XPROVENCE_THRESHOLD", "0.3")
self.threshold = float(threshold_env)
if not (0.0 <= self.threshold <= 1.0):
logger.warning(
f"XPROVENCE_THRESHOLD={self.threshold} out of bounds [0.0, 1.0], "
f"defaulting to 0.3"
)
self.threshold = 0.3
except ValueError:
logger.error(
f"Invalid XPROVENCE_THRESHOLD='{threshold_env}', defaulting to 0.3"
)
self.threshold = 0.3

self.always_select_title = _parse_bool(
os.getenv("XPROVENCE_ALWAYS_SELECT_TITLE", "true")
)

logger.info(
f"XProvence model loaded: threshold={self.threshold}, "
f"always_select_title={self.always_select_title} "
f"(Configure via XPROVENCE_THRESHOLD, XPROVENCE_ALWAYS_SELECT_TITLE env vars)"
)

super(XProvenceModel, self).__init__(model=model, dtype=dtype, device=device)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
pass

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
"""
XProvence prediction with context pruning support.

For single-item batches with raw_query/raw_text available,
uses XProvence's process() method for sentence-level pruning.
Otherwise falls back to standard forward pass.
"""
batch_size = len(batch)

if batch_size == 1 and batch.raw_query and batch.raw_text:
return self._predict_with_pruning(batch.raw_query, batch.raw_text)

return self._predict_standard(batch)

def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
"""
Use XProvence's process() method for context pruning.

Returns score with pruned_text containing only relevant sentences.
"""
try:
os.environ["TQDM_DISABLE"] = "1"

original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)

try:
output = self.model.process(
raw_query,
raw_text,
threshold=self.threshold,
always_select_title=self.always_select_title,
)
finally:
torch.set_default_dtype(original_dtype)

reranking_score = float(output["reranking_score"])
pruned_context = output["pruned_context"]

logger.debug(
f"XProvence pruning: score={reranking_score:.4f}, "
f"original_len={len(raw_text)}, pruned_len={len(pruned_context)}"
)

return [Score(values=[reranking_score], pruned_text=pruned_context)]

except Exception as e:
logger.error(f"XProvence process() failed: {e}, falling back to standard")
return [Score(values=[0.0], pruned_text=None)]

def _predict_standard(self, batch: PaddedBatch) -> List[Score]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}

output = self.model(**kwargs, return_dict=True)

if hasattr(output, "ranking_scores"):
scores_tensor = output.ranking_scores
elif hasattr(output, "logits"):
scores_tensor = output.logits[:, 0] if output.logits.dim() == 2 else output.logits
else:
scores_tensor = output[0]

if scores_tensor.dim() == 0:
scores = [float(scores_tensor.item())]
else:
scores = scores_tensor.view(-1).tolist()

if isinstance(scores, float):
scores = [scores]

return [Score(values=[float(s)], pruned_text=None) for s in scores]
20 changes: 16 additions & 4 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
};
use tokio::runtime::Runtime;

Expand Down Expand Up @@ -108,6 +108,11 @@ impl Backend for PythonBackend {
));
}
let batch_size = batch.len();

// XProvence: Get first raw query/text from batch (for single request)
let raw_query = batch.raw_queries.first().cloned().flatten();
let raw_text = batch.raw_texts.first().cloned().flatten();

let results = self
.tokio_runtime
.block_on(self.backend_client.clone().predict(
Expand All @@ -116,15 +121,22 @@ impl Backend for PythonBackend {
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
raw_query,
raw_text,
))
.map_err(|err| BackendError::Inference(err.to_string()))?;
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();

let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());

for (i, r) in raw_results.into_iter().enumerate() {
predictions.insert(i, r);
for (i, score) in results.into_iter().enumerate() {
predictions.insert(
i,
Prediction {
scores: score.values,
pruned_text: score.pruned_text,
},
);
}

Ok(predictions)
Expand Down
Loading