In [None]:
import re
import uuid
from typing import Generator, List

from bs4 import BeautifulSoup, Doctype, NavigableString, Tag

from langchain.storage import InMemoryByteStore
from langchain_core.documents import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_community.document_loaders import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.utils.html import PREFIXES_TO_IGNORE_REGEX
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate


In [None]:
%env OPENAI_API_KEY=

In [None]:
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")

In [None]:
def metadata_extractor(html: str, url: str) -> dict:
    soup = BeautifulSoup(html, 'lxml')
    
    title = soup.find("title")
    description = soup.find("meta", attrs={"name": "description"})
    html = soup.find("html")

    title_paragraphs = soup.find_all('p', class_='StylePalatino24ptBoldCentered')
    doc_metas = soup.find_all('p', class_='StylePalatino10ptBoldCentered')
    
    real_title = "XrootD "
    release_date = ""
    software_version = ""
    author = ""
    # Loop through the found elements and print their text if they are not empty
    for p in title_paragraphs:
        if p.text.strip() and not p.text.strip().isspace():
           real_title += p.text.strip()

    doc_meta_ct = 0
    for p in doc_metas:
        if p.text.strip() and not p.text.strip().isspace():
            if doc_meta_ct == 0:
                release_date = p.text.strip()
            elif doc_meta_ct == 1:
                software_version = p.text.strip()
            elif doc_meta_ct == 2:
                author = p.text.strip()
            doc_meta_ct += 1

    h1_tag = soup.find('h1')
    introduction = ""
    if h1_tag is not None:
        # Find the first non-empty <p> following the <h1>
        current_tag = h1_tag.find_next_sibling()

        while current_tag:
            if current_tag.name == 'p' and current_tag.text.strip():
                break
            current_tag = current_tag.find_next_sibling()
        
        introduction = current_tag.text.strip() if current_tag is not None else ""

    return {
        "source": url,
        "title": real_title if title else title.get_text() if title else "",
        "documentation_release_date": release_date,
        "xrootd_software_version": software_version,
        "documentation_author": author,
        "description": introduction if introduction else description.get("content", "") if description else "",
        "language": html.get("lang", "en") if html else "en",
    }

SUFFIXES_TO_IGNORE = (
    ".css",
    ".js",
    ".ico",
    ".png",
    ".jpg",
    ".jpeg",
    ".gif",
    ".svg",
    ".csv",
    ".bz2",
    ".zip",
    ".epub",
    ".pdf",
    ".pptx"
)
SUFFIXES_TO_IGNORE_REGEX = (
    "(?!" + "|".join([re.escape(s) + r"[\#'\"]" for s in SUFFIXES_TO_IGNORE]) + ")"
)

