[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/agents/self_querying_agent_mongodb_unstructured.ipynb)

[![View Article](https://img.shields.io/badge/View%20Article-blue)](https://www.mongodb.com/developer/products/atlas/advanced-rag-metadata-filtering/?utm_campaign=devrel&utm_source=cross-post&utm_medium=organic_social&utm_content=https%3A%2F%2Fgithub.com%2Fmongodb-developer%2FGenAI-Showcase&utm_term=apoorva.joshi)

# Advanced RAG: Metadata Extraction and Self-Querying Retrieval

This notebook shows how to incorporate metadata filtering and self-querying retrieval into a RAG application using Unstructured, MongoDB and LangGraph.

## Step 1: Install required libraries

- **langgraph**: Python package to build controllable agents
<p>
- **langchain-mongodb**: Python package to use MongoDB Atlas with LangChain
<p>
- **langchain-openai**: Python package to use OpenAI models in LangChain

In [79]:
! pip install -qU langgraph langchain-mongodb langchain-openai


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Step 2: Install required libraries

- Set the MongoDB connection string. Follow the steps [here](https://www.mongodb.com/docs/manual/reference/connection-string/) to get the connection string from the Atlas UI.

- Set the OpenAI API key. Steps to obtain an API key as [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key)

In [10]:
import os
import getpass

In [51]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key:")

Enter your OpenAI API Key: ········


In [94]:
MONGODB_URI = getpass.getpass("Enter your MongoDB connection string:")

Enter your MongoDB connection string: ········


In [47]:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your Langsmith API Key:")

## Step 3: Instantiate the LLM

In [None]:
from langchain_openai import ChatOpenAI
from langchain_fireworks import ChatFireworks

In [57]:
llm = ChatOpenAI(model="gpt-4o-2024-05-13", temperature=0)

## Step 4: Define Graph State

In [197]:
from typing_extensions import TypedDict
from typing import Any, List, Dict

In [198]:
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: User query
        metadata: Metadata dictionary
        filter: MQL filter definition
        documents: List of documents
        generation: LLM generation
    """

    question: str
    metadata: Dict[str, List]
    filter: Dict[str, Any]
    documents: List[str]
    generation: str

## Step 5: Define Graph Nodes 

In [193]:
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import PromptTemplate
from datetime import datetime

### Metadata Extractor

In [117]:
companies = ["AT&T INC.",
 "American International Group, Inc.",
 "Apple Inc.",
 "BERKSHIRE HATHAWAY INC.",
 "Bank of America Corporation",
 "CENCORA, INC.",
 "CVS HEALTH CORPORATION",
 "Cardinal Health, Inc.",
 "Chevron Corporation",
 "Citigroup Inc.",
 "Costco Wholesale Corporation",
 "Exxon Mobil Corporation",
 "Ford Motor Company",
 "GENERAL ELECTRIC COMPANY",
 "GENERAL MOTORS COMPANY",
 "HP Inc.",
 "INTERNATIONAL BUSINESS MACHINES CORPORATION",
 "JPMorgan Chase & Co.",
 "MICROSOFT CORPORATION",
 "MIDLAND COMPANY",
 "McKESSON CORPORATION",
 "THE BOEING COMPANY",
 "THE HOME DEPOT, INC.",
 "THE KROGER CO.",
 "The Goldman Sachs Group, Inc.",
 "UnitedHealth Group Incorporated",
 "VALERO ENERGY CORPORATION",
 "Verizon Communications Inc.",
 "WALMART INC.",
 "WELLS FARGO & COMPANY"]

In [269]:
class Metadata(BaseModel):
    """Metadata to use for pre-filtering."""
    company: List[str] = Field(description=f"List of company names, eg: Google, Adobe etc. Match the names to companies on this list: {companies}")
    year: List[int] = Field(description="List containing start year and end year. For phrases like 'in the past X years/last year', extract the start year by subtracting X from the current year. The current year is {datetime.now().year}")

In [237]:
def extract_metadata(state):
    """
    Extract metadata from natural language query.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state i.e. metadata containing the metadata extracted from the user query.
    """
    print("---EXTRACT METADATA---")
    question = state["question"]
    template = """Extract the specified metadata from the following query:\n\n{question}"""
    prompt = PromptTemplate(
        template=template,
        input_variables=["question"]
    )
    chain = prompt | llm.with_structured_output(Metadata)
    result = chain.invoke({"question": question})
    metadata = {"metadata.custom_metadata.company": result.company, "metadata.custom_metadata.year": result.year}
    return {"metadata": metadata, "question": question}

### MQL Filter Generator

In [None]:
class Filter(BaseModel):
    """MongoDB filter definition."""
    filter: [Dict] = Field(description="The generated filter definition")

In [None]:
def generate_mql_filter(state):
    """
    Generate MongoDB Query Language (MQL) filter definition.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state i.e. the MQL filter.
    """
    print("---GENERATE FILTER DEFINITION---")
    metadata = state["metadata"]
    question = state["question"]
    template = """Generate a MongoDB filter definition from the user question and metadata extracted from it. Follow the guidelines below:
    - Respond in JSON with the filter assigned to a `filter` key.
    - The metadata field `metadata.custom_metadata.company` contains the list of companies found in the user query
    - The metadata field `metadata.custom_metadata.year` contains the years found in the user query
    - If any of the metadata fields is None, DO NOT include it in the filter.
    - If both the metadata fields are None, return an empty dictionary {{}}
    - The filter should only contain the fields `metadata.custom_metadata.company` and `metadata.custom_metadata.year`
    - The filter can only contain the following MQL match expressions:
        - $gt: Greater than
        - $lt: Lesser than
        - $gte: Greater than or equal to
        - $lte: Less than or equal to
        - $eq: Equal to
        - $ne: Not equal to
        - $in: Specified field value equals any value in the specified array
        - $nin: Specified field value is not present in the specified array
        - $nor: Logical NOR operation
        - $and: Logical AND operation
        - $or: Logical OR operation 
    - If the question has phrases such as "in the past X years", create a date range filter using expressions such as $gt, $lt, $lte and $gte
    - If the questions mentions a single company, use the $eq expression
    - If the question mentions multiple companies, use the $in expression 
    - To combine date range and company filters, use the $and operator\n\nQuestion: {question}\n\nMetadata: {metadata}
    """
    prompt = PromptTemplate(
        template=template,
        input_variables=["question", "metadata"]
    )
    chain = prompt | llm.with_structured_output(Filter, method="json_mode")
    result = chain.invoke({"question": question, "metadata": metadata})
    return result.filter