# Retrieval as a Tool

Where in the last usage, the retrieval is defined by graph so it can clearly what every steps doing. In the production use, we would like the retrieval as a tool. In this notebook, we will try to create a tool based on the last notebook.

In [1]:
# Prepare environment variable
# You should check env.example file
# It's needed throughout this jupyter notebook
import os
from dotenv import load_dotenv
load_dotenv()

True

# Retrieval Tool

Creating tools meaning we are only doing what the tools intendeed, let the summarizer by the agent it self

In [27]:
# Prepare library
## AI Search related
import langchain_community.vectorstores.azuresearch as azuresearch
from langchain_community.vectorstores.azuresearch import AzureSearch
from langchain_openai import AzureOpenAIEmbeddings

## LLM related
import os
import json
from langchain.chat_models import init_chat_model

## Tool creation related
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, AnyMessage

In [28]:
# Check if you already prepare the environment file
index_name = "kaenova-testing-3" # You can change it, if you want to
api_version = "2024-12-01-preview"
langchain_model_name = "azure_openai:gpt-4" # Adjust based on what you deployed 

for key in ["AZURE_AIS_ENDPOINT", "AZURE_AIS_KEY", "AZURE_OAI_ENDPOINT", "AZURE_OAI_KEY", "AZURE_OAI_DEPLOYMENT", "AZURE_OAI_CHAT_DEPLOYMENT"]:
    if not os.getenv(key):
        raise ValueError(f"{key} is not provided")

In [29]:
# !important you may change this. Map the fields into the connector library
azuresearch.FIELDS_CONTENT = "chunk"
azuresearch.FIELDS_CONTENT_VECTOR = "chunk_vector"
azuresearch.FIELDS_ID = "id"

In [30]:
# Prepare the embedding function
embeddings : AzureOpenAIEmbeddings = AzureOpenAIEmbeddings(
    azure_deployment=os.getenv("AZURE_OAI_DEPLOYMENT"),
    openai_api_version=api_version,
    azure_endpoint=os.getenv("AZURE_OAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OAI_KEY"),
)

embeddings_function = embeddings.embed_query

In [31]:
# Create Azure AI Search client
vector_store: AzureSearch = AzureSearch(
    azure_search_endpoint=os.getenv("AZURE_AIS_ENDPOINT"),
    azure_search_key=os.getenv("AZURE_AIS_KEY"),
    index_name=index_name,
    embedding_function=embeddings.embed_query,
)

In [51]:
# Prepare LLM for query Expander
os.environ["AZURE_OPENAI_API_KEY"] = os.getenv("AZURE_OAI_KEY")
os.environ["AZURE_OPENAI_ENDPOINT"] = os.getenv("AZURE_OAI_ENDPOINT")
os.environ["OPENAI_API_VERSION"] = api_version

chat_llm = init_chat_model(
    langchain_model_name,
    azure_deployment=os.environ["AZURE_OAI_CHAT_DEPLOYMENT"],
)

In [106]:
# Define the steps of retrieval flow
def expand_query(questions: str) -> list[str]:
    system_prompt = (
        "You're a search query expander, you'll get a user question and you need to create multiple concise, keyworded, and general search query"
        "That'll be injected to the Azure AI Search"
        "Return the question only in JSON array without any codeblocks"
        "You'll only return like \'[\"query_1\", \"query_2\", \"query_3\", \"query_4\", \"query_5\"]\'"
    )

    valid = False
    data = []
    while not valid:
        result = chat_llm.invoke([SystemMessage(content=system_prompt), HumanMessage(content=questions)])
        print("[expand_query] llm_result", result.content)
        
        try:
            data = json.loads(result.content)
            if type(data) is list:
                valid = True
        except:
            print("[expand_query] not a valid value")

    return data

def retrieve_documents(questions: list[str]) -> list[str]:
    retrieved_chunks = []
    for question in questions:
        answers = vector_store.semantic_hybrid_search(
            query=question, k=5,
        )
        for x in answers:
            chunk_data = {
                "id": x.metadata['id'],
                "content": x.page_content
            }
            retrieved_chunks.append(f"{json.dumps(chunk_data)}")
    print("[retrieve_documents] retrieved before set", len(retrieved_chunks), "of documents")
    final_chunks = list(set(retrieved_chunks))
    print("[retrieve_documents] retrieved", len(final_chunks), "of documents")
    return retrieved_chunks

In [107]:
# Creating the tool itself

@tool
def retrieve_tool(questions: str) -> list[str]:
    """Tool to retrieve document based on the questions related. This will return a json of related documents"""
    generated_questions = expand_query(questions)
    return retrieve_documents(generated_questions)

In [108]:
# Let's try the tool first
print(retrieve_tool("Kapan Gerakan perempuan?")[:1])

[expand_query] llm_result ["sejarah gerakan perempuan", "kapan mulai gerakan feminis", "tanggal penting dalam gerakan perempuan", "awal mula feminisme", "timeline gerakan hak perempuan"]
[retrieve_documents] retrieved before set 25 of documents
[retrieve_documents] retrieved 7 of documents
['{"id": "75f218c9a1341be9b298ed3848d7c055__chunk_0045", "content": "banyak kelompok yang memiliki massa, baik yang berbasis ideologi \\npolitik maupun agama. Kekuatan kelompok tersebut memunculkan warna \\nyang beragam pada identitas nasional dan berbagai peristiwa sejarah di \\nIndonesia. Beberapa di antaranya akan dibahas pada subbab berikut.\\n1.\\t Gerakan Perempuan \\nGerakan Perempuan pada tahun 1950\\u20141960 merupakan salah satu \\nperiode pergerakan paling progresif setelah tahun 1928. Pada periode ini \\nbanyak organisasi perempuan yang berafiliasi dengan kekuatan-kekuatan \\norganisasi massa yang besar. Sebagai contoh Aisyiah dari Muhammadiyah, \\nMuslimat dari Masyumi, Muslimat Nahdlatu