## Initialize MongoDB vector database with some sample documents

In [17]:
# Restart the notebook after installing the packages

!pip install pymongo==4.7.2 langchain-core==0.2.6 langchain-openai==0.1.7 langchain==0.2.1 langchain-community==0.2.4 lark==1.1.9

In [3]:
from pymongo import MongoClient

MONGO_URI = os.getenv('MONGO_URI')

client = MongoClient(MONGO_URI)
DB_NAME = "your database name"
COLLECTION_NAME = "your collection name"
collection = client[DB_NAME][COLLECTION_NAME]

In [5]:
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings

docs = [
    Document(
        page_content="A bunch of scientists bring back dinosaurs and mayhem breaks loose",
        metadata={"release_date": "1994-04-15", "rating": 7.7, "genre": ["action", "scifi", "adventure"]},
    ),
    Document(
        page_content="Leo DiCaprio gets lost in a dream within a dream within a dream within a ...",
        metadata={"release_date": "2010-07-16", "director": "Christopher Nolan", "rating": 8.2, "genre": ["action", "thriller"]},
    ),
    Document(
        page_content="A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea",
        metadata={"release_date": "2006-11-25", "director": "Satoshi Kon", "rating": 8.6, "genre": ["anime", "thriller", "scifi"]},
    ),
    Document(
        page_content="A bunch of normal-sized women are supremely wholesome and some men pine after them",
        metadata={"release_date": "2019-12-25", "director": "Greta Gerwig", "rating": 8.3, "genre": ["romance", "drama", "comedy"]},
    ),
    Document(
        page_content="Toys come alive and have a blast doing so",
        metadata={"release_date": "1995-11-22", "genre": ["anime", "fantasy"]},
    )
]

In [6]:
import json

from langchain_openai import OpenAIEmbeddings

openai_api_key = os.getenv("OPEN_AI_API_KEY")
openai_api_base = os.getenv("OPEN_API_BASE")

# default_headers is optional
default_headers = os.getenv("OPEN_API_DEFAULT_HEADERS")
default_headers = json.loads(default_headers) if default_headers else None

embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, openai_api_base=openai_api_base, default_headers=default_headers)

In [7]:
from langchain.vectorstores import MongoDBAtlasVectorSearch

vectorStore = MongoDBAtlasVectorSearch.from_documents(docs, embeddings, collection=collection)

## Create a vector search index with name `default` and use the below in the JSON editor

In [8]:
# copy below while creating the vector search index
{
  "fields": [
    {
      "numDimensions": 1536,
      "path": "embedding",
      "similarity": "cosine",
      "type": "vector"
    },
    {
      "path": "genre",
      "type": "filter"
    },
    {
      "path": "rating",
      "type": "filter"
    },
    {
      "path": "release_date",
      "type": "filter"
    }
  ]
}

{'fields': [{'numDimensions': 1536,
   'path': 'embedding',
   'similarity': 'cosine',
   'type': 'vector'},
  {'path': 'genre', 'type': 'filter'},
  {'path': 'rating', 'type': 'filter'},
  {'path': 'release_date', 'type': 'filter'}]}

## Define the metadata for the MongoDB collection

In [9]:
from langchain.chains.query_constructor.base import AttributeInfo

metadata_field_info = [
    AttributeInfo(
        name="genre",
        description="Keywords for filtering: ['anime', 'action', 'comedy', 'romance', 'thriller']",
        type="[string]",
    ),
    AttributeInfo(
        name="release_date",
        description="The date the movie was released on",
        type="string",
    ),
    AttributeInfo(
        name="rating", description="A 1-10 rating for the movie", type="float"
    ),
]
document_content_description = "Brief summary of a movie"

## Define a tool that we will use for running queries on our MongoDB collection

In [10]:
import json
import logging
import os
import traceback
from typing import Dict, Optional, Type, Union, List

from pymongo import MongoClient
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool

class MongoDBClient:
    """Data helper for querying MongoDB Vector Indexes."""

    def __init__(self, collection):
        self.collection = collection

    def run_aggregate_pipeline(self, pipeline: List[Dict]) -> List[Dict]:
        documents = list(self.collection.aggregate(pipeline))
        return documents