def langchain_docs_extractor(html: str) -> str:
    soup = BeautifulSoup(html, "lxml")
    # Remove all the tags that are not meaningful for the extraction.
    SCAPE_TAGS = ["nav", "footer", "aside", "script", "style"]
    [tag.decompose() for tag in soup.find_all(SCAPE_TAGS)]
    
    classes_to_remove = ['MsoToc1', 'MsoToc2', 'MsoToc3', 'MsoToc4']

    # Find and decompose all <p> elements with specified classes
    for class_name in classes_to_remove:
        for element in soup.find_all('p', class_=class_name):
            element.decompose()    

    def get_text(tag: Tag) -> Generator[str, None, None]:
        for child in tag.children:
            if isinstance(child, Doctype):
                continue

            if isinstance(child, NavigableString):
                yield child
            elif isinstance(child, Tag):
                if child.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
                    yield f"{'#' * int(child.name[1:])} {child.get_text()}\n\n"
                elif child.name == "a":
                    yield f"[{child.get_text(strip=False)}]({child.get('href')})"
                elif child.name == "img":
                    yield f"![{child.get('alt', '')}]({child.get('src')})"
                elif child.name in ["strong", "b"]:
                    yield f"**{child.get_text(strip=False)}**"
                elif child.name in ["em", "i"]:
                    yield f"_{child.get_text(strip=False)}_"
                elif child.name == "br":
                    yield "\n"
                elif child.name == "code":
                    parent = child.find_parent()
                    if parent is not None and parent.name == "pre":
                        classes = parent.attrs.get("class", "")

                        language = next(
                            filter(lambda x: re.match(r"language-\w+", x), classes),
                            None,
                        )
                        if language is None:
                            language = ""
                        else:
                            language = language.split("-")[1]

                        lines: list[str] = []
                        for span in child.find_all("span", class_="token-line"):
                            line_content = "".join(
                                token.get_text() for token in span.find_all("span")
                            )
                            lines.append(line_content)

                        code_content = "\n".join(lines)
                        yield f"```{language}\n{code_content}\n```\n\n"
                    else:
                        yield f"`{child.get_text(strip=False)}`"

                elif child.name == "p":
                    yield from get_text(child)
                    yield "\n\n"
                elif child.name == "ul":
                    for li in child.find_all("li", recursive=False):
                        yield "- "
                        yield from get_text(li)
                        yield "\n\n"
                elif child.name == "ol":
                    for i, li in enumerate(child.find_all("li", recursive=False)):
                        yield f"{i + 1}. "
                        yield from get_text(li)
                        yield "\n\n"
                elif child.name == "div" and "tabs-container" in child.attrs.get(
                    "class", [""]
                ):
                    tabs = child.find_all("li", {"role": "tab"})
                    tab_panels = child.find_all("div", {"role": "tabpanel"})
                    for tab, tab_panel in zip(tabs, tab_panels):
                        tab_name = tab.get_text(strip=True)
                        yield f"{tab_name}\n"
                        yield from get_text(tab_panel)
                elif child.name == "div" and child.attrs.get("style","") == "border:solid windowtext 1.0pt;padding:1.0pt 4.0pt 1.0pt 4.0pt":
                    # xrootd codeblock style
                    code = "".join(c.text +"\n"  for c in child.contents)
                    yield f"```\n{code}\n```"
                    yield "\n\n"
                    pass
                elif child.name == "table":
                    yield "[table]"
                    thead = child.find("thead")
                    header_exists = isinstance(thead, Tag)
                    if header_exists:
                        headers = thead.find_all("th")
                        if headers:
                            yield "| "
                            yield " | ".join(header.get_text() for header in headers)
                            yield " |\n"
                            yield "| "
                            yield " | ".join("----" for _ in headers)
                            yield " |\n"

                    tbody = child.find("tbody")
                    tbody_exists = isinstance(tbody, Tag)
                    if tbody_exists:
                        for row in tbody.find_all("tr"):
                            yield "| "
                            yield " | ".join(
                                cell.get_text(strip=True) for cell in row.find_all("td")
                            )
                            yield " |\n"
                    else:
                        first_row = child.find("tr")
                        headers = first_row.find_all("td")
                        yield "| "
                        yield " | ".join(header.get_text(strip=True) for header in headers)
                        yield " |\n"
                        yield "| "
                        yield " | ".join("----" for _ in headers)
                        yield " |\n"
                        data_rows = first_row.find_next_siblings("tr")                        
                        for row in data_rows:
                            yield "| "
                            yield " | ".join(
                                cell.get_text(strip=True) for cell in row.find_all("td")
                            )
                            yield " |\n"
                    yield "\n\n[table]"
                elif child.name in ["button"]:
                    continue
                else:
                    yield from get_text(child)

    joined = "".join(get_text(soup)).replace('\xa0', '').replace('\r\n', ' ').replace('****', '')
    return re.sub(r"\n\n+", "\n\n", joined).strip()

def simple_extractor(html: str) -> str:
    soup = BeautifulSoup(html, "lxml")
    # List of classes to remove
    classes_to_remove = ['MsoToc1', 'MsoToc2', 'MsoToc3', 'MsoToc4']

    # Find and decompose all <p> elements with specified classes
    for class_name in classes_to_remove:
        for element in soup.find_all('p', class_=class_name):
            element.decompose()     
   
    return re.sub(r'[\s]*\n+[\s]*', '\n', soup.text).strip().replace('\xa0', '')

