# How KG improves RAG ?
- Making the retrieval `structure-aware`
- Enabling `context augmentation`
- Enabling `fine-grained-access-control`
- Combining `vector + graph search`

In [None]:
import yaml, logging, sys, os
from pyvis.network import Network
from IPython.display import display
from llama_index.llms import AzureOpenAI
from llama_index.llm_predictor import LLMPredictor
from llama_index import set_global_service_context
from llama_index.graph_stores import Neo4jGraphStore
from llama_index.vector_stores import Neo4jVectorStore
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.query_engine import KnowledgeGraphQueryEngine

from llama_index import (
                        StorageContext,
                        VectorStoreIndex,
                        KnowledgeGraphIndex,
                        SimpleDirectoryReader, 
                        load_graph_from_storage,
                        load_index_from_storage,
                        ServiceContext,
                        PromptHelper
                        )

logging.basicConfig(
                    stream=sys.stdout, 
                    level=logging.INFO
                    )
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))


In [None]:
with open('cadentials.yaml') as f:
    credentials = yaml.load(f, Loader=yaml.FullLoader)

# Configuring LLMs

In [None]:
embedding_llm = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
llm=AzureOpenAI(
                model=credentials['AZURE_ENGINE'],
                api_key=credentials['AZURE_OPENAI_API_KEY'],
                deployment_name=credentials['AZURE_DEPLOYMENT_ID'],
                api_version=credentials['AZURE_OPENAI_API_VERSION'],
                azure_endpoint=credentials['AZURE_OPENAI_API_BASE']
                )
chat_llm = LLMPredictor(llm)

prompt_helper = PromptHelper(
                            num_output=256,
                            context_window=4096,
                            chunk_overlap_ratio=0.1,
                            chunk_size_limit=None
                            )

text_splitter = TokenTextSplitter(
                                separator=" ",
                                chunk_size=1024,
                                chunk_overlap=20,
                                backup_separators=["\n"]
                                )

service_context = ServiceContext.from_defaults(
                                                text_splitter=text_splitter,
                                                prompt_helper=prompt_helper,
                                                embed_model=embedding_llm,
                                                llm_predictor=chat_llm
                                                )

set_global_service_context(service_context)

neo4j_db = Neo4jVectorStore(
                            credentials['NEO4J_USERNAME'], 
                            credentials['NEO4J_PASSWORD'], 
                            credentials['NEO4J_URI'], 
                            384
                            )

neo4j_store = Neo4jGraphStore(
                            username=credentials['NEO4J_USERNAME'],
                            password=credentials['NEO4J_PASSWORD'],
                            url=credentials['NEO4J_URI'],
                            database='neo4j',
                            )


# Data Pipeline

In [None]:
documents = SimpleDirectoryReader("./data/new_papers").load_data()
print(len(documents))

# Method 01 : Graph Vector Index

### Build Index

In [None]:
if not os.path.exists('./db/06/method-01/vector/'):
    vec_index = VectorStoreIndex.from_documents(
                                                documents,
                                                service_context = service_context
                                                )
    vec_index.storage_context.persist(persist_dir='./db/06/method-01/vector/')
    print("Saving Vector Index ...")
else:
    storage_context_vector = StorageContext.from_defaults(persist_dir='./db/06/method-01/vector/')
    vec_index = load_index_from_storage(
                                        storage_context=storage_context_vector)
    print("Loading Vector Index ...")

In [None]:
if not os.path.exists('./db/06/method-01/graph/'):
    storage_context_graph = StorageContext.from_defaults(graph_store=neo4j_db)
    
    graph_index = VectorStoreIndex.from_documents(
                                                documents, 
                                                storage_context=storage_context_graph
                                                )
    
    graph_index.storage_context.persist(persist_dir='./db/06/method-01/graph/')
    print("Saving Graph Index ...")
else:
    storage_context_graph = StorageContext.from_defaults(
                                                        graph_store=neo4j_db,
                                                        persist_dir='./db/06/method-01/graph/'
                                                        )
    graph_index = load_index_from_storage(storage_context=storage_context_graph)
    print("Loading Graph Index ...")

### Querying

In [None]:
query_engine_vector = vec_index.as_query_engine()
query_engine_graph = graph_index.as_query_engine()

In [None]:
query = "What is ToolFormer ?"

In [None]:
response_vector = str(query_engine_vector.query(query))
response_graph = str(query_engine_graph.query(query))

In [None]:
print("vector db response : {}".format(response_vector))
print("graph db response : {}".format(response_graph))

# Method 02 : Text2Cypher (Knowledge Graph Index)

In [None]:
if not os.path.exists('./db/06/method-02/'):
    storage_context = StorageContext.from_defaults(graph_store=neo4j_store)
    kg_index = KnowledgeGraphIndex.from_documents( 
                                            # tags=tags,
                                            documents=documents,
                                            max_triplets_per_chunk=10,
                                            service_context=service_context,
                                            storage_context=storage_context,
                                            # space_name=space_name,
                                            # edge_types=edge_types,
                                            # rel_prop_names=rel_prop_names,
                                            include_embeddings=True,
                                            verbose=True
                                            )
    
    kg_index.storage_context.persist(persist_dir='./db/06/method-02/')

else:
    storage_context = StorageContext.from_defaults(
                                                    graph_store=neo4j_store,
                                                    persist_dir='./db/06/method-02/'
                                                    )
    kg_index = load_graph_from_storage(                                            # tags=tags,
                                        documents=documents,
                                        max_triplets_per_chunk=10,
                                        service_context=service_context,
                                        storage_context=storage_context,
                                        # space_name=space_name,
                                        # edge_types=edge_types,
                                        # rel_prop_names=rel_prop_names,
                                        include_embeddings=True,
                                        verbose=True
                                        )

In [None]:
nl2kg_query_engine = KnowledgeGraphQueryEngine(
                                            storage_context=storage_context,
                                            service_context=service_context,
                                            llm=chat_llm,
                                            )

In [None]:
%ngql SHOW HOSTS