From 5bee5be02a77283b875699bd9b333e38bc498e51 Mon Sep 17 00:00:00 2001 From: Rajendra Kadam Date: Wed, 10 Jul 2024 11:52:39 +0530 Subject: [PATCH] Move missing classification handling to business logic from model --- .../pebblo_retrieval/enforcement_filters.py | 17 +++++++++++++++++ .../chains/pebblo_retrieval/models.py | 10 ---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py index 570cbdfa783f8f..63686b2ba59ae0 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py @@ -21,6 +21,7 @@ AuthContext, SemanticContext, ) +from langchain_community.utilities.pebblo import CLASSIFICATION_UNAVAILABLE logger = logging.getLogger(__name__) @@ -522,6 +523,9 @@ def _set_semantic_enforcement_filter( This method sets the semantic enforcement filter in the search_kwargs of the retriever based on the type of the vectorstore. """ + # Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty + add_unavailable_to_deny_list(semantic_context) + # Apply semantic filter search_kwargs = retriever.search_kwargs if retriever.vectorstore.__class__.__name__ == PINECONE: _apply_pinecone_semantic_filter(search_kwargs, semantic_context) @@ -529,3 +533,16 @@ def _set_semantic_enforcement_filter( _apply_qdrant_semantic_filter(search_kwargs, semantic_context) elif retriever.vectorstore.__class__.__name__ == PGVECTOR: _apply_pgvector_semantic_filter(search_kwargs, semantic_context) + + +def add_unavailable_to_deny_list(sem_ctx: Optional[SemanticContext]) -> None: + """ + Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty. + This function handles documents with missing semantic metadata. + """ + if sem_ctx.pebblo_semantic_entities and sem_ctx.pebblo_semantic_entities.deny: + if CLASSIFICATION_UNAVAILABLE not in sem_ctx.pebblo_semantic_entities.deny: + sem_ctx.pebblo_semantic_entities.deny.append(CLASSIFICATION_UNAVAILABLE) + if sem_ctx.pebblo_semantic_topics and sem_ctx.pebblo_semantic_topics.deny: + if CLASSIFICATION_UNAVAILABLE not in sem_ctx.pebblo_semantic_topics.deny: + sem_ctx.pebblo_semantic_topics.deny.append(CLASSIFICATION_UNAVAILABLE) diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index f72d38daf8ca20..3b7f94d44c8a41 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -4,8 +4,6 @@ from langchain_core.pydantic_v1 import BaseModel -from langchain_community.utilities.pebblo import CLASSIFICATION_UNAVAILABLE - class AuthContext(BaseModel): """Class for an authorization context.""" @@ -48,14 +46,6 @@ def __init__(self, **data: Any) -> None: "'pebblo_semantic_topics'" ) - # Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty - if self.pebblo_semantic_entities and self.pebblo_semantic_entities.deny: - if CLASSIFICATION_UNAVAILABLE not in self.pebblo_semantic_entities.deny: - self.pebblo_semantic_entities.deny.append(CLASSIFICATION_UNAVAILABLE) - if self.pebblo_semantic_topics and self.pebblo_semantic_topics.deny: - if CLASSIFICATION_UNAVAILABLE not in self.pebblo_semantic_topics.deny: - self.pebblo_semantic_topics.deny.append(CLASSIFICATION_UNAVAILABLE) - class ChainInput(BaseModel): """Input for PebbloRetrievalQA chain."""