In [1]:
import json
import numpy as np
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from langchain.chat_models import ChatOpenAI

In [2]:
%load_ext autoreload
%autoreload 2

In [69]:
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
from chatsky_llm_autoconfig.metrics.embedder import get_embedding
from chatsky_llm_autoconfig.metrics.automatic_metrics import is_same_structure
from chatsky_llm_autoconfig.schemas import CompareResponse
from chatsky_llm_autoconfig.algorithms.three_stages_graph_generation import ThreeStagesGraphGenerator
from chatsky_llm_autoconfig.dialogue import Dialogue

In [4]:
from settings import EnvSettings
from utils import call_llm_api, read_json
from compare_prompt import (
    compare_graphs_prompt, graph_example_1, result_form
)

In [5]:
env_settings = EnvSettings()


In [6]:
evaluator = HuggingFaceCrossEncoder(model_name=env_settings.RERANKER_MODEL, model_kwargs={"device": env_settings.EMBEDDER_DEVICE})
graph_generator = ThreeStagesGraphGenerator()

In [7]:
def nodes2list(graph: dict) -> list:
    res = []

    for node in graph["nodes"]:
        utt = ""
        for n_utt in node['utterances']:
            utt += n_utt + " "
        res.append(utt)

    return res

In [8]:
def graph2list(graph: dict) -> tuple[list,int]:
    res = []
    n_edges = 0
    lens = []

    # node = [node for node in graph["nodes"] if node["is_start"]][0]
    for node in graph["nodes"]:
        edges = [e for e in graph['edges'] if e['source']==node["id"]]
        utt = ""
        for n_utt in node['utterances']:
            utt += n_utt + " "
        for edge in edges:
            lens.append(len(edge['utterances']))
            for e_utt in edge['utterances']:
                utt += e_utt + " "
                n_edges += 1
        res.append(utt)

        # node = [node for node in graph["nodes"] if node["id"]==edge['target']][0] 
    return res, n_edges, lens

In [9]:
def get_2_rerankings(generated1: list[str], golden1: list[str], generated2: list[str], golden2: list[str]):
    
    sz = len(generated1)
    to_score = []
    for gen in generated1:
        for gol in golden1:
            to_score.append((gen,gol))
    for gen in generated2:
        for gol in golden2:
            to_score.append((gen,gol))
    print("SCORING...")
    # print(to_score)
    scores = np.array(evaluator.score(to_score))
    print("finished")

    return scores[:sz*sz].reshape(sz,sz), scores[sz*sz:].reshape(sz,sz)

In [67]:
def get_2_rerankings(generated1: list[str], golden1: list[str], generated2: list[str], golden2: list[str]):
    
    sz = len(generated1)
    to_score = []
    for gen in generated1:
        for gol in golden1:
            to_score.append((gen,gol))
            to_score.append((gol,gen))
    for gen in generated2:
        for gol in golden2:
            to_score.append((gen,gol))
            to_score.append((gol,gen))
    print("SCORING...")
    # print(to_score)
    # scores = np.array(evaluator.score(to_score))
    scores = evaluator.score(to_score)
    print("finished")
    scores = np.array([min(s1,s2) for s1,s2 in zip(scores[::2],scores[1::2])])

    return scores[:sz*sz].reshape(sz,sz), scores[sz*sz:].reshape(sz,sz)

In [10]:
def compare_edge_lens(G1: BaseGraph, G2: BaseGraph, max: list):
    
    nodes_map = {}
    graph1 = G1.graph_dict
    graph2 = G2.graph_dict
    nodes1 = [n['id'] for n in graph1['nodes']]
    nodes2 = [n['id'] for n in graph2['nodes']]
    for idx,n in enumerate(nodes1):
        nodes_map[n] = nodes2[max[idx]]
    # nodes2 = [nodes_map[n] for n in nodes1]
    print("MAPS: ", nodes_map)

    for node1, node2 in zip(nodes1,[nodes_map[n] for n in nodes1]):
        print("NN: ", node1, node2)
        edges1 = G1.edge_by_source(node1)
        edges2 = G2.edge_by_source(node2)
        if len(edges1) != len(edges2):
            print("FF: ", edges1, edges2)
            return False
        for edge1 in edges1:
            for edge2 in edges2:
                if nodes_map[edge1['target']] == edge2['target'] and len(edge1['utterances']) != len(edge2['utterances']):
                    print(edge1, edge2)
                    return False
    return True


