In [None]:
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from agent_state_lib import AgentState

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from language_lib import detect_language

from milvus_lib import retrieve_docs

In [None]:
model_path = "h:/ML_Models/_gemma/model/gemma-2b-it"  # local folder

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded.")

prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer based on retrieved documents."),
    ("human", "{question}\n\nDocs:\n{retrieved_docs}")
])

chain = prompt | llm | StrOutputParser()

def generate_answer(state: AgentState) -> AgentState:
    answer = chain.invoke({
        "question": state["question"],
        "retrieved_docs": "\n".join(state["retrieved_docs"])
    })
    return {**state, "answer": answer}

In [None]:

test_state: AgentState = {
    "question": "",
    "embedded_question": "big dogs",}

response = retrieve_docs(test_state)
for doc in response["retrieved_docs"]:
    print(doc)

def translate_question(state: AgentState) -> AgentState:
    translation_prompt = ChatPromptTemplate.from_messages([
        ("system", "Translate the following question to English."),
        ("human", "{question}")
    ])
    translator_chain = translation_prompt | llm | StrOutputParser()
    
    translated = translator_chain.invoke({"question": state["question"]})
    return {**state, "question": translated}

In [None]:
graph = StateGraph(AgentState)
graph.add_node("detect_language", detect_language)
graph.add_node("translate", translate_question)
graph.add_node("retrieve", retrieve_docs)
graph.add_node("reason", generate_answer)

graph.set_entry_point("detect_language")

graph.add_conditional_edges(
    "detect_language",
    # This function returns a string key
    detect_language,
    # Routing map
    {
        "english": "retrieve",
        "non_english": "translate"
    }
)

graph.add_edge("translate", "retrieve")
graph.add_edge("retrieve", "reason")
graph.set_finish_point("reason")

agent = graph.compile()

In [None]:
response = agent.invoke({
    "question": "based on the documents, what is the best dog for rescue operations?",
    "embedded_question": "big dogs",})
print(response["answer"])
