In [14]:
import logging
from langchain.prompts import PromptTemplate
from langchain_openai  import ChatOpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain_community.embeddings import HuggingFaceEmbeddings
from chatsky_llm_autoconfig.algorithms.base import GraphGenerator
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
from chatsky_llm_autoconfig.schemas import DialogueGraph
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.metrics.automatic_metrics import llm_match

In [15]:
from embedder import nodes2groups
from utils import call_llm_api, nodes2graph, dialogues2list, read_json
from settings import EnvSettings
from missing_edges_prompt import three_1, three_2

In [16]:
env_settings = EnvSettings()
logging.getLogger("langchain_core.vectorstores.base").setLevel(logging.ERROR)

In [17]:
class ThreeStagesGraphGenerator(GraphGenerator):
    """Graph generator based on list of diaolgues.
    Thee stages:
    1. Algorithmic grouping assistant utterances into nodes.
    2. Algorithmic connecting nodes by edges.
    3. If one of dialogues ends with user's utterance, ask LLM to add missing edges.
    """

    def invoke(self, dialogue: list[Dialogue] = None, graph: DialogueGraph = None, topic: str = "") -> BaseGraph:

        base_model = ChatOpenAI(model=env_settings.GENERATION_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL, temperature=1)
        nexts, nodes, starts, neigbhours, last_user = dialogues2list(dialogue)

        groups = nodes2groups(nodes, [" ".join(p) for p in nexts], [n+ " ".join(p) + " " for p,n in zip(nexts, nodes)], neigbhours)
        nodes = []
        for idx, group in enumerate(groups):
            if any([gr in starts for gr in group]):
                start = True
            else:
                start = False
            nodes.append({"id":idx+1, "label": "", "is_start": start, "utterances": group})

        embeddings = HuggingFaceEmbeddings(model_name=env_settings.EMBEDDER_MODEL, model_kwargs={"device": env_settings.EMBEDDER_DEVICE})
        graph_dict = nodes2graph(nodes, dialogue, embeddings)
        graph_dict = {"nodes": graph_dict['nodes'], "edges": graph_dict['edges'], "reason": ""}

        if not last_user:
            result_graph = Graph(graph_dict=graph_dict)
            return result_graph    
        partial_variables = {}
        prompt_extra = ""
        for idx, dial in enumerate(dialogue):
            partial_variables[f"var_{idx}"] = dial.to_list()
            prompt_extra += f" Dialogue_{idx}: {{var_{idx}}}"
        prompt = PromptTemplate(template=three_1+"{graph_dict}. "+three_2+prompt_extra, input_variables=["graph_dict"], partial_variables=partial_variables)

        model = base_model | PydanticOutputParser(pydantic_object=DialogueGraph)

        result = call_llm_api(prompt.format(graph_dict=graph_dict), model, temp=0)
        if result is None:
            return Graph(graph_dict={})
        result.reason = "Fixes: " + result.reason
        graph_dict=result.model_dump()
        if not all([e['target'] for e in graph_dict['edges']]):
            return Graph(graph_dict={})
        result_graph = Graph(graph_dict=graph_dict)
        return result_graph

    async def ainvoke(self, *args, **kwargs):
        return self.invoke(*args, **kwargs)


In [18]:
graph_generator = ThreeStagesGraphGenerator()

In [19]:
dialogue_to_graph = read_json(env_settings.TEST_DATA_PATH)
case = dialogue_to_graph[0]

In [None]:
result_graph = graph_generator.invoke([Dialogue.from_list(c["messages"]) for c in case["dialogues"]])

In [12]:
result_graph.graph_dict

{'nodes': [{'id': 1,
   'label': '',
   'is_start': True,
   'utterances': ['Of course! What action would you like to change?',
    'Certainly! Which part would you like to modify?',
    'Sure! What would you like to alter?']},
  {'id': 2,
   'label': '',
   'is_start': True,
   'utterances': ["I'm sorry you're experiencing issues with your file download. Let's work together to fix this. Could you please describe the problem you're encountering?"]},
  {'id': 3,
   'label': '',
   'is_start': False,
   'utterances': ["I'm sorry to hear that the download is being interrupted. Let's try a few troubleshooting steps. First, can you check if your internet connection is stable?"]},
  {'id': 4,
   'label': '',
   'is_start': False,
   'utterances': ['Great. Next, please try pausing the download and then resuming it. Let me know if that helps.']},
  {'id': 5,
   'label': '',
   'is_start': False,
   'utterances': ["Understood. Let's try clearing your browser's cache and then attempt the downloa

In [20]:
llm_match(Graph(case['graph']),result_graph)

SCORING...
finished
NODES:  0.9999753
ALL:  0.9999455


True