In [None]:
import json
import os
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain.schema import Document
from pydantic import BaseModel, Field, ValidationError
from typing import Literal, Dict, List
from typing_extensions import TypedDict

from langchain_community.chat_models import ChatOllama
from langchain_community.vectorstores import Chroma
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryByteStore,InMemoryStore
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.embeddings import HuggingFaceEmbeddings

import pickle
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_API_KEY"] = "YOUR_LANGSMITH_API_KEY"
os.environ["LANGSMITH_PROJECT"] = "YOUR_LANGSMITH_PROJECT"
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "YOUR_HUGGINGFACEHUB_API_TOKEN"

In [3]:
from langchain_cohere import CohereEmbeddings

# Initialize embeddings
#embeddings = CohereEmbeddings(model="embed-english-light-v3.0")
llm = ChatOllama(model="llama3.2", temperature=0)
qwen = ChatOllama(model="qwen2:1.5b", temperature=0)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

  llm = ChatOllama(model="llama3.2", temperature=0)
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
BASE_DIR = os.getcwd()
PERSIST_DIRECTORY = os.path.join(BASE_DIR, "vector-db", "lhdn_db_new")
COLLECTION_NAME = "lhdn"


vectorstore = Chroma(
    collection_name=COLLECTION_NAME,
    persist_directory=PERSIST_DIRECTORY,
    embedding_function=embeddings
)

retriever = vectorstore.as_retriever()

  vectorstore = Chroma(


In [5]:
from typing import Dict, List, Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatOllama
from pydantic import BaseModel, ValidationError, Field
import json
import time
from langchain_core.prompts import (
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
)

class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""
    datasource: Literal["vectorstore", "llm"]


# LLM (Ollama)


# Improved Prompt (clear, compact, and JSON-structured output)
system = """
You are an intelligent assistant for LHDN (Lembaga Hasil Dalam Negeri Malaysia).

Your task is to decide whether a user's question should be answered using a factual vector database (which includes complete information about Malaysian tax for individuals, companies, and e-invoicing), or by the LLM for general conversation.

Use the chat history below if helpful:
{chat_history}

Rules:
- If the question is about LHDN, taxes (personal, company), or e-invoicing → respond:
  {{ "datasource": "vectorstore" }}
- If the question is small talk, greeting, opinion-based, or unrelated to LHDN/tax/e-invoice → respond:
  {{ "datasource": "llm" }}

Respond with ONLY a single line of JSON (no explanation), like:
{{ "datasource": "vectorstore" }}
"""


system_template = SystemMessagePromptTemplate.from_template(system)
human_template  = HumanMessagePromptTemplate.from_template("{question}")
route_prompt    = ChatPromptTemplate.from_messages([system_template, human_template])
chain           = route_prompt | llm

def get_routing_output(question: str, chat_history: List[Dict[str, str]] = [], retries: int = 2):
    for attempt in range(1, retries + 1):
        start = time.time()
        response = chain.invoke({"question": question, "chat_history": chat_history})
        content = response.content.strip()

        try:
            parsed = json.loads(content)
            result = RouteQuery(**parsed)
            print(
                f"[✅ Attempt {attempt}] Took {time.time() - start:.2f}s → {result.datasource}")
            return result
        except (json.JSONDecodeError, ValidationError):
            print(f"[⚠️ Attempt {attempt}] Failed to parse → {content}")
            if attempt == retries:
                return None


# Tests
print(get_routing_output(
    "What are the type of business that require to implement e-invoice?"))
print(get_routing_output(
    "What is tax"))
print(get_routing_output(
    "Explain e-invoice"))
print(get_routing_output("Hey, how are you today?"))

[✅ Attempt 1] Took 14.04s → vectorstore
datasource='vectorstore'
[✅ Attempt 1] Took 0.41s → vectorstore
datasource='vectorstore'
[✅ Attempt 1] Took 0.42s → vectorstore
datasource='vectorstore'
[✅ Attempt 1] Took 0.37s → llm
datasource='llm'


In [6]:
## Grade Document

# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: Literal["yes", "no"] = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )

# Prompt for grading
system = """You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keywords or has semantic meaning related to the user question, grade it as relevant, and return {{ "binary_score": "yes" }}. 
If the document not related to the user question,  return {{ "binary_score": "no" }}
You are not required to be strict—just filter out obviously wrong matches.

Respond ONLY with a single line JSON like this (no explanation, no extra text): {{ "binary_score": "yes" }} or {{ "binary_score": "no" }}."""

grade_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "Retrieved document:\n\n{document}\n\nUser question: {question}")
])