class BaseMongoDBTool(BaseModel):
    """Base tool for interacting with MongoDB."""

    client: MongoDBClient = Field(exclude=True)
    match_filter: dict = Field(exclude=True)

    class Config(BaseTool.Config):
        pass

class _QueryExecutorMongoDBToolInput(BaseModel):
    pipeline: str = Field(..., description="A valid MongoDB pipeline in JSON string format")

class QueryExecutorMongoDBTool(BaseMongoDBTool, BaseTool):
    name: str = "mongo_db_executor"
    description: str = """
    Input to this tool is a mongodb pipeline, output is a list of documents.
    If the pipeline is not correct, an error message will be returned.
    If an error is returned, report back to the user the issue and stop.
    """
    args_schema: Type[BaseModel] = _QueryExecutorMongoDBToolInput

    def _run(
            self,
            pipeline: str,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Union[List[Dict], str]:
        """Get the result for the mongodb pipeline."""
        try:
            logger.info(f"Pipeline: {pipeline}/")
            logger.info(f"Match filter: {self.match_filter}/")
            pipeline = json.loads(pipeline)
            if self.match_filter:
                pipeline = [{"$match": self.match_filter}] + pipeline
            logger.info(f"Updated pipeline: {pipeline}/")
            documents = self.client.run_aggregate_pipeline(pipeline)
            return documents
        except Exception as e:
            """Format the error message"""
            return f"Error: {e}\n{traceback.format_exc()}"


## Define prompt and examples that we will be using for the Query Constructor

> Note: Update the prompt and examples as per your use case

In [11]:
from langchain_core.prompts import PromptTemplate

DEFAULT_SCHEMA = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the following schema:

```json
{{{{
    "query": string \\ rewritten user's query after removing the information handled by the filter
    "filter": string \\ logical condition statement for filtering documents
}}}}
```

The query string should be re-written. Any conditions in the filter should not be mentioned in the query as well.

A logical condition statement is composed of one or more comparison and logical operation statements.

A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string):  name of attribute to apply the comparison to
- `val` (string): is the comparison value

