In [1]:
import numpy as np
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from langchain.chat_models import ChatOpenAI

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
from chatsky_llm_autoconfig.schemas import CompareResponse
from chatsky_llm_autoconfig.algorithms.three_stages_graph_generation import ThreeStagesGraphGenerator
from chatsky_llm_autoconfig.dialogue import Dialogue



In [4]:
from settings import EnvSettings
from utils import call_llm_api, read_json
from compare_prompt import (
    compare_graphs_prompt, graph_example_1, result_form
)

In [5]:
env_settings = EnvSettings()


In [6]:
evaluator = HuggingFaceCrossEncoder(model_name=env_settings.RERANKER_MODEL, model_kwargs={"device": env_settings.EMBEDDER_DEVICE})
graph_generator = ThreeStagesGraphGenerator()

In [7]:
def nodes2list(graph: dict) -> list:
    res = []

    for node in graph["nodes"]:
        utt = ""
        for n_utt in node['utterances']:
            utt += n_utt + " "
        res.append(utt)

    return res

In [8]:
def graph2list(graph: dict) -> tuple[list,int]:
    res = []
    n_edges = 0
    lens = []

    # node = [node for node in graph["nodes"] if node["is_start"]][0]
    for node in graph["nodes"]:
        edges = [e for e in graph['edges'] if e['source']==node["id"]]
        utt = ""
        for n_utt in node['utterances']:
            utt += n_utt + " "
        for edge in edges:
            lens.append(len(edge['utterances']))
            for e_utt in edge['utterances']:
                utt += e_utt + " "
                n_edges += 1
        res.append(utt)

        # node = [node for node in graph["nodes"] if node["id"]==edge['target']][0] 
    return res, n_edges, lens

In [9]:
def get_2_rerankings(generated1: list[str], golden1: list[str], generated2: list[str], golden2: list[str]):
    
    sz = len(generated1)
    to_score = []
    for gen in generated1:
        for gol in golden1:
            to_score.append((gen,gol))
    for gen in generated2:
        for gol in golden2:
            to_score.append((gen,gol))
    print("SCORING...")
    # print(to_score)
    scores = np.array(evaluator.score(to_score))
    print("finished")

    return scores[:sz*sz].reshape(sz,sz), scores[sz*sz:].reshape(sz,sz)

In [None]:
def compare_edge_lens(G1: BaseGraph, G2: BaseGraph, max: list, nodes1_list, nodes2_list):
    
    for i in range(len(max)):

        for edge1 in G1.edge_by_source(nodes1_list[i]):
            for edge2 in G2.edge_by_source(nodes2_list[max[i]]):

    if any([len1[i] != len2[max[i]] for i in range(len(max))]):
        print("for")

In [37]:
def llm_match(G1: BaseGraph, G2: BaseGraph) -> bool:
    g1 = G1.graph_dict
    g2 = G2.graph_dict

    # print("ORIG: ", g1)
    # g1_order = graph_order(g1)
    # g2_order = graph_order(g2)
    # print("ORDER: ", g1_order, "\n")
    # print("2LIST: ", graph2list(g1_order), "\n")
    #matrix = get_embedding(graph2list(g1_order), graph2list(g2_order), env_settings.EMBEDDER_MODEL, env_settings.EMBEDDER_DEVICE)

    nodes1_list = nodes2list(g1)
    nodes2_list = nodes2list(g2)
    if len(nodes1_list) != len(nodes2_list):
        return False

    g1_list, n1, len1 = graph2list(g1)
    g2_list, n2, len2 = graph2list(g2)
    print("LEN1: ", len1, "LEN2: ", len2)
    # print("G1: ", g1_list, "\n")
    # print("G2: ", g2_list, "\n")

    nodes_matrix, matrix = get_2_rerankings(nodes1_list, nodes2_list, g1_list, g2_list)
    nodes_max = list(np.argmax(nodes_matrix, axis=1))
    if len(set(nodes_max)) < len(nodes1_list):
        print("LENS")
        return False


    # print("LENS: ", len1, len2)
    if n1 != n2:
        print("N!")
        return False
    

    # matrix = get_reranking(g1_list, g2_list)
    max = list(np.argmax(matrix, axis=1))
    # print("MAX: ", max)
    if len(set(max)) < len(g1_list) or nodes_max != max:
        print("MIX")
        return False
    print("MAX: ", max)
    G1.edge_by_source(nodes1_list[i])
    G2.edge_by_source(nodes2_list[max[i]])
    if any([len1[i] != len2[max[i]] for i in range(len(max))]):
        print("for")
        return False
    print("NODES: ", np.min(np.max(nodes_matrix, axis=1)))
    print("ALL: ", np.min(np.max(matrix, axis=1)))

    if min(np.min(np.max(nodes_matrix, axis=1)),np.min(np.max(matrix, axis=1))) >= env_settings.SIM_THRESHOLD:
        return True
    # diags = get_diagonals(matrix)
    # # print("DIAGS: ", diags, "\n")
    # sums = np.sum(diags,axis=1)
    # max_index = np.argmax(sums)
    # g1_best = get_diagonal(g1,max_index)
    # min_value = np.min(diags[max_index])
    # print("MIN: ", min_value)
    # return True
    # print("\nG1: ", g1_best, "\n")
    # print("G2: ", g2_order, "\n")

    # if min_value >= env_settings.SIM_THRESHOLD:
    #     return True
    parser = PydanticOutputParser(pydantic_object=CompareResponse)
    format_model=ChatOpenAI(model=env_settings.FORMATTER_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL)
    model=ChatOpenAI(model=env_settings.COMPARE_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL)
    new_parser = OutputFixingParser.from_llm(parser=parser, llm=format_model)
    result = call_llm_api(compare_graphs_prompt.format(result_form=result_form,graph_example_1=graph_example_1, graph_1=g1, graph_2=g2), model|new_parser, temp=0).model_dump()
    # print("RES: ", result)
    return result['result']