# Chain
retrieval_grader = grade_prompt | llm

# Helper function
def grade_document(question: str, document: str):
    response = retrieval_grader.invoke({"question": question, "document": document})
    raw = response.content.strip()
    try:
        parsed = json.loads(raw)
        result = GradeDocuments(**parsed)
        return result
    except Exception:
        # Fallback handling for raw "yes" or "no"
        if raw.lower() in ["yes", "relevant" '"yes"', "'yes'"]:
            return GradeDocuments(binary_score="yes")
        elif raw.lower() in ["no", '"no"', "'no'"]:
            return GradeDocuments(binary_score="no")
        else:
            print("⚠️ Failed to parse grading result:", raw)
            return GradeDocuments(binary_score="yes")
            return None


In [19]:
from langchain_core.prompts import PromptTemplate

prompt = PromptTemplate.from_template(
    """You are a helpful and friendly assistant from LHDN (Lembaga Hasil Dalam Negeri Malaysia).

Your goal is to think through the user's question logically and clearly. If it involves numbers or reasoning (e.g., tax calculations, income estimations), first break it down step by step in your mind. Then, use the information from the context below to support and verify your answer.

Do not mention the word "context", "document", or how you got the information. Just answer as if you are directly speaking to the user, in a clear and natural tone.

If the answer cannot be clearly found or supported, respond politely and conversationally, and let the user know you’re not certain.

Context:
{context}

User Question:
{question}

Answer (with thoughtful reasoning and helpful explanation):"""
)



# LLM

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

In [20]:
# Data model for grading hallucination
class GradeHallucinations(BaseModel):
    """Binary score for hallucination present in generation answer."""
    binary_score: Literal["yes", "no"] = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

system = """You are a grader that only will output 'yes' or 'no' that will be assessing whether an LLM generation is grounded in or supported by a set of retrieved facts.
Return {{ "binary_score": "yes" }} if the answer is grounded in or supported by the set of facts, else return {{ "binary_score": "no" }}. 
Respond ONLY with a single line JSON like this (no explanation, no extra text): {{ "binary_score": "yes" }} or {{ "binary_score": "no" }}.
"""

hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
    ]
)

hallucination_grader = hallucination_prompt | llm

def grade_hallucination(documents, generation):
    # Prepare the input for the LLM
    response = hallucination_grader.invoke({
        "documents": documents,
        "generation": generation
    })

    try:
        parsed = json.loads(response.content.strip())
        result = GradeHallucinations(**parsed)
        return result
    except (json.JSONDecodeError, ValidationError) as e:
        print("⚠️ Failed to parse grading result:", response)
        return None

# Test the grader with real data
# graded_result = grade_hallucination(docs,generation)

# print(graded_result)


In [21]:
### Answer Grader

# Data model
class GradeAnswer(BaseModel):
    """Binary score to assess answer addresses question."""

    binary_score: Literal["yes", "no"] =  Field(
        description="Answer addresses the question, 'yes' or 'no'"
    )

# Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n 
    'Yes' means that the answer resolves the question. 'No' means that the answer doesn't resolve the question
     Give a binary score {{ "binary_score": "yes" }} or {{ "binary_score": "no" }}. 
     Respond ONLY with a single line JSON like this (no explanation, no extra text): {{ "binary_score": "yes" }} or {{ "binary_score": "no" }}
     """

answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

answer_grader = answer_prompt | llm

def grade_answer(question, generation):
    # Prepare the input for the LLM
    response = answer_grader.invoke({
        "question": question,
        "generation": generation
    })

    try:
        parsed = json.loads(response.content.strip())
        result = GradeAnswer(**parsed)
        return result
    except (json.JSONDecodeError, ValidationError) as e:
        print("⚠️ Failed to parse grading result:", response)
        return None

# Test the grader with real data
# graded_result = grade_answer(question, generation)

# print(graded_result)


In [22]:
### Question Re-writer


# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n 
     for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.
     Finally produce a new question and return the ONLY the new question.
     Don't include your reasoning process in your output.
     """

re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()
# question_rewriter.invoke({"question": question})

In [29]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

system =  """You are an assistant that helps rephrase unclear or follow-up questions into clear, standalone questions.

Use the chat history to understand the full context. Your goal is to rewrite the user's current question so that it makes sense on its own without needing prior messages.

Be clear, natural, and concise. Keep the meaning the same — do not add new assumptions.

Chat History:
{history}

User’s Follow-up Question:
{question}

Rephrased Standalone Question:"""

re_write_with_history_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human",  """
Original Question:
{question}

Recent Chat History:
{chat_history}
"""),
])

question_rewriter_with_history = (
    re_write_with_history_prompt
    | qwen
    | StrOutputParser()
)


In [24]:
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
        chat_history: list of chat history
    """

    question: str
    generation: str
    documents: List[str]
    chat_history: List[Dict[str, str]] = []
    from_llm_direct: bool
    retry_count: int = 0
    

In [33]:
def sentiment_analysis(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RUNNING: Sentiment Analysis---")

    return "peace"


def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RUNNING: retrieve---")
    question = state["question"]

    # Retrieval
    documents = retriever.invoke(question)

    return {"documents": documents, "question": question}


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    for d in documents:
        score = grade_document(question, d.page_content)
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": documents, "question": question}


def generate(state):
    """
    Generate an answer using RAG and maintain chat history.

    Args:
        state (GraphState): The current graph state

    Returns:
        GraphState: Updated state with generation and chat history
    """
    print("---RAG GENERATE---")
    question = state["question"]
    documents = state["documents"]
    chat_history = state.get("chat_history", [])

    # Add user message
    chat_history.append({"role": "user", "content": question})

    generation = rag_chain.invoke({"context": documents, "question": question})

    # Add assistant's response to history
    chat_history.append({"role": "assistant", "content": generation})

    return GraphState(
        question=question,
        documents=documents,
        generation=generation,
        chat_history=chat_history,
    )


def generate_llm_response(state: GraphState) -> GraphState:
    question = state["question"]
    chat_history = state.get("chat_history", [])

    # Add current question
    chat_history.append({"role": "user", "content": question})
    print(chat_history)

    # Get response from LLM
    response = llm.invoke(question)
    answer = response.content.strip()

    # Add response to history
    chat_history.append({"role": "assistant", "content": answer})

    return GraphState(
        question=question,
        generation=answer,
        chat_history=chat_history
    )


def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state.get("documents", [])
    retry_count = state.get("retry_count", 0)

    # Re-write question
    better_question = question_rewriter.invoke({"question": question})
    print(better_question)
    return {"documents": documents, "question": better_question, "retry_count": retry_count + 1}


def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    state["question"]
    filtered_documents = state["documents"]

    if not filtered_documents:
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, CHECK RETRY LIMIT---"
        )
        return check_retry_limit(state)
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"


def check_retry_limit(state):
    retry_count = state.get("retry_count", 0)
    if retry_count < 1:
        print(
            "---DECISION: RETRY LIMIT NOT EXCEED, TRANSFORM QUERY---"
        )
        return "transform_query"
    else:
        print(
            "---DECISION: LIMIT EXCEED, FALLBACK---"
        )
        return "llm_fallback"


