In [24]:
from llama_index.core import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.node_parser import LangchainNodeParser
from bs4 import BeautifulSoup, NavigableString
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from selenium.webdriver.common.by import By
import ast
from selenium import webdriver
from llama_index.core.schema import NodeWithScore, QueryBundle

def generate_xpath(element, path=""): # used to generate dict nodes
    """ Recursive function to generate the xpath of an element """
    if element.parent is None:
        return path
    else:
        siblings = [sib for sib in element.parent.children if sib.name == element.name]
        if len(siblings) > 1:
            count = siblings.index(element) + 1
            path = f"/{element.name}[{count}]{path}"
        else:
            path = f"/{element.name}{path}"
        return generate_xpath(element.parent, path)
    
def add_xpath_attributes(html_content):
    """
    Add an 'xpath' attribute to each element in the HTML content with its computed XPath.
    """
    soup = BeautifulSoup(html_content, 'lxml')
    for element in soup.find_all(True):
        xpath = generate_xpath(element)
        element['xpath'] = xpath
    return str(soup)

def create_nodes_dict(html, only_body=True, max_length=200): # used to generate dict nodes
    ''' Create a list of xpaths and a list of dict of attributes of all elements in the html'''
    soup = BeautifulSoup(html, 'html.parser')
    if only_body:
        root = soup.body
    else:
        root = soup.html
    element_attributes_list = []
    stack = [root]  # stack to keep track of elements and their paths
    while stack:
        element = stack.pop()
        if element.name is not None:
            element_attrs = dict(element.attrs)
            direct_text_content = ''.join([str(content).strip() for content in element.contents if isinstance(content, NavigableString) and content.strip()])
            if direct_text_content:
                element_attrs['text'] = direct_text_content
                element_attrs['element'] = element.name
                for key in element_attrs:
                    if len(element_attrs[key]) > max_length:
                        element_attrs[key] = element_attrs[key][:max_length]
                element_attributes_list.append(element_attrs)
            elif element_attrs != {}:
                element_attrs['element'] = element.name
                for key in element_attrs:
                    if len(element_attrs[key]) > max_length:
                        element_attrs[key] = element_attrs[key][:max_length]
                element_attributes_list.append(element_attrs)
            for child in element.children:
                if child.name is not None:
                    stack.append(child)
    return element_attributes_list

def chunk_dicts(dicts, chunk_size=10):
    def chunks(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
    grouped_chunks = []
    for chunk in chunks(dicts, chunk_size):
        all_keys = set(key for d in chunk for key in d.keys()) 
        grouped = {key: [] for key in all_keys}
        for d in chunk:
            for key in all_keys:
                grouped[key].append(d.get(key, ''))  
        grouped_chunks.append(grouped)   
    return grouped_chunks

def unchunk_dicts(grouped_chunks):
    flat_list = []
    for group in grouped_chunks:
        max_length = max(len(v) for v in group.values())
        for i in range(max_length):
            new_dict = {}
            for key, values in group.items():
                if i < len(values):
                    if values[i] != '':
                        new_dict[key] = values[i]
            if new_dict:
                flat_list.append(new_dict)
    return flat_list

def clean_attributes(attributes_list, rank_fields): # used to generate dict nodes
    if rank_fields:
        rank_fields.append('xpath')
        attributes_list = [{k: v for k, v in d.items() if k in rank_fields} for d in attributes_list]
    attributes_list = [d for d in attributes_list if (not((len(list(d.keys()))==2) and (('element' in list(d.keys())) and 'xpath' in list(d.keys())))) or d=={}]
    return attributes_list

def get_results(query, html, embedder, top_n=5, group_by=10, rank_fields=None): # used to generate and retrieve dict nodes
    ''' Return the top_n elements of the html that are the most relevant to the query as Node objects with xpath in their metadata'''
    attributes_list = create_nodes_dict(html)
    assert group_by > 0
    #cleaning the attributes_list
    attributes_list = clean_attributes(attributes_list, rank_fields)
    #retrieving the top_n results
    
    list_of_results = []
    attributes_list = chunk_dicts(attributes_list, group_by)
    l = len(attributes_list)
    #grouping the attributes_list in groups of 1000 to avoid memory errors
    list_of_grouped_results = []
    for j in range(0, l, 1000):
        nodes = []
        attr = attributes_list[j:j+1000]
        for d in attr:
            xpath = d.pop('xpath')
            nodes.append(TextNode(text=str(d), metadata={'xpath': xpath}))
        index = VectorStoreIndex(nodes, embed_model=embedder)
        retriever = BM25Retriever.from_defaults(index = index, similarity_top_k=top_n)
        results = retriever.retrieve(query)
        list_of_grouped_results += results
    nodes = []
    for grouped_results in list_of_grouped_results:
        xpaths = grouped_results.metadata['xpath']
        ds = unchunk_dicts([ast.literal_eval(grouped_results.text)])
        assert len(xpaths) == len(ds)
        for xpath, d in zip(xpaths, ds):
            nodes.append(TextNode(text=str(d), metadata={'xpath': xpath}))
    l2 = len(nodes)
    for j in range(0, l2, 1000):
        index = VectorStoreIndex(nodes[j:j+1000], embed_model=embedder)
        retriever = BM25Retriever.from_defaults(index = index, similarity_top_k=top_n)
        results = retriever.retrieve(query)
        list_of_results += results
    list_of_results = sorted(list_of_results, key=lambda x: x.score, reverse=True)
    return list_of_results[:top_n]

def match_element(attributes, element_specs):
    i=0
    for spec in element_specs:
        if attributes['xpath'] == spec['xpath']:
            return i       
        i+=1
    return None

def return_nodes_with_xpath(nodes, results_dict, score):
    returned_nodes = []
    for node in nodes:
        split_html = node.text
        soup = BeautifulSoup(split_html, 'html.parser')
        for element  in soup.descendants:
            try:
                indice = match_element(element.attrs, results_dict)
                if indice is not None:
                    node.metadata['score'] = score[indice]
                    returned_nodes.append(node)
            except:
                pass
    return returned_nodes

def check_visibility(xpath, driver):
    ''' Check if an element is visible '''
    try:
        element = driver.find_element(By.XPATH, xpath)
        return element.is_displayed()
    except:
        return False
    

def get_nodes_sm(query, html, embedder, driver=None, top_n=5, group_by=10, rank_fields=['element', 'placeholder', 'text', 'name']):
    if isinstance(query, QueryBundle):
        query = query.query_str
    html = add_xpath_attributes(html)
    text_list = [html]
    documents = [Document(text=t) for t in text_list]
    splitter = LangchainNodeParser(lc_splitter=RecursiveCharacterTextSplitter.from_language(
            language="html",
        ))
    nodes = splitter.get_nodes_from_documents(documents)
    results = get_results(query, html, embedder=embedder, top_n=top_n, group_by=group_by, rank_fields = rank_fields)
    results_dict = [ast.literal_eval(r.text) for r in results]
    for i in range(len(results_dict)):
        results_dict[i]['xpath'] = results[i].metadata['xpath']
    score = [r.score for r in results]
    if driver:
        for r in results_dict:
            if not check_visibility(r['xpath'], driver):
                i = results_dict.index(r)
                results_dict.remove(r)
                score.pop(i)
    results_nodes = return_nodes_with_xpath(nodes, results_dict, score)
    results = [NodeWithScore(node=node, score=node.metadata['score']) for node in results_nodes] 
    return results

In [25]:
import requests
from llama_index.embeddings.openai import OpenAIEmbedding
import os

os.environ["OPENAI_API_KEY"] =  #to fill

html = requests.get('https://github.com').text
query = 'Click on start a free entreprise trial'
embedder = OpenAIEmbedding(model= "text-embedding-3-large")

In [26]:
nodes = get_nodes_sm(QueryBundle(query), html, embedder)

In [27]:
nodes

[NodeWithScore(node=TextNode(id_='fe93d482-8a2f-4a3c-bf53-75ce49ba31be', embedding=None, metadata={'score': 1.2890889229649445}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='9ce69d35-2620-48e6-a73d-d585d2975fd0', node_type=<ObjectType.DOCUMENT: '4'>, metadata={}, hash='8af57a21377c604c1c5b708a80f8f438ebaf961042270e0ce7a1b76813f89795'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='4257bbc9-b175-4675-826c-a0d1daeca369', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='9904dd3263b8c7b94cdc2acc3887d3b27ae97d96ef066b921878ae18bff3b240'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='feb65d8e-aa44-4ec3-a65e-f4473b8bea44', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='e793640a273c09ae695c4509dc95d929e220de6dbe67aae71a4a79ec2ad35980')}, text='<body class="logged-out env-production page-responsive header-overlay home-campaign" style="word-wrap: break-word;" xpath="/html/b