In [74]:
def llm_match(G1: BaseGraph, G2: BaseGraph) -> bool:
    g1 = G1.graph_dict
    g2 = G2.graph_dict

    # print("ORIG: ", g1)
    # g1_order = graph_order(g1)
    # g2_order = graph_order(g2)
    # print("ORDER: ", g1_order, "\n")
    # print("2LIST: ", graph2list(g1_order), "\n")
    #matrix = get_embedding(graph2list(g1_order), graph2list(g2_order), env_settings.EMBEDDER_MODEL, env_settings.EMBEDDER_DEVICE)

    nodes1_list = nodes2list(g1)
    nodes2_list = nodes2list(g2)
    if len(nodes1_list) != len(nodes2_list):
        print("FIRST")
        return False

    g1_list, n1, len1 = graph2list(g1)
    g2_list, n2, len2 = graph2list(g2)
    print("LEN1: ", len1, "LEN2: ", len2)
    # for idx,g in enumerate(zip(g1_list,g2_list)):
    #     print(idx, ": ", g[0])
    #     print("G2: ", g[1], "\n")

    nodes_matrix = get_embedding(nodes1_list, nodes2_list, env_settings.EMBEDDER_MODEL, env_settings.EMBEDDER_DEVICE)
    matrix = get_embedding(g1_list, g2_list, env_settings.EMBEDDER_MODEL, env_settings.EMBEDDER_DEVICE)

    # nodes_matrix, matrix = get_2_rerankings(nodes1_list, nodes2_list, g1_list, g2_list)
    nodes_max = list(np.argmax(nodes_matrix, axis=1))
    if len(set(nodes_max)) < len(nodes1_list):
        print("LLLLENS")
        return False


    # print("LENS: ", len1, len2)
    if n1 != n2:
        print("N!")
        return False
    

    # matrix = get_reranking(g1_list, g2_list)
    max = list(np.argmax(matrix, axis=1))
    print("N_MAX: ", nodes_max)
    print("MAX: ", max)
    if len(set(max)) < len(g1_list) or nodes_max != max:
        print("MIX", len(set(max)), len(g1_list), nodes_max)
        return False


    if not compare_edge_lens(G1, G2, max):
        print("LENS")
        return False
    print("NODES: ", np.min(np.max(nodes_matrix, axis=1)))
    print("ALL: ", np.min(np.max(matrix, axis=1)))



    if min(np.min(np.max(nodes_matrix, axis=1)),np.min(np.max(matrix, axis=1))) >= env_settings.SIM_THRESHOLD:
        return True
    # diags = get_diagonals(matrix)
    # # print("DIAGS: ", diags, "\n")
    # sums = np.sum(diags,axis=1)
    # max_index = np.argmax(sums)
    # g1_best = get_diagonal(g1,max_index)
    # min_value = np.min(diags[max_index])
    # print("MIN: ", min_value)
    # return True
    # print("\nG1: ", g1_best, "\n")
    # print("G2: ", g2_order, "\n")

    # if min_value >= env_settings.SIM_THRESHOLD:
    #     return True
    parser = PydanticOutputParser(pydantic_object=CompareResponse)
    format_model=ChatOpenAI(model=env_settings.FORMATTER_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL)
    model=ChatOpenAI(model=env_settings.COMPARE_MODEL_NAME, api_key=env_settings.OPENAI_API_KEY, base_url=env_settings.OPENAI_BASE_URL)
    new_parser = OutputFixingParser.from_llm(parser=parser, llm=format_model)
    result = call_llm_api(compare_graphs_prompt.format(result_form=result_form,graph_example_1=graph_example_1, graph_1=g1, graph_2=g2), model|new_parser, temp=0).model_dump()
    # print("RES: ", result)
    return result['result']



In [12]:
def read_json(path):
    with open(path, mode="r") as file:
        data = file.read()
    return json.loads(data)

In [13]:
dialogue_to_graph = read_json("generated_plus.json")
results = read_json("saved.json")["o1-mini"]["Feb_10"]

In [14]:
import logging
logging.getLogger("langchain_core.vectorstores.base").setLevel(logging.ERROR)

In [81]:
case = dialogue_to_graph[12]
result_graph = graph_generator.invoke([Dialogue.from_list(c["messages"]) for c in case["dialogues"]])