def grade_generation_v_documents_and_question(state):
    """
    Determines whether the generation is grounded in the document and answers question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """

    print("---CHECK HALLUCINATIONS---")

    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    #score = #grade_hallucination(documents, generation)
    grade = "yes" #score.binary_score

    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check question-answering
        print("---GRADE GENERATION vs QUESTION---")

        score = grade_answer(question, generation)
        grade = score.binary_score
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "hallucinate"


def route_question(state):
    """
    Route question to appropriate node based on the datasource and vector database.
    """

    print("--- ROUTE QUESTION ---")
    question = state["question"]
    print("Question:", question)

    source = get_routing_output(question)
    print("Datasource:", source.datasource)

    if source == None:
        print("→ Source is None, fallback. Routing to: llm (direct)")
        return "llm"

    if source.datasource == "llm":
        print("→ Routing to: llm (direct)")
        state["from_llm_direct"] = True
        return "llm"

    if source.datasource == "vectorstore":
        return "vectorstore"

    raise ValueError(f"Unknown datasource: {source.datasource}")

def transform_query_with_history(state):
    """
    Transform the user's query using chat history context to generate a more complete question.

    Args:
        state (dict): The current graph state containing 'question', optional 'documents', and 'chat_history'.

    Returns:
        dict: Updated state with a refined 'question'.
    """
    print("---TRANSFORM QUERY WITH HISTORY---")

    question = state["question"]
    chat_history = state.get("chat_history", [])  # This should be a list of tuples: [(user_msg, ai_msg), ...]

    # Reconstruct context from chat history
    history_text = ""
    for i, (user_msg, ai_msg) in enumerate(chat_history[-5:]):  # Only use recent history to avoid context overload
        history_text += f"User: {user_msg}\nAI: {ai_msg}\n"
    
    prompt_input = {
        "question": question,
        "chat_history": history_text.strip()
    }

    better_question = question_rewriter_with_history.invoke(prompt_input)
    print(better_question)

    return {
        "question": better_question,
        "chat_history": chat_history,
        "documents": state.get("documents", []),
    }

def smart_transform_query(state):
    """
    Smartly transform query only if it's unclear or a follow-up.
    
    Args:
        state (dict): Contains current question, chat history, and more
    
    Returns:
        state (dict): May update the 'question' key
    """
    from langchain_core.output_parsers import StrOutputParser
    from langchain_core.prompts import ChatPromptTemplate

    question = state["question"]
    history = state.get("chat_history", [])
    documents = state.get("documents", [])
    retry_count = state.get("retry_count", 0)

    # Step 1: Use LLM to classify if the question needs rephrasing
    need_transform_prompt = ChatPromptTemplate.from_messages([
        ("system", 
        """You are an assistant that determines whether a user's current question needs to be rephrased based on the prior conversation history.

    Respond ONLY with 'yes' or 'no'.

    Respond 'yes' if:
    - The question contains vague or ambiguous language.
    - The question includes pronouns or references to earlier messages (e.g., "that", "this", "he", "she", "it", "the previous one", "your answer", "what you said").
    - The question is clearly a follow-up that relies on previous context (e.g., it adds new information or builds on a past response).
    - The question is expressing confusion, dissatisfaction, or asking for clarification about a prior response (e.g., "I don’t understand", "why is that?", "can you explain more?", "that doesn’t make sense").

    Respond 'no' if:
    - The question is clearly stated and can stand on its own without needing to refer to any prior messages.
    - Even if the user has asked something before, the current question is complete and self-contained.

    Always analyze both the current question and the chat history carefully before deciding.
        """),
        
        ("user", 
        "Chat history: {{history}}\n\nCurrent question: {{question}} \n\nAnswer with 'yes' or 'no' only.")
    ])


    chain = need_transform_prompt | llm | StrOutputParser()

    response = chain.invoke({
        "question": question,
        "history": format_history(history)
    })

    if response.strip().lower() == "yes" and history:
        print("[smart_transform_query] Transforming question due to follow-up or unclear input.")
        # Rewrite using historical context
        better_question = transform_query_with_history(state)["question"]
        return {
            "documents": documents,
            "question": better_question,
            "retry_count": retry_count,
            "chat_history": history
        }
    else:
        print("[smart_transform_query] Using original question. No transformation needed.")
        return {
            "documents": documents,
            "question": question,
            "retry_count": retry_count,
            "chat_history": history
        }