A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to

Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values.
Make sure you understand the user's intent while generating a date filter. Use a range comparators such as gt | gte | lt | lte  for partial dates. 
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. 
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\
"""
DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA)

SONG_DATA_SOURCE = """\
```json
{{
    "content": "Lyrics of a song",
    "attributes": {{
        "artist": {{
            "type": "string",
            "description": "Name of the song artist"
        }},
        "length": {{
            "type": "integer",
            "description": "Length of the song in seconds"
        }},
        "genre": {{
            "type": "[string]",
            "description": "The song genre, one or many of [\"pop\", \"rock\" or \"rap\"]"
        }},
        "release_dt": {{
            "type": "string",
            "description": "Release date of the song."
        }}
    }}
}}
```\
"""

MEMO_DATA_SOURCE = """\
```json
{{
    "content": "Estaff memos",
    "attributes": {{
        "memo_date": {{
            "type": "string",
            "description": "Date the memo was published on"
        }},
        "title": {{
            "type": "string",
            "description": "The title of the memo"
        }}
    }}
}}
```\
"""

KEYWORDS_DATA_SOURCE = """\
```json
{{
    "content": "Documents store",
    "attributes": {{
        "tags": {{
            "type": "[string]",
            "description": "Keywords for filtering: ['credal', 'genai', 'radiant', 'langchain']"
        }}
    }}
}}
````\
"""

KEYWORDS_DATA_SOURCE_ANSWER = """\
```json
{{
    "query": "Updates on radiant integration with credal",
    "filter": "in(\\"tags\\", [\\"credal\\", \\"radiant\\"])"
}}
````\
"""

KEYWORDS_DATE_DATA_SOURCE_ANSWER = """\
```json
{{
    "query": "Updates based on most recent memo",
    "filter": "in(\\"tags\\", [\\"credal\\", \\"radiant\\"])"
}}
````\
"""

FULL_ANSWER = """\
```json
{{
    "query": "songs about teenage romance",
    "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), in(\\"genre\\", [\\"pop\\"]), and(gt(\\"release_dt\\", \\"2010-12-31\\"), lt(\\"release_dt\\", \\"2020-01-01\\")))"
}}
```\
"""

DATE_ANSWER = """\
```json
{{
    "query": "What are the updates on genai?",
    "filter": "gt(\\"memo_date\\", \\"2023-01-01\\")"
}}
```\
"""

NO_FILTER_ANSWER = """\
```json
{{
    "query": "",
    "filter": "NO_FILTER"
}}
```\
"""

WITH_LIMIT_ANSWER = """\
```json
{{
    "query": "love",
    "filter": "NO_FILTER",
    "limit": 2
}}
```\
"""

DEFAULT_EXAMPLES = [
    {
        "i": 1,
        "data_source": MEMO_DATA_SOURCE,
        "user_query": "What are the updates on genai after 1 Jan 2023",
        "structured_request": DATE_ANSWER,
    },
    {
        "i": 2,
        "data_source": MEMO_DATA_SOURCE,
        "user_query": "What are the updates on genai",
        "structured_request": NO_FILTER_ANSWER
    },
    {
        "i": 3,
        "data_source": SONG_DATA_SOURCE,
        "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre released before 1 January 2020 and after 31 December, 2010",
        "structured_request": FULL_ANSWER,
    },
    {
        "i": 4,
        "data_source": SONG_DATA_SOURCE,
        "user_query": "What are songs that were not published on Spotify",
        "structured_request": NO_FILTER_ANSWER,
    },
    {
        "i": 5,
        "data_source": KEYWORDS_DATA_SOURCE,
        "user_query": "Updates on radiant integration in Credal",
        "structured_request": KEYWORDS_DATA_SOURCE_ANSWER
    },
    {
        "i": 6,
        "data_source": KEYWORDS_DATA_SOURCE,
        "user_query": "Updates on radiant integration in Credal based on most recent memo",
        "structured_request": KEYWORDS_DATE_DATA_SOURCE_ANSWER
    }
]

EXAMPLES_WITH_LIMIT = [
    {
        "i": 1,
        "data_source": SONG_DATA_SOURCE,
        "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre released before 1 January 2020 and after 31 December, 2010",
        "structured_request": FULL_ANSWER,
    },
    {
        "i": 2,
        "data_source": SONG_DATA_SOURCE,
        "user_query": "What are songs that were not published on Spotify",
        "structured_request": NO_FILTER_ANSWER,
    },
    {
        "i": 3,
        "data_source": SONG_DATA_SOURCE,
        "user_query": "What are three songs about love",
        "structured_request": WITH_LIMIT_ANSWER,
    },
]


def enforce_constraints(input_json):
    def process_value(value):
        if isinstance(value, (str, int)):
            return value
        elif isinstance(value, list) and all(isinstance(item, str) for item in value):
            return value
        elif isinstance(value, dict) and 'date' in value and isinstance(value['date'], str):
            return value['date']
        else:
            raise ValueError("Invalid value type")

    def process_dict(d):
        if not isinstance(d, dict):
            return d
        processed_dict = {}
        for k, v in d.items():
            if k.startswith("$") and isinstance(v, list):
                # Handling $and and $or conditions
                processed_dict[k] = [process_dict(item) for item in v]
            elif k.startswith("$"):
                processed_dict[k] = process_value(v)
            else:
                processed_dict[k] = process_dict(v)
        return processed_dict

    return process_dict(input_json)


## Define the prompt that we will be using for the Time Based query constructor

> Note: Update the prompt as per your use case

In [12]:
SYSTEM_PROMPT_TEMPLATE = """
Your goal is to structure the user's query to match the request schema provided below.

<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the following schema:

```json
{{{{
    "query": string \\ rewritten user's query after removing the information handled by the filter
    "filter": string \\ logical condition statement for filtering documents
}}}}
```

The query string should be re-written. Any conditions in the filter should not be mentioned in the query as well.

A logical condition statement is composed of one or more comparison and logical operation statements.

A comparison statement takes the form: `comp(attr, val)`:
- `comp` ('eq | ne | gt | gte | lt | lte | in | nin'): comparator
- `attr` (string):  name of attribute to apply the comparison to
- `val` (string): is the comparison value

A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ('and | or'): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to

First step is to think about whether the user question mentions anything about date or time related that require a lookup in the MongoDB database. Words like "latest", "recent", "earliest", "first", "last" etc. in the query means a look up could be required.
If no lookup is required, return "NO_FILTER" for the filter value.

If required, create a syntactically correct MongoDB aggregation pipeline using '$sort' and '$limit' operator to run.
Use projection to only fetch the relevant date columns.
Then look at the results of the aggregation pipeline and generate a date range query that can be used to filter relevant documents from the collection.

Make sure to only generate date-based filters.
Make sure to only generate the query if a user asks about a time based question such as latest, most recent and not mention a specific date time.
Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to date/time attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values.
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
Make sure the column names in the filter query are in double quotes.

<< Data Source >>
```json
{{{{
    "content": {content_description},
    "attributes": {attribute_info}
}}}}
```
"""

## Define the MetadataFilter class that we will use to generate pre_filters

In [13]:
import logging
import time
from typing import List, Dict, Tuple, Union

from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain.chains.query_constructor.base import AttributeInfo, _format_attribute_info, StructuredQueryOutputParser
from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain_community.query_constructors.mongodb_atlas import MongoDBAtlasTranslator
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate, HumanMessagePromptTemplate, \
    SystemMessagePromptTemplate
from pymongo import MongoClient
from dw_gai.common.enums import ChatModel

logger = logging.getLogger(__name__)


class MetadataFilter:
    """
    MetadataFilter is responsible for generating a MongoDB pre-filter query based on the user query.
    """

    def __init__(self, collection, llm, metadata_field_info, document_content_description):
        """
        Initialize the MetadataFilter with a pymongo collection
        :param collection: Pymongo collection
        :param llm
        :param metadata_info: Dict of attribute_info and content_description
        """
        self.collection = collection
        self.llm = llm
        self.dataset_query_constructor = {}
        self.translator = MongoDBAtlasTranslator()
        self.metadata_field_info = metadata_field_info
        self.document_content_description = document_content_description
    
    def create_query_constructor(self):
        """
        This method will create query constructor for the collection.
        The query constructor is a chain with a prompt created using collection's metadata and content description.
        This query constructor will be used to generate pre-filter for a user's query.
        """
        query_constructor_run_name = "query_constructor"

        chain_kwargs = {}
        translator = MongoDBAtlasTranslator()
        chain_kwargs["allowed_operators"] = translator.allowed_operators
        chain_kwargs["allowed_comparators"] = translator.allowed_comparators
        enable_limit = False

        query_constructor = load_query_constructor_runnable(
            llm=self.llm,
            document_contents=document_content_description,
            attribute_info=metadata_field_info,
            enable_limit=enable_limit,
            schema_prompt=DEFAULT_SCHEMA_PROMPT,
            examples=EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES,
            **chain_kwargs,
        )

        query_constructor = query_constructor.with_config(
            run_name=query_constructor_run_name
        )

        return query_constructor

    def generate_metadata_filter(self, query: str) -> Dict:
        """
        This method will use the query constructor and generate the pre-filters for a list of datasets.
        :param query: User's query
        :return (dict): Returns pre-filter and new query for each dataset.
        """
        query = f"""Answer the below question:\n
                Question: {query}
                """
        query_constructor = self.create_query_constructor()

        structured_query = {}
        try:
            structured_query = query_constructor.invoke(query)
            logger.info(f"Structured query: {structured_query}")
            new_query, new_kwargs = self.translator.visit_structured_query(structured_query)
            pre_filter = enforce_constraints(new_kwargs)
            logger.info(f"Generated pre-filter query: {pre_filter}")
            logger.info(f"Generated new query: {query} -> {new_query}")
            if pre_filter:
                time_based_pre_filter, new_query = self.generate_time_based_filter(pre_filter, new_query)
                if time_based_pre_filter:
                    logger.info(f"Merging metadata filter: {pre_filter}, and\n\t{time_based_pre_filter}")
                    pre_filter["pre_filter"] = {
                        "$and": [pre_filter["pre_filter"], time_based_pre_filter["pre_filter"]]}
            logger.info(f"Final pre-filter query: {pre_filter}")
            pre_filter = pre_filter["pre_filter"] if pre_filter else {}
            new_query = new_query if new_query else query
        except Exception as ex:
            logger.error(f"Failed while creating pre-filter: {ex}")
        return pre_filter, new_query

    def generate_time_based_filter(self, pre_filter: Dict, query: str) -> Tuple[str, Dict]:
        """
        This method is responsible for generating filter query for "most recent", "latest", "earliest" type of user
        questions.
        :param pre_filter: (Dict) metadata pre-filter query
        :param query: (str) user query
        :param dataset: (str) MongoDB collection name
        :return: (Tuple[str, Dict]) Rewritten user question and time-based filter query
        """
        client = MongoDBClient(collection=self.collection)
        executor_tool = QueryExecutorMongoDBTool(client=client, match_filter=pre_filter["pre_filter"])
        tools = [executor_tool]
        attribute_str = _format_attribute_info(self.metadata_field_info)
        system_prompt_template = SYSTEM_PROMPT_TEMPLATE.format(attribute_info=attribute_str,
                                                               content_description=self.document_content_description)

        prompt = ChatPromptTemplate(input_variables=["agent_scratchpad", "input"],
                                    messages=[SystemMessagePromptTemplate(
                                        prompt=PromptTemplate(input_variables=[], template=system_prompt_template)),
                                        MessagesPlaceholder(variable_name="chat_history", optional=True),
                                        HumanMessagePromptTemplate(
                                            prompt=PromptTemplate(input_variables=["input"],
                                                                  template="{input}")),
                                        MessagesPlaceholder(variable_name="agent_scratchpad")])


        agent = create_tool_calling_agent(self.llm, tools, prompt)
        agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
        structured_query = agent_executor.invoke({"input": query})
        allowed_attributes = []
        for ainfo in metadata_field_info:
            allowed_attributes.append(
                ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
            )

        output_parser = StructuredQueryOutputParser.from_components(
            allowed_comparators=self.translator.allowed_comparators,
            allowed_operators=self.translator.allowed_operators,
            allowed_attributes=allowed_attributes
        )
        structured_query = output_parser.parse(structured_query["output"])
        logger.info(f"Structured query: {structured_query}")
        new_query, new_kwargs = self.translator.visit_structured_query(structured_query)
        time_based_pre_filter = enforce_constraints(new_kwargs)
        logger.info(f"Generated time based pre-filter query: {time_based_pre_filter}")
        logger.info(f"Generated new query after time based filtering: {query} -> {new_query}")
        return time_based_pre_filter, new_query


## Use the MetadataFilter and pass the 'pre_filter' in our MongoDBAtlasVectorSearch retriever

In [14]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(openai_api_key=openai_api_key, openai_api_base=openai_api_base, default_headers=default_headers)

In [18]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.vectorstores import MongoDBAtlasVectorSearch



system_prompt = """Use the following pieces of context to answer the user question in subsequent messages. The context was retrieved from a knowledge database and you should use only the facts from the context to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context. Don't address the context directly, but use it to answer the user question like it's your own knowledge.
Context: ```{context}```
"""

qa_prompt = ChatPromptTemplate.from_messages(
                [
                    ("system", system_prompt),
                    ("human", "{query}"),
                ]
            )

query = "I want to watch a movie released before year 2000 in the anime genre with the latest release date"

metadata_filter = MetadataFilter(collection=collection,
                                 llm=llm, 
                                 metadata_field_info=metadata_field_info,
                                 document_content_description=document_content_description)

pre_filter, new_query = metadata_filter.generate_metadata_filter(query)
query = new_query

vectorStore = MongoDBAtlasVectorSearch( collection, embeddings )
retriever = vectorStore.as_retriever(
    search_kwargs={'pre_filter': pre_filter}
)

def format_docs(docs):
    return "\n\n".join([d.page_content for d in docs])
    
chain = (
    {"context": retriever | format_docs, "query": RunnablePassthrough()}
    | qa_prompt
    | llm
    | StrOutputParser()
)

chain.invoke(query)