Skip to content

Commit

Permalink
Move missing classification handling to business logic from model
Browse files Browse the repository at this point in the history
  • Loading branch information
Raj725 committed Jul 10, 2024
1 parent 66f6b55 commit 5bee5be
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AuthContext,
SemanticContext,
)
from langchain_community.utilities.pebblo import CLASSIFICATION_UNAVAILABLE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -522,10 +523,26 @@ def _set_semantic_enforcement_filter(
This method sets the semantic enforcement filter in the search_kwargs
of the retriever based on the type of the vectorstore.
"""
# Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty
add_unavailable_to_deny_list(semantic_context)
# Apply semantic filter
search_kwargs = retriever.search_kwargs
if retriever.vectorstore.__class__.__name__ == PINECONE:
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
elif retriever.vectorstore.__class__.__name__ == QDRANT:
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)
elif retriever.vectorstore.__class__.__name__ == PGVECTOR:
_apply_pgvector_semantic_filter(search_kwargs, semantic_context)


def add_unavailable_to_deny_list(sem_ctx: Optional[SemanticContext]) -> None:
"""
Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty.
This function handles documents with missing semantic metadata.
"""
if sem_ctx.pebblo_semantic_entities and sem_ctx.pebblo_semantic_entities.deny:
if CLASSIFICATION_UNAVAILABLE not in sem_ctx.pebblo_semantic_entities.deny:
sem_ctx.pebblo_semantic_entities.deny.append(CLASSIFICATION_UNAVAILABLE)
if sem_ctx.pebblo_semantic_topics and sem_ctx.pebblo_semantic_topics.deny:
if CLASSIFICATION_UNAVAILABLE not in sem_ctx.pebblo_semantic_topics.deny:
sem_ctx.pebblo_semantic_topics.deny.append(CLASSIFICATION_UNAVAILABLE)
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from langchain_core.pydantic_v1 import BaseModel

from langchain_community.utilities.pebblo import CLASSIFICATION_UNAVAILABLE


class AuthContext(BaseModel):
"""Class for an authorization context."""
Expand Down Expand Up @@ -48,14 +46,6 @@ def __init__(self, **data: Any) -> None:
"'pebblo_semantic_topics'"
)

# Add CLASSIFICATION_UNAVAILABLE to deny list if it's not empty
if self.pebblo_semantic_entities and self.pebblo_semantic_entities.deny:
if CLASSIFICATION_UNAVAILABLE not in self.pebblo_semantic_entities.deny:
self.pebblo_semantic_entities.deny.append(CLASSIFICATION_UNAVAILABLE)
if self.pebblo_semantic_topics and self.pebblo_semantic_topics.deny:
if CLASSIFICATION_UNAVAILABLE not in self.pebblo_semantic_topics.deny:
self.pebblo_semantic_topics.deny.append(CLASSIFICATION_UNAVAILABLE)


class ChainInput(BaseModel):
"""Input for PebbloRetrievalQA chain."""
Expand Down

0 comments on commit 5bee5be

Please sign in to comment.