In [None]:
import os, getpass

# LLM og verktøy
from langchain_openai import AzureChatOpenAI
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.tools import FunctionTool
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core import ChatPromptTemplate
from llama_index.core.llms import ChatMessage
from llama_index.core import (get_response_synthesizer)
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import (VectorStoreIndex, StorageContext,  load_index_from_storage)

# Import av embedding-moduler
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingModelType

# Indeksverktøy
LLMGPT4omini = AzureChatOpenAI(
    model=os.getenv('AZURE_OPENAI_MODEL_GPT4omini'),
    deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME_GPT4omini'),
    azure_deployment=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME_GPT4omini'),
    api_key=os.getenv('AZURE_OPENAI_API_KEY_GPT4omini'),
    azure_endpoint=os.getenv('AZURE_OPENAI_AZURE_ENDPOINT_GPT4omini'),
    api_version=os.getenv('AZURE_OPENAI_API_VERSJON_GPT4omini'),
    temperature=0.0,
    timeout= 120,
)

def read_index_from_storage(storage):
    storage_context = StorageContext.from_defaults(persist_dir=storage)
    return load_index_from_storage(storage_context)

# Sett Azure OpenAI-legitimasjon

llm = LLMGPT4omini



from typing_extensions import Literal
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from typing_extensions import TypedDict


# Schema for structured output to use as routing logic
class Route(BaseModel):
    step: Literal["contraception", "tobacco", "mental health"] = Field(
        None, description="The next step in the routing process"
    )


# Augment the LLM with schema for structured output
from langchain_core.output_parsers import PydanticOutputParser
router = llm.with_structured_output(Route)

# State
class State(TypedDict):
    input: str
    decision: str
    output: str
    index_tobakk: VectorStoreIndex
    index_prevensjon : VectorStoreIndex
    index_psykiskhelse : VectorStoreIndex
    
    
# some global variables
chat_text_qa_msgs = [
        ChatMessage(
            role=MessageRole.SYSTEM,
            content=( 
                "You are a helpful assistant, and you will be given a user request."
                "Some rules to follow:"
                "- Always answer the request using the given context information and not prior knowledge"
                "- Always answer in norwegian"
                "- Always answer using the format: '\033[33mAnswer: \033[34m<answer>\033[0m'"

            ),
        )
        ,
        ChatMessage(
            role=MessageRole.USER,
            content=(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Query: {query_str}\n"
                "Answer: "
            ),
        ),
    ]
text_qa_template =  ChatPromptTemplate(chat_text_qa_msgs)

response_synthesizer = get_response_synthesizer(
    response_mode= "tree_summarize",
    text_qa_template = text_qa_template,
    summary_template= text_qa_template, #definitly in use for response_mode = tree_summarize
    structured_answer_filtering=True, 
    verbose=True,
)
text_splitter = SentenceSplitter.from_defaults(chunk_size=1024, chunk_overlap=75)



# Nodes
def llm_call_1(state: State):
    """Svarer på spørsmål om tobakk"""
    print("\nllm_call_1 invoked: Svarer på spørsmål om tobakk")
    storage = './blobstorage/chatbot/ungnotobakk'
    if not state.get("index_tobakk"):
        state["index_tobakk"] = read_index_from_storage(storage)
    else:
        print(f'Index for tobakk already loaded from {storage}')

    query_engine = state["index_tobakk"].as_query_engine(
        similarity_cutoff=0.7, 
        similarity_top_k=10,
        response_synthesizer=response_synthesizer
    )
    
    response = query_engine.query(state["input"])
    return {"output": response}


def llm_call_2(state: State):
    """Svarer på spørsmål om prevensjon"""
    print("\nllm_call_2 invoked: Svarer på spørsmål om prevensjon")
    storage = './blobstorage/chatbot/prevensjon'
    if not state.get("index_prevensjon"):
        state["index_prevensjon"] = read_index_from_storage(storage)
    else:
        print(f'Index for prevensjon already loaded from {storage}')

    query_engine = state["index_prevensjon"].as_query_engine(
        similarity_cutoff=0.7, 
        similarity_top_k=10,
        response_synthesizer=response_synthesizer
    )
    
    response = query_engine.query(state["input"])
    return {"output": response}


