In [2]:
from typing import Annotated, Literal
from typing_extensions import TypedDict

import os
from pathlib import Path
import sys
from dotenv import load_dotenv
import logging

from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver

from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage


project_root = Path().resolve().parent  
sys.path.append(str(project_root))

load_dotenv()

from src.receipt_intelligence.config import settings
from src.receipt_intelligence.modules.text2sql_service.prompts import Prompts


In [3]:
class ConversationState(TypedDict):
    """
    State class for the LangGraph workflow. 
    It keeps track of the information necessary to maintain a coherent conversation between the user and the bot.
    The StateGraph's state is defined as a typed dictionary containing an append-only list of messages. These messages form the chat history.

    Attributes:
        messages: A list of messages between the user and the bot.
        last_userquery_type (str): Current message type, used for routing the message to the correct node.
    """
    messages: Annotated[list, add_messages]
    last_userquery_type: Literal["general_prompt", "sql_prompt"] | None

In [4]:
def get_chat_model(temperature: float = 0.7, model_name: str = settings.GROQ_LLM_MODEL) -> ChatGroq:
    """
    Returns a ChatGroq model instance with the given temperature and model name.
    Defaults to standard settings.
    """
    return ChatGroq(
        api_key=os.getenv("GROQ_API_KEY"),
        model_name=model_name,
        temperature=temperature,
    )

# Function creates a runnable pipeline
def get_response_chain():

    model = get_chat_model()
    #model = model.bind_tools(tools)

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "{system_message}"),
            MessagesPlaceholder(variable_name="msgs"),
        ]
    )

    return prompt | model



In [5]:
def classify_message_node(state: ConversationState):
    
    last_message = state["messages"][-1].content

    response_chain = get_response_chain()

    response = response_chain.invoke({"system_message": Prompts.CLASSIFY_MESSAGE_PROMPT, "msgs": [HumanMessage(content=last_message)]})

    #return{"messages": [{"role": "assistant", "content": response.content}]}
    return {"last_userquery_type": response.content}

In [6]:
graph_builder = StateGraph(ConversationState)

graph_builder.add_node("classify_message", classify_message_node)

graph_builder.add_edge(START, "classify_message")
graph_builder.add_edge("classify_message", END)



<langgraph.graph.state.StateGraph at 0x271cc092ba0>

In [7]:
memory = InMemorySaver()
graph = graph_builder.compile(checkpointer=memory)



In [8]:
state = None

while True:
    user_input = input("Enter a message (or 'quit' to stop): ")
    if user_input.lower() == "quit":
        break

    # Run the graph with previous state
    state = graph.invoke(
        {"messages": [("user", user_input)]},
        config={"configurable": {"thread_id": "conversation-1"}},
        state=state,
    )

    print("Graph state:", state)

Graph state: {'messages': [HumanMessage(content='Hello world', additional_kwargs={}, response_metadata={}, id='13baa4c8-bbe4-4df5-9ce9-baaf4d5ebd8b')], 'last_userquery_type': 'general_prompt'}


In [1]:
print(state)

NameError: name 'state' is not defined