def format_history(history):
    return "\n".join([f"User: {q}\nBot: {a}" for q, a in history])


In [34]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# Nodes
workflow.add_node("sentiment_analysis", sentiment_analysis)
workflow.add_node("retrieve", retrieve)
#workflow.add_node("grade_documents", grade_documents)
workflow.add_node("transform_query", transform_query)
workflow.add_node("llm_fallback", generate_llm_response)
workflow.add_node("generate", generate)
workflow.add_node("preprocess_prompt", smart_transform_query)
workflow.add_node("transform_query_with_history",transform_query_with_history)

workflow.add_conditional_edges(
    START,
    sentiment_analysis,
    {
        "peace": "preprocess_prompt",
        "angry": "retrieve",
    },
)

### ROUTE QUESTION BRANCH ###
workflow.add_conditional_edges(
    "preprocess_prompt",
    route_question,
    {
        "vectorstore": "retrieve",
        "llm": "llm_fallback",
    },
)

### RETRIEVE BRANCH ###
#workflow.add_edge("retrieve", "grade_documents")

workflow.add_conditional_edges(
    "retrieve",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "llm_fallback": "llm_fallback",
        "generate": "generate",
    },
)

### RETRY BRANCH ###
workflow.add_edge("transform_query", "retrieve")
workflow.add_edge("llm_fallback", END)

### GENERATE BRANCH ###
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question,
    {
        "hallucinate": END,
        "useful": END,
        #TODO
        "not useful": "transform_query_with_history",
    },
)

workflow.add_edge("transform_query_with_history", "retrieve")

# Compile
app = workflow.compile()


In [35]:
from pprint import pprint

chat_history = []

print("🤖 Welcome! Type 'exit' to quit.\n")

while True:
    user_input = input("You: ")
    if user_input.strip().lower() in ["exit", "quit"]:
        print("👋 Goodbye!")
        break

    state = {
        "question": user_input,
        "chat_history": chat_history,
    }

    print(user_input)

    # Run through the LangGraph app
    for output in app.stream(state):
        for key, value in output.items():
            if key == "llm_fallback" or key == "generate":  # depends on where generation happens
                answer = value["generation"]
                chat_history = value["chat_history"]
                pprint(f"\nAI: {answer}\n")


🤖 Welcome! Type 'exit' to quit.

my compnay have an annual revenue of Rm20 million, when should we implement e-invoice?
---RUNNING: Sentiment Analysis---
[smart_transform_query] Using original question. No transformation needed.
--- ROUTE QUESTION ---
Question: my compnay have an annual revenue of Rm20 million, when should we implement e-invoice?
[✅ Attempt 1] Took 1.03s → vectorstore
Datasource: vectorstore
---RUNNING: retrieve---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
---RAG GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
('\n'
 'AI: Based on the information provided by the Inland Revenue Board of '
 'Malaysia (IRBM), the implementation date for e-Invoice is determined based '
 'on the annual turnover or revenue.\n'
 '\n'
 'Since your company has an annual revenue of RM20 million, which falls within '
 'the threshold mentioned in Table 1.1, you a

KeyError: "Input to ChatPromptTemplate is missing variables {'history'}.  Expected: ['chat_history', 'history', 'question'] Received: ['question', 'chat_history']\nNote: if you intended {history} to be part of the string and not a variable, please escape it with double curly braces like: '{{history}}'.\nFor troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/INVALID_PROMPT_INPUT "