def llm_call_3(state: State):
    """Svarer på spørsmål om psykiskhelse"""
    print("\nllm_call_3 invoked: Svarer på spørsmål om psykiskhelse")
    storage = './blobstorage/chatbot/psykiskhelse'
    if not state.get("index_psykiskhelse"):
        state["index_psykiskhelse"] = read_index_from_storage(storage)
    else:
        print(f'Index for psykiskhelse already loaded from {storage}')

    query_engine = state["index_psykiskhelse"].as_query_engine(
        similarity_cutoff=0.7, 
        similarity_top_k=10,
        response_synthesizer=response_synthesizer
    )
    
    response = query_engine.query(state["input"])
    return {"output": response}

def llm_call_router(state: State):
    """Route the input to the appropriate node"""

    # Run the augmented LLM with structured output to serve as routing logic
    decision = router.invoke(
        [
            SystemMessage(
                content="Route the input to contraception, tobacco, mental health, based on the user's request."
            ),
            HumanMessage(content=state["input"]),
        ]
    )

    return {"decision": decision.step}

# Conditional edge function to route to the appropriate node
def route_decision(state: State):
    # Return the node name you want to visit next
    if state["decision"] == "tobacco":
        return "llm_call_1"
    elif state["decision"] == "contraception":
        return "llm_call_2"
    elif state["decision"] == "mental health":
        return "llm_call_3"

    
from langgraph.graph import StateGraph, START, END


# Build workflow
router_builder = StateGraph(State)

# Add nodes
router_builder.add_node("llm_call_1", llm_call_1)
router_builder.add_node("llm_call_2", llm_call_2)
router_builder.add_node("llm_call_3", llm_call_3)
router_builder.add_node("llm_call_router", llm_call_router)

# Add edges to connect nodes
router_builder.add_edge(START, "llm_call_router")
router_builder.add_conditional_edges(
    "llm_call_router",
    route_decision,
    {  # Name returned by route_decision : Name of next node to visit
        "llm_call_1": "llm_call_1",
        "llm_call_2": "llm_call_2",
        "llm_call_3": "llm_call_3",
    },
)
router_builder.add_edge("llm_call_1", END)
router_builder.add_edge("llm_call_2", END)
router_builder.add_edge("llm_call_3", END)

# Compile workflow
router_workflow = router_builder.compile()



from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles
# Show the workflow
# display(Image(router_workflow.get_graph().draw_mermaid_png(
#     curve_style=CurveStyle.LINEAR,
#     node_colors=NodeStyles(first="#ffdfba", last="#baffc9", default="#fad7de"),
#     wrap_label_n_words=9,
#     background_color="white",
#     padding=10
#     ,)))
from graph_utils import save_mermaid_diagram
save_mermaid_diagram(router_workflow.get_graph())


initial_state = {
    "input": "Gi meg noen tips om hvordan slutte med snus",
    "decision": "",
    "output": "",
    "index_tobakk": None,
    "index_prevensjon": None,
    "index_psykiskhelse": None
}
state = router_workflow.invoke(initial_state)

# Invoke
print(state["output"])

initial_state = {
    "input": "Hvilke bivirkninger har p-piller?",
    "decision": "",
    "output": "",
    "index_tobakk": None,
    "index_prevensjon": None,
    "index_psykiskhelse": None
}
state = router_workflow.invoke(initial_state)

# Invoke
print(state["output"])

✅ Mermaid diagram saved to: graph.mmd
🌐 Opening https://mermaid.live - paste your diagram code there.

llm_call_1 invoked: Svarer på spørsmål om tobakk
1 text chunks after repacking
[33mAnswer: [34mHer er noen tips for å slutte med snus:
1. Finn noe som motiverer deg til å slutte, for eksempel å spare penger til en spesifikk ting.
2. Sett en sluttdato og vær mentalt forberedt på abstinenser.
3. Finn støtte i andre, del dine planer med venner eller bruk Slutta-appen for motivasjon.
4. Legg en plan for å håndtere fristelser og tilbakefall, og bruk alternative metoder som tyggegummi eller drops.
5. Belønn deg selv når du når delmål, for eksempel med en kinobillett eller en fin middag. Lykke til med snusslutt!

llm_call_2 invoked: Svarer på spørsmål om prevensjon
1 text chunks after repacking
[33mAnswer: [34mVanlige bivirkninger av p-piller inkluderer hodepine, kvalme, kviser, redusert sexlyst, ømme bryster og humørsvingninger.