In [11]:
dialogue_to_graph = read_json(env_settings.TEST_DATA_PATH)

In [12]:
case = dialogue_to_graph[1]

In [None]:
result_graph = graph_generator.invoke([Dialogue.from_list(c["messages"]) for c in case["dialogues"]])

In [32]:
result_graph.graph_dict

{'nodes': [{'id': 1,
   'label': '',
   'is_start': False,
   'utterances': ['Of course, we can facilitate an exchange for a different size. Please specify your preferred size.',
    'Certainly, we can exchange your product for a different size. Please let us know which size you would prefer.',
    'Sure, we can exchange your product for another size. Which size would you like instead?']},
  {'id': 2,
   'label': '',
   'is_start': True,
   'utterances': ["I'm sorry to hear that the size of the product you received isn't what you expected. How can I assist you today?"]},
  {'id': 3,
   'label': '',
   'is_start': False,
   'utterances': ["I'm sorry for the inconvenience. Could you please provide your order number?"]},
  {'id': 4,
   'label': '',
   'is_start': False,
   'utterances': ['Thank you. I see that you ordered a Medium size. Would you like a replacement, a refund, or to exchange it for a different size?']},
  {'id': 5,
   'label': '',
   'is_start': False,
   'utterances': ['U

In [28]:
case['topic']

'Resolving a mismatched product size complaint'

In [31]:
case['graph']

{'edges': [{'source': 1,
   'target': 2,
   'utterances': ['The size I ordered is not correct.']},
  {'source': 2, 'target': 3, 'utterances': ["Sure, it's 123456."]},
  {'source': 2,
   'target': 7,
   'utterances': ["Never mind, I don't need assistance."]},
  {'source': 3, 'target': 4, 'utterances': ['I would like a replacement.']},
  {'source': 3, 'target': 5, 'utterances': ['I would like a refund.']},
  {'source': 3,
   'target': 6,
   'utterances': ['I would like to exchange for a different size.',
    'Actually, I need a different size.',
    'Can I exchange it for another size?']},
  {'source': 3,
   'target': 7,
   'utterances': ["I changed my mind, I don't need any help."]},
  {'source': 4, 'target': 7, 'utterances': ["No, that's all, thank you."]},
  {'source': 4,
   'target': 3,
   'utterances': ['Actually, can I change my choice?']},
  {'source': 5, 'target': 7, 'utterances': ["No, that's all, thanks."]},
  {'source': 5,
   'target': 3,
   'utterances': ['Actually, can I cha

In [31]:
result_graph.graph_dict['edges']

[{'source': 2,
  'target': 3,
  'utterances': ['The size I ordered is not correct.']},
 {'source': 3, 'target': 4, 'utterances': ["Sure, it's 123456."]},
 {'source': 4, 'target': 5, 'utterances': ['I would like a replacement.']},
 {'source': 4,
  'target': 6,
  'utterances': ["I changed my mind, I don't need any help."]},
 {'source': 5,
  'target': 4,
  'utterances': ['Actually, can I change my choice?']},
 {'source': 4, 'target': 7, 'utterances': ['I would like a refund.']},
 {'source': 7, 'target': 4, 'utterances': ['Actually, can I change my mind?']},
 {'source': 1,
  'target': 4,
  'utterances': ['Actually, I want to choose another option instead.']},
 {'source': 5, 'target': 6, 'utterances': ["No, that's all, thank you."]},
 {'source': 7, 'target': 6, 'utterances': ["No, that's all, thanks."]},
 {'source': 1, 'target': 6, 'utterances': ["No, that's all, thank you."]},
 {'source': 4,
  'target': 1,
  'utterances': ['Can I exchange it for another size?',
   'I would like to exchange

In [17]:
case['graph']['edges']

[{'source': 1,
  'target': 2,
  'utterances': ['The size I ordered is not correct.']},
 {'source': 2, 'target': 3, 'utterances': ["Sure, it's 123456."]},
 {'source': 2,
  'target': 7,
  'utterances': ["Never mind, I don't need assistance."]},
 {'source': 3, 'target': 4, 'utterances': ['I would like a replacement.']},
 {'source': 3, 'target': 5, 'utterances': ['I would like a refund.']},
 {'source': 3,
  'target': 6,
  'utterances': ['I would like to exchange for a different size.',
   'Actually, I need a different size.',
   'Can I exchange it for another size?']},
 {'source': 3,
  'target': 7,
  'utterances': ["I changed my mind, I don't need any help."]},
 {'source': 4, 'target': 7, 'utterances': ["No, that's all, thank you."]},
 {'source': 4,
  'target': 3,
  'utterances': ['Actually, can I change my choice?']},
 {'source': 5, 'target': 7, 'utterances': ["No, that's all, thanks."]},
 {'source': 5, 'target': 3, 'utterances': ['Actually, can I change my mind?']},
 {'source': 6, 'targe

In [38]:
llm_match(Graph(case['graph']),result_graph)

LEN1:  [1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1] LEN2:  [1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1]
SCORING...


finished
MAX:  [1, 2, 3, 4, 6, 0, 5]
for


False