In [None]:
from dotenv import load_dotenv
load_dotenv()

### Query Expansion

In [None]:
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
import os

oai = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"), model="text-embedding-3-small")

docs = [
    Document(id=1, page_content="Employees have 20 leaves per year."),
    Document(id=2, page_content="Apply for leave in the HR portal."),
    Document(id=3, page_content="Remote work policies are in the handbook."),
]

vectorstore = InMemoryVectorStore(oai)
await vectorstore.aadd_documents(docs)

['1', '2', '3']

In [54]:
from langchain.tools import tool

SIMILARITY_THRESHOLD = 0.5

@tool
async def retrieve(contextual_query: str):
    '''Retrieve relevant document for given query'''

    docs = [doc.page_content for doc, sim in 
            vectorstore.similarity_search_with_score(contextual_query) 
            if sim >= SIMILARITY_THRESHOLD]
    
    return "\n".join(docs)

In [128]:
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage, SystemMessage

TOOLS = {"retrieve": retrieve}

system_prompt = '''You are a helpful internal HR assistant.
You are given a tool called `retrieve` to search for relevant documents based on a given query.
When invoking the tool, always enrich the query with the context based on chat history.'''

history = [
    SystemMessage(system_prompt),
    HumanMessage("How many leaves do I get?"),
    AIMessage("", tool_calls=[{"id": "1", "name": "retrieve", "args": {"contextual_query": "employee leave entitlements and policies"}}]),
    ToolMessage("Employees have 20 leaves per year.", tool_call_id="1"),
    AIMessage("You have 20 days annuallly."),
    ("human", "{query}")
]

prompt = ChatPromptTemplate.from_messages(history)

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5).bind_tools([retrieve])

In [129]:
chain = prompt | llm
res = chain.invoke(input={"query": "and how do I apply?"})
history.append(res)

In [130]:
res.model_dump()

{'content': '',
 'additional_kwargs': {'tool_calls': [{'id': 'call_zs7ikbtL9dIKDAmmcgccMO6f',
    'function': {'arguments': '{"contextual_query":"how to apply for leave as an employee"}',
     'name': 'retrieve'},
    'type': 'function'}],
  'refusal': None},
 'response_metadata': {'token_usage': {'completion_tokens': 22,
   'prompt_tokens': 155,
   'total_tokens': 177,
   'completion_tokens_details': {'accepted_prediction_tokens': 0,
    'audio_tokens': 0,
    'reasoning_tokens': 0,
    'rejected_prediction_tokens': 0},
   'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}},
  'model_name': 'gpt-4o-mini-2024-07-18',
  'system_fingerprint': 'fp_34a54ae93c',
  'id': 'chatcmpl-BdKvu9ELvfWd6utMM4ERotJmX3JlP',
  'service_tier': 'default',
  'finish_reason': 'tool_calls',
  'logprobs': None},
 'type': 'ai',
 'name': None,
 'id': 'run--4361b2a2-661f-4dc7-93a0-3f8c413eb736-0',
 'example': False,
 'tool_calls': [{'name': 'retrieve',
   'args': {'contextual_query': 'how to apply f

In [131]:
while res.tool_calls:
    for tcall in res.tool_calls:
        msg = await TOOLS[tcall['name'].lower()].ainvoke(tcall)
        history.append(msg)
    res = chain.invoke(history)
    history.append(res)

In [132]:
print(res.content)

To apply for leave, you can do so through the HR portal. Remember, you have 20 leaves available per year.
