## Start with AstraDB Vector Store

In [None]:
from dotenv import load_dotenv
load_dotenv("/Users/eric.pinzur/src/github.com/langchain-ai/langchain-datastax/libs/astradb/.env")

In [None]:
import getpass
import os
import json
from glob import glob
from langchain_core.documents import Document
from langchain_astradb import AstraDBVectorStore
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_openai import OpenAIEmbeddings

In [None]:
os.environ["ASTRA_DB_API_ENDPOINT"] = getpass.getpass("ASTRA_DB_API_ENDPOINT = ")
os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass.getpass("ASTRA_DB_APPLICATION_TOKEN = ")
os.environ["OPENAI_API_KEY"] = getpass.getpass("OPENAI_API_KEY = ")

In [None]:
def debug_chunk(chunk: Document, header=None):
    if header is None:
        header = "Chunk id, metadata and text"
    print(f"{header}:\n\n'{chunk.id}'\n\n{json.dumps(chunk.metadata, indent=4)}\n\n\{chunk.page_content}\n\n")

In [None]:
headers_to_split_on = [
    ("#", "header_1"),
    ("##", "header_2"),
]

markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on=headers_to_split_on,
    return_each_line=False,
    strip_headers=False
)

file_paths = sorted(glob(pathname="datasets/legal/*.md"))

all_chunks = []
for file_index, file_path in enumerate(file_paths):
    with open(file_path, 'r') as file:
        chunks = markdown_splitter.split_text(file.read())

        for chunk_index, chunk in enumerate(chunks):
            chunk.metadata["file_path"] = file_path
            chunk.id = f"file_{file_index}_chunk_{chunk_index}"

            all_chunks.append(chunk)

print(f"Split the {len(file_paths)} files into {len(all_chunks)} chunks.\n")
debug_chunk(all_chunks[3], header="Example chunk id, metadata, and content")

In [None]:
vector_store = AstraDBVectorStore(
    collection_name="astra_graph_upgrade",
    embedding=OpenAIEmbeddings(),
    api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
    token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
)

added_ids = vector_store.add_documents(all_chunks)

In [None]:
query = """
What are the consequences if the Developer for
the AI-Powered Customer Support Tool fails to
meet the Phase 2 delivery date?
"""

retriever = vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={
        "k": 3,
    },
)
retrieved_chunks1 = retriever.invoke(query)
for chunk in retrieved_chunks1:
    debug_chunk(chunk)

In [None]:
for chunk in retrieved_chunks1:
    print(chunk.id)

## Upgrade to AstraDB Graph Vector Store

In [None]:
from langchain_astradb import AstraDBGraphVectorStore

import re
from langchain_community.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
from keyphrase_vectorizers import KeyphraseCountVectorizer

In [None]:
graph_vector_store = AstraDBGraphVectorStore(
    collection_name="astra_graph_upgrade",
    embedding=OpenAIEmbeddings(),
    api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
    token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
)

In [None]:
graph_retriever = graph_vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={
        "k": 3,
    },
)
retrieved_chunks2 = graph_retriever.invoke(query)
for chunk in retrieved_chunks2:
    debug_chunk(chunk)

In [None]:
for chunk1, chunk2 in zip(retrieved_chunks1, retrieved_chunks2):
    print(f"{chunk1.id}\t{chunk2.id}")

In [None]:
# Define the regex pattern
outgoing_section_pattern = r"(\d+\.\d+)\s+\*\*(.*?)\*\*"
incoming_internal_section_pattern = r"\*\*Section\s(\d+\.\d+)\*\*"
incoming_external_section_pattern1 = r"\*\*(.*?)\s\((.*?)\),\sSection\s(\d+\.\d+)\*\*"
incoming_external_section_pattern2 = r"\*\*Section\s(\d+\.\d+)\sof\sthe\s(.*?)\s\((.*?)\)\*\*"

keybert_link_extractor = KeybertLinkExtractor(
    extract_keywords_kwargs={
        "vectorizer": KeyphraseCountVectorizer(stop_words="english"),
        "use_mmr":True,
        "diversity": 0.7
    }
)

def get_links_for_chunk(chunk: Document) -> set[Link]:
    doc_title = chunk.metadata.get("header_1", "")

    links = keybert_link_extractor.extract_one(chunk)

    # find outgoing links
    for out_section in re.findall(outgoing_section_pattern, chunk.page_content):
        out_number = out_section[0]
        links.add(Link("section", direction="in", tag=f"{doc_title} {out_number}"))

    # find incoming links
    for in_number in re.findall(incoming_internal_section_pattern, chunk.page_content):
        links.add(Link("section", direction="out", tag=f"{doc_title} {in_number}"))

    for in_section1 in re.findall(incoming_external_section_pattern1, chunk.page_content):
        in_title1 = in_section1[0]
        in_abbreviation1 = in_section1[1]
        in_number1 = in_section1[2]
        links.add(Link("section", direction="out", tag=f"{in_title1} ({in_abbreviation1}) {in_number1}"))

    for in_section2 in re.findall(incoming_external_section_pattern2, chunk.page_content):
        in_number2 = in_section2[0]
        in_title2 = in_section2[1]
        in_abbreviation2 = in_section2[2]
        links.add(Link("section", direction="out", tag=f"{in_title2} ({in_abbreviation2}) {in_number2}"))

    return links

In [None]:
while True:
    updated_chunks = graph_vector_store.upgrade_chunks(link_function=get_links_for_chunk)
    if updated_chunks == 0:
        break

    print(f"Added links to {updated_chunks} chunks.")

print("Upgrade Complete!")

In [None]:
graph_retriever = graph_vector_store.as_retriever(
    search_type="mmr_traversal",
    search_kwargs={
        "k": 3,
        "fetch_k": 20, # initial starting chunks
        "depth": 2,
    },
)
retrieved_chunks3 = graph_retriever.invoke(query)
for chunk in retrieved_chunks3:
    debug_chunk(chunk)

In [None]:
for chunk1, chunk2, chunk3 in zip(retrieved_chunks1, retrieved_chunks2, retrieved_chunks3):
    print(f"{chunk1.id}\t{chunk2.id}\t{chunk3.id}")