LISTS_N:  [(0, ["I can't complete the verification process."]), (1, ["Yes, it says 'Verification failed.'", "I don't want to troubleshoot right now."]), (2, ["I tried that, but it still doesn't work.", 'Actually, I already did that. What else can I do?', 'I already reset my password. What other options do I have?']), (3, ['Actually, I think my email is incorrect. Can I update it?', "Yes, it's user@example.com."]), (4, ["Sure, it's newuser@example.com."]), (5, ['No, thank you.']), (6, ['']), (7, ['No, thank you.']), (8, ["I cleared my cache and tried again, but it still doesn't work."]), (9, ['Yes, please send me the instructions.']), (10, ['']), (11, [''])]
LISTS:  [(0, "Hi! I'm sorry you're having trouble verifying your account. Can you describe what's happening?"), (1, "I'm sorry to hear that. Let's troubleshoot together. Can you tell me if you're receiving any error messages?"), (2, 'Thank you for the information. Please try resetting your password and attempt verification again.'),

In [79]:
is_same_structure(Graph(dialogue_to_graph[13]['graph']),result_graph)

True

In [82]:
llm_match(Graph(dialogue_to_graph[12]['graph']),result_graph)

LEN1:  [1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1] LEN2:  [1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1]
0 :  Hi! I'm sorry you're having trouble verifying your account. Can you describe what's happening? I can't complete the verification process. 
G2:  Great, everything seems in order. I will forward this issue to our support team for further assistance. Is there anything else I can help you with? Everything looks correct. I'll escalate this to our support team. Do you need help with anything else? No, thank you.  

1 :  I'm sorry to hear that. Let's troubleshoot together. Can you tell me if you're receiving any error messages? Yes, it says 'Verification failed.' I don't want to troubleshoot right now. 
G2:  Hi! I'm sorry you're having trouble verifying your account. Can you describe what's happening? I can't complete the verification process.  

2 :  Thank you for the information. Please try resetting your password and attempt verification again. I tried that, but it still doesn't work. Actually, I already

True

In [80]:
def remove_duplicated_edges(graph: dict):
    edges = graph['edges']
    couples = [(e['source'],e['target']) for e in edges]
    duplicates = [i for i in set(couples) if couples.count(i) > 1]
    print(duplicates)
    new_edges = []
    for d in duplicates:
        found = [c for c in edges if c['source'] == d[0] and c['target'] == d[1]]
        print("FOUND: ", found)
        new_edge = found[0]
        new_edge['utterances'] = []
        for e in found:
            new_edge['utterances'].extend(e['utterances'])
        new_edges.append(new_edge)
    new_graph = {"edges": [e for e in edges if (e['source'],e['target']) not in duplicates]+new_edges, "nodes": graph['nodes'] }
    return new_graph

In [40]:
result_graph.graph_dict

{'nodes': [{'id': 1,
   'label': '',
   'is_start': False,
   'utterances': ['Ad-blocking has been enabled in your browser settings. Would you like to customize the settings?',
    'Your browser is now blocking ads. Do you want to adjust the ad-blocking settings?']},
  {'id': 2,
   'label': '',
   'is_start': True,
   'utterances': ['Would you like to enable ad-blocking in your browser?']},
  {'id': 3,
   'label': '',
   'is_start': False,
   'utterances': ['Which settings would you like to adjust? You can block specific types of ads or set exceptions for certain websites.']},
  {'id': 4,
   'label': '',
   'is_start': False,
   'utterances': ['Alright, ad-blocking remains disabled. Let me know if you need anything else.']},
  {'id': 5,
   'label': '',
   'is_start': False,
   'utterances': ['Ad-blocking extension has been installed. Would you like to configure it?']},
  {'id': 6,
   'label': '',
   'is_start': False,
   'utterances': ['Which options would you like to configure in the 

In [39]:
dialogue_to_graph[11]['topic']

'Enabling ad-blocking features on a browser'

In [None]:
dialogue_to_graph[1]['graph']

In [None]:
dialogue_to_graph[7]['']

In [52]:
list(results[7].values())[0][0]

{'nodes': [{'id': 1,
   'label': '',
   'is_start': True,
   'utterances': ["I'm sorry you're experiencing issues with your checkout cart. Can you please describe the problem?",
    "Sorry to hear you're having trouble with your checkout cart. Could you please explain what's happening?"]},
  {'id': 2,
   'label': '',
   'is_start': False,
   'utterances': ['Great to hear! Is there anything else I can help you with?',
    "You're welcome! Is there anything else I can help you with?",
    'Glad to hear that! Do you need any further assistance?']},
  {'id': 3,
   'label': '',
   'is_start': False,
   'utterances': ['Please try refreshing the page first and let me know if the issue persists.',
    'Please try using Firefox and let me know if the issue persists.']},
  {'id': 4,
   'label': '',
   'is_start': False,
   'utterances': ['Have you tried using a different browser or device?',
    "Okay, have you cleared your browser's cache and cookies?"]},
  {'id': 5,
   'label': '',
   'is_star

In [91]:
dialogue_to_graph[6]['graph']

{'edges': [{'source': 1,
   'target': 2,
   'utterances': ['I need to report an incorrect product dimension listing.',
    "There's an issue with a product's dimensions."]},
  {'source': 1,
   'target': 6,
   'utterances': ["Never mind, I don't need help right now."]},
  {'source': 2, 'target': 3, 'utterances': ['The product ID is 12345.']},
  {'source': 2,
   'target': 4,
   'utterances': ["I'd prefer to give you the product name instead."]},
  {'source': 3,
   'target': 5,
   'utterances': ['The length and width are wrong.']},
  {'source': 3,
   'target': 7,
   'utterances': ["Actually, I provided the product number incorrectly. It's 67890."]},
  {'source': 4, 'target': 5, 'utterances': ['The height is incorrect.']},
  {'source': 5, 'target': 6, 'utterances': ["No, that's all."]},
  {'source': 5,
   'target': 2,
   'utterances': ['Actually, I need to change the product ID.']},
  {'source': 7, 'target': 5, 'utterances': ['The height is wrong.']}],
 'nodes': [{'id': 1,
   'label': 'Ass

In [81]:
cur_graph = remove_duplicated_edges(dialogue_to_graph[7]['graph'])

[(6, 11)]
FOUND:  [{'source': 6, 'target': 11, 'utterances': ['I want to try Safari instead.']}, {'source': 6, 'target': 11, 'utterances': ['I want to try Safari instead.']}]


In [82]:
cur_graph

{'edges': [{'source': 1,
   'target': 2,
   'utterances': ["My cart won't let me proceed to checkout."]},
  {'source': 1,
   'target': 3,
   'utterances': ["I'm getting an error when I try to add items to the cart.",
    'The checkout page is not loading.',
    'My cart keeps crashing when I try to add items.']},
  {'source': 2, 'target': 4, 'utterances': ["Yes, but it didn't work."]},
  {'source': 2, 'target': 5, 'utterances': ["No, I haven't tried that."]},
  {'source': 2,
   'target': 6,
   'utterances': ['Actually, can I change my browser?']},
  {'source': 4,
   'target': 7,
   'utterances': ["Yes, I cleared cache and cookies, but it still doesn't work."]},
  {'source': 4,
   'target': 5,
   'utterances': ["No, I haven't cleared cache and cookies."]},
  {'source': 4, 'target': 8, 'utterances': ['Can I change my device?']},
  {'source': 5, 'target': 7, 'utterances': ["It still doesn't work."]},
  {'source': 5, 'target': 9, 'utterances': ['It worked! Thank you.']},
  {'source': 3, 't

In [83]:
llm_match(Graph(cur_graph),result_graph)

LEN1:  [1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] LEN2:  [1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
SCORING...


finished
MAX:  [2, 3, 7, 11, 9, 4, 10, 12, 0, 8, 1, 5, 6]
MAPS:  {1: 3, 2: 4, 3: 8, 4: 12, 5: 10, 6: 5, 7: 11, 8: 13, 9: 1, 10: 9, 11: 2, 12: 6, 13: 7}
NN:  1 3
NN:  2 4
NN:  3 8
NN:  4 12
NN:  5 10
NN:  6 5
NN:  7 11
NN:  8 13
NN:  9 1
NN:  10 9
NN:  11 2
NN:  12 6
NN:  13 7
NODES:  0.9999747
ALL:  0.99996185


True

In [19]:
def match(index):
    return llm_match(Graph(dialogue_to_graph[index]['graph']),Graph(list(results[index].values())[0][0]))

In [26]:
match(6)

FIRST


False