# Load, chunk and index the contents of the blog.
loader = RecursiveUrlLoader(
    url="https://xrootd.slac.stanford.edu/docs.html",
    # url="https://xrootd.slac.stanford.edu/doc/dev55/xrd_config.htm",
    max_depth=2,
    metadata_extractor=metadata_extractor,
    prevent_outside=True,
    use_async=False,
    timeout=600,
    base_url="https://xrootd.slac.stanford.edu/",
    # Drop trailing / to avoid duplicate pages.
    link_regex=(
        f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)"
        r"(?:[\#'\"]|\/[\#'\"])"
    ),
    extractor=langchain_docs_extractor,
    check_response_status=True,
)
docs = loader.load()
docs
len(docs)

In [None]:
headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    ("###", "Header 3"),
    ("####", "Header 4"),
]

markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
md_header_splits: List[Document] = list()
for doc in docs:
    splitted = markdown_splitter.split_text(doc.page_content)    
    for idx, split in enumerate(splitted):
        splitted[idx].metadata = doc.metadata.copy() | split.metadata
    md_header_splits.extend(splitted)

In [None]:
table_splits: List[Document] = []
for md_doc in md_header_splits:
  splitted = re.split(re.escape("[table]"), md_doc.page_content)
  for split in splitted:
    new_doc = Document(page_content=split, metadata=md_doc.metadata.copy())
    table_splits.append(new_doc)


In [None]:
code_splits: List[Document] = []
for table_doc in table_splits:
    # Define the pattern to capture the ``` separator
    pattern = re.compile(r'(```.*?```)|(.+?)(?=`{3}|\Z)', re.DOTALL)
    
    # Find all matches of the pattern in the text
    matches = pattern.findall(table_doc.page_content)
    
    # Combine the matches to form the final result
    result = ["".join(match) for match in matches]
    
    for split in result:
      new_doc = Document(page_content=split, metadata=table_doc.metadata.copy())
      code_splits.append(new_doc)

In [None]:
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=400)
final_splits = text_splitter.split_documents(table_splits)

This is to summarize each document and store them as a side vector

In [None]:
chain = (
    {"doc": lambda x: x.page_content}
    | ChatPromptTemplate.from_template("Summarize the following document:\n\n{doc}")
    | ChatOpenAI(max_retries=0)
    | StrOutputParser()
)

In [None]:
summaries = chain.batch(final_splits, {"max_concurrency": 10})

In [None]:
# The vectorstore to use to index the child chunks
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())

# The storage layer for the parent documents
store = InMemoryByteStore()
id_key = "doc_id"
# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    byte_store=store,
    id_key=id_key,
)
doc_ids = [str(uuid.uuid4()) for _ in final_splits]

In [None]:
summary_docs = [
    Document(page_content=s, metadata=final_splits[i].metadata.copy()|{id_key: doc_ids[i]})
    for i, s in enumerate(summaries)
]

In [None]:
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, docs)))

Add original chunks to the vectorstore

In [None]:
for i, doc in enumerate(final_splits):
    doc.metadata[id_key] = doc_ids[i]
retriever.vectorstore.add_documents(final_splits)

In [None]:
retrieved = vectorstore.max_marginal_relevance_search("How can I configure xrootd?")
retrieved

This is the regular embedding

In [None]:
vectorstore = FAISS.from_documents(documents=final_splits, embedding=OpenAIEmbeddings())

In [None]:
vectorstore.save_local("../embeddings")

In [None]:
query = "How can I configure xrootd?"
retrieved = vectorstore.max_marginal_relevance_search(query,k=3, fetch_k=3)
retrieved

In [None]:

# Retrieve and generate using the relevant snippets of the blog.
retriever = vectorstore.as_retriever(search_type="mmr")
# prompt = hub.pull("rlm/rag-prompt")

template = """You are an assistant for question-answering tasks.
You will be asked questions about XRootD, or eXtended Request Daemon.
Use the following pieces of retrieved context to answer the question.
Use your existing knowledge to answer the question if the retrieved context
does not contain useful information.
Do not mention the context to the user. Do now let user know you are given a retrieved context.
If you don't know the answer or nothing is provided in the context, just say: Sorry, I don't know.

{context}

Question: {question}

Answer:"""
custom_rag_prompt = PromptTemplate.from_template(template)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | custom_rag_prompt
    | llm
    | StrOutputParser()
)

In [None]:
rag_chain.invoke("What is oss in xrootd?")