In [14]:
from llama_index.core import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.node_parser import LangchainNodeParser
from bs4 import BeautifulSoup, NavigableString
import cohere
from selenium.webdriver.common.by import By
from selenium import webdriver
from llama_index.core.schema import NodeWithScore, QueryBundle

def generate_xpath(element, path=""): # used to generate dict nodes
    """ Recursive function to generate the xpath of an element """
    if element.parent is None:
        return path
    else:
        siblings = [sib for sib in element.parent.children if sib.name == element.name]
        if len(siblings) > 1:
            count = siblings.index(element) + 1
            path = f"/{element.name}[{count}]{path}"
        else:
            path = f"/{element.name}{path}"
        return generate_xpath(element.parent, path)
    
def add_xpath_attributes(html_content):
    """
    Add an 'xpath' attribute to each element in the HTML content with its computed XPath.
    """
    soup = BeautifulSoup(html_content, 'lxml')
    for element in soup.find_all(True):
        xpath = generate_xpath(element)
        element['xpath'] = xpath
    return str(soup)

def create_nodes_dict(html, only_body=True, max_length=200): # used to generate dict nodes
    ''' Create a list of xpaths and a list of dict of attributes of all elements in the html'''
    soup = BeautifulSoup(html, 'html.parser')
    if only_body:
        root = soup.body
    else:
        root = soup.html
    element_attributes_list = []
    stack = [root]  # stack to keep track of elements and their paths
    while stack:
        element = stack.pop()
        if element.name is not None:
            element_attrs = dict(element.attrs)
            direct_text_content = ''.join([str(content).strip() for content in element.contents if isinstance(content, NavigableString) and content.strip()])
            if direct_text_content:
                element_attrs['text'] = direct_text_content
                element_attrs['element'] = element.name
                for key in element_attrs:
                    if len(element_attrs[key]) > max_length:
                        element_attrs[key] = element_attrs[key][:max_length]
                element_attributes_list.append(element_attrs)
            elif element_attrs != {}:
                element_attrs['element'] = element.name
                for key in element_attrs:
                    if len(element_attrs[key]) > max_length:
                        element_attrs[key] = element_attrs[key][:max_length]
                element_attributes_list.append(element_attrs)
            for child in element.children:
                if child.name is not None:
                    stack.append(child)
    return element_attributes_list

def get_results(cohere, query, html, top_n=5, model="rerank-english-v3.0", rank_fields=None):
    attributes_list = create_nodes_dict(html)
    l = len(attributes_list)
    list_of_results = []
    for j in range(0, l, 1000):
        attr = attributes_list[j:j+1000]
        results = cohere.rerank(model=model, query=query, documents=attr, top_n=top_n, return_documents=True, rank_fields=rank_fields)
        results = [r.dict() for r in results.results]
        for r in results:
            r['index']+=j
        list_of_results += results
    list_of_results = sorted(list_of_results, key=lambda x: x['relevance_score'], reverse=True)
    return list_of_results[:top_n]

def match_element(attributes, element_specs):
    i=0
    for spec in element_specs:
        if attributes['xpath'] == spec['xpath']:
            return i       
        i+=1
    return None

def return_nodes_with_xpath(nodes, results_dict, score):
    returned_nodes = []
    for node in nodes:
        split_html = node.text
        soup = BeautifulSoup(split_html, 'html.parser')
        for element  in soup.descendants:
            try:
                indice = match_element(element.attrs, results_dict)
                if indice is not None:
                    node.metadata['score'] = score[indice]
                    returned_nodes.append(node)
            except:
                pass
    return returned_nodes

def check_visibility(xpath, driver):
    ''' Check if an element is visible '''
    try:
        element = driver.find_element(By.XPATH, xpath)
        return element.is_displayed()
    except:
        return False
    

def get_nodes_sm(cohere, query, html, driver=None,  top_n=5, model="rerank-english-v3.0", rank_fields=['element', 'placeholder', 'text', 'name']):
    if isinstance(query, QueryBundle):
        query = query.query_str
    html = add_xpath_attributes(html)
    text_list = [html]
    documents = [Document(text=t) for t in text_list]
    splitter = LangchainNodeParser(lc_splitter=RecursiveCharacterTextSplitter.from_language(
            language="html",
        ))
    nodes = splitter.get_nodes_from_documents(documents)
    results = get_results(cohere, query, html, top_n=top_n, model=model, rank_fields=rank_fields)
    results_dict = [r['document'] for r in results]
    score = [r['relevance_score'] for r in results]
    if driver:
        for r in results_dict:
            if not check_visibility(r['xpath'], driver):
                i = results_dict.index(r)
                results_dict.remove(r)
                score.pop(i)
    results_nodes = return_nodes_with_xpath(nodes, results_dict, score)
    results = [NodeWithScore(node=node, score=node.metadata['score']) for node in results_nodes] 
    return results