## Demo: OracleVS extension for no data duplication in vectorstore


Install following packages:

In [None]:
! pip install oracledb langchain_community langchain_ollama

### OracleVS ingestion & vector store retriever implementation

In [1]:
from langchain_community.document_loaders.oracleai import OracleDocLoader
from langchain_core.documents import Document
from langchain_community.document_loaders.oracleai import OracleTextSplitter
from langchain_core.documents import Document
from langchain_community.embeddings.oracleai import OracleEmbeddings
from langchain_community.document_loaders.oracleai import OracleDocLoader
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_ollama import OllamaEmbeddings
from langchain_ollama import OllamaLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.vectorstores.base import VectorStoreRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

import hashlib
import sys
import json
import oracledb
#import oraclevs
from typing import (
    List
)


from langchain_community.vectorstores import oraclevs
from langchain_community.vectorstores.oraclevs import OracleVS



In [None]:
def create_hash(input_string, hash_algorithm='sha256'):
    """
    Creates a hash from the given input string.

    Args:
        input_string (str): The string to be hashed.
        hash_algorithm (str, optional): The hashing algorithm to use. Defaults to 'sha256'.

    Returns:
        str: The hexadecimal representation of the hash.
    """
    # Choose the hashing algorithm
    if hash_algorithm == 'sha256':
        hash_object = hashlib.sha256()
    elif hash_algorithm == 'md5':
        hash_object = hashlib.md5()
    elif hash_algorithm == 'sha1':
        hash_object = hashlib.sha1()
    else:
        raise ValueError("Unsupported hash algorithm. Choose from 'sha256', 'md5', 'sha1'.")

    # Update the hash object with the bytes of the input string
    hash_object.update(input_string.encode('utf-8'))

    # Get the hexadecimal representation of the hash
    hash_hex = hash_object.hexdigest()

    return hash_hex

#### Ingestion function

In [3]:

def from_SQL(*args, **kwargs):
    """
    connection_load,    #source db
    sql,                #source db sql to get contents in a field
    embedding,
    connection,         # vector store connection
    table_name,         # vector store target table
    distance_strategy,
    """

    # custom code
    connection_load = kwargs["connection_load"]
    sql = kwargs["sql"]
    connection = kwargs["connection"]
    embedding =  kwargs["embedding"]
    table_name =  kwargs["table_name"]
    distance_strategy =  kwargs["distance_strategy"]

    batch_size = 1000

    documents=[]
    try:
        # Create a cursor object to execute SQL queries
        with connection_load.cursor() as cursor:
            # Execute a simple query to test the connection           
            cursor.execute(sql)
            # Get the column names
            columns = [desc[0] for desc in cursor.description]


            # Fetch the result
            while True:
                results = cursor.fetchmany(batch_size)
                if not results:
                    break
                for row in results:
                    row_dict = dict(zip(columns, row))
                    text = row_dict['TEXT']
                    ID = str(row_dict['ID'])
                    
                    if isinstance(text, oracledb.LOB):
                        text = text.read()
                    elif isinstance(text, str):
                    # text is already a string, no need to read
                        pass
                    else:
                    # Handle other types if necessary
                        text = str(text)

                    truncated_text = (text[:20] if text else '')+'...'
                    #print(f"ID: {row_dict['ID']}, TEXT: {truncated_text}")
                    documents.append(Document(
                        page_content=text,
                        metadata={
                            "id": ID,
                            "query":sql
                            }
                        )
                    )
        
        vectorstore = OracleVS.from_documents(
                    documents=documents,
                    embedding=embedding,
                    client=connection,
                    table_name=table_name,
                    distance_strategy = distance_strategy
                    )
        
        vectorstore.connection_load = connection_load  # Store the connection after vectorstore is created
        vectorstore.sql = sql  # Store the SQL query after vectorstore is created
        vectorstore.connection = connection


        #Update vector store with hash in each text        
        with connection.cursor() as cur:

            sql_update = f"""UPDATE "VECTOR"."{table_name}"  SET "TEXT" = TO_CHAR(ORA_HASH(DBMS_LOB.SUBSTR("TEXT", 4000, 1)))"""

            print(f"HASH UPDATE: \n{sql_update}")
              
            try:
                cur.execute(sql_update)
                connection.commit()
                print("Update with hash successful.")
            except oracledb.Error as e:
                print(f"An error occurred: {e}")
                # Rollback in case of an error
                connection.rollback()
                return None
       
        return vectorstore
    
    except oracledb.Error as e:
        print(f"Error occurred: {e}")
        return None  
     
OracleVS.from_SQL = from_SQL

#### Custom Retriever for referenced chunks

In [4]:
from pydantic import PrivateAttr

class CustomVectorStoreRetriever(VectorStoreRetriever):
    _connection_source: object = PrivateAttr()
    _connection: object = PrivateAttr()
    _sql: str = PrivateAttr()
    _table_name: str = PrivateAttr()
    _top_k: int = PrivateAttr()
    _vector_store: OracleVS = PrivateAttr()

    def __init__(self, vectorstore, search_type="similarity", search_kwargs=None, **kwargs):
        super().__init__(vectorstore=vectorstore, search_type=search_type, search_kwargs=search_kwargs)
        self._connection_source = vectorstore.connection_load  # Store the connection
        self._sql = vectorstore.sql  # Store the SQL query
        self._table_name = vectorstore.table_name
        self._connection = vectorstore.connection
        self._top_k = search_kwargs.get("k") if search_kwargs else None
        self._vector_store = vectorstore
        
    def get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs
    ) -> list[Document]:
        # Your custom implementation here
        return self.sql_get_relevant_documents(query, run_manager=run_manager, **kwargs)

    def sql_get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs) -> list[Document]:
    # Your custom implementation here
        kwargs = self.search_kwargs | kwargs
        document_list: list[Document] = []
        batch_size = 4
        with self._connection.cursor() as cursor:
            # Execute a simple query to test the connection   
            
            embedding: List[float] = []
            embedding = self.vectorstore.embedding_function.embed_query(query)
            embedding_arr = json.dumps(embedding)
            vector_store_query = f"""
            SELECT ID,TEXT, JSON_VALUE(METADATA, '$.id') as S_ID,
            VECTOR_DISTANCE(embedding,:embedding, {oraclevs._get_distance_function(self._vector_store.distance_strategy)}) as DISTANCE
            FROM {self._table_name} 
            ORDER BY DISTANCE ASC
            FETCH FIRST {self._top_k} ROWS ONLY
            """
            print(f"vector_store_query: {vector_store_query}")
           
            cursor.execute(vector_store_query,embedding=embedding_arr)
            
            print("\nCustomVectorStoreRetriever::sql_get_relevant_documents()")
            print(vector_store_query)
            print(self._table_name)
            columns = [desc[0] for desc in cursor.description]
            ids=[]
            #TO_BE_DONE: get from nearest based on distance metrics
            print("From the Vector store: retrieved nearest chunks")
            while True:
                results = cursor.fetchall()
                if not results:
                    break
                for row in results:
                    row_dict = dict(zip(columns, row))
                    text = row_dict['TEXT']
                    ID = str(row_dict['ID'])
                    S_ID = str(row_dict['S_ID'])
                
                    if isinstance(text, oracledb.LOB):
                        text = text.read()
                    elif isinstance(text, str):
                    # text is already a string, no need to read
                        pass
                    else:
                    # Handle other types if necessary
                        text = str(text)

                    truncated_text = (text[:20] if text else '')+'...'

            
                    ids.append(row_dict['S_ID'])
                    print(f"S_ID: {row_dict['S_ID']}, TEXT: {truncated_text}")
            
            # Get Text from the original one
            
            ids_str = ', '.join(f"'{id}'" for id in ids)
            ids_list = ids_str.replace("'", "").split(",")
            decode_order = "DECODE(ID, " + ",".join(f"'{val.strip()}', {i+1}" for i, val in enumerate(ids_list)) + ")"
            
            #query = "SELECT ID, TEXT FROM ("+self._sql + ')'+ f""" WHERE ID IN ({ids_str})"""  

            query = (
                        "SELECT ID, TEXT FROM (" + self._sql + ") "
                        f"WHERE ID IN ({ids_str}) "
                    f"ORDER BY {decode_order}"
            )
            
            print(query)

            with self._connection_source.cursor() as cursor:
                cursor.execute(query)
                print("From the Source Table: retrieved content")
                while True:
                    results = cursor.fetchall()
                    if not results:
                        break
                    for row in results:
                        row_dict = dict(zip(columns, row))
                        text = str(row_dict['TEXT'])
                        ID = str(row_dict['ID'])
                        truncated_text = (text[:20] if text else '')+'...'
                        print(f"ID: {row_dict['ID']}, TEXT: {truncated_text}")
                        #TO_BE_DONE: create an array of documents to be returned  
                        document_list.append(Document(
                            page_content=text,
                            metadata={"source": self._sql }
                        ))
            
        return document_list


  class CustomVectorStoreRetriever(VectorStoreRetriever):


## Test for RAG 
Use "Products" table from https://github.com/oracle-samples/db-sample-schemas/releases/tag/v23.3 to run this example

In [11]:
def rag(vs, query, user_question):
    # RAG usage in a LangChain Pipe
    top_k = 4
    rag_prompt = f"""
    You are an assistant for question-answering tasks, be concise. 
    Use the retrieved DOCUMENTS to answer the user input as accurately as possible. 
    Keep your answer grounded in the facts of the DOCUMENTS and reference the DOCUMENTS where possible. If there ARE DOCUMENTS, you should be able to answer.  
    If there are NO DOCUMENTS, respond only with 'I am sorry, but cannot find relevant sources.'
    """
    llm_model = "llama3.1"

    template = rag_prompt + """\n# DOCUMENTS :\n {context} \n""" + """\n # Question: {question} """

    prompt = PromptTemplate.from_template(template)

    retriever = CustomVectorStoreRetriever(vectorstore=vs, search_kwargs={"k": top_k})

    llm = OllamaLLM(model=llm_model, base_url=url)

    chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    answer = chain.invoke(user_question)
    return answer

In [12]:
def rag_debug(vs, query, user_question):
    # RAG usage in a LangChain Pipe
    top_k = 4
    rag_prompt = f"""
    You are an assistant for question-answering tasks, be concise. 
    Use the retrieved DOCUMENTS to answer the user input as accurately as possible. 
    Keep your answer grounded in the facts of the DOCUMENTS and reference the DOCUMENTS where possible. If there ARE DOCUMENTS, you should be able to answer.  
    If there are NO DOCUMENTS, respond only with 'I am sorry, but cannot find relevant sources.'
    """
    llm_model = "llama3.1"

    template = rag_prompt + """\n# DOCUMENTS :\n {context} \n""" + """\n # Question: {question} """

    prompt = PromptTemplate.from_template(template)

    retriever = CustomVectorStoreRetriever(vectorstore=vs, search_kwargs={"k": top_k})

    llm = OllamaLLM(model=llm_model, base_url=url)

    # --- Get retrieved documents ---
    docs = retriever.get_relevant_documents(user_question)
    context = "\n\n".join([d.page_content for d in docs])

    # --- Render final prompt string ---
    final_prompt = prompt.format(context=context, question=user_question)
    print("\n========== PROMPT SENT TO LLM ==========\n")
    print(final_prompt)
    print("\n========================================\n")

    # --- Run LLM as usual ---
    answer = llm.invoke(final_prompt)
    return answer


### Ingestion from SQL and RAG
The query:

```python
query= "SELECT product_id as ID,product_details AS TEXT FROM products"
```
must be defined with an:
- ID alias, to indentify the key
- TEXT alias to provide the text to be embedded

To create the Vector Store:

```python
OracleVS.from_SQL(connection_load=conn, sql=query,connection = conn23,embedding=embeddings,table_name=table_name_vs,distance_strategy=DistanceStrategy.DOT_PRODUCT)
```
where:
- connection_load: DB where stored original text chunks to embedd. (Any Oracle DB)
- sql: the sql to exectute on `connection_load`.
- connection: Vector Store DB connection (DB23ai) 
- embedding: LLM model to generate vector embeddings
- table_name: vector store table name  (will store only embeddings an references to original text)
- distance_strategy: as usual on Vector Store definition

In [None]:
username = "CO"
password = "********"
dsn = "localhost:1521/FREEPDB1"

username_vs = "vector"
password_vs = "******"
dsn_vs = "localhost:1521/FREEPDB1"
table_name_vs = "VECTOR_STORE_LINKED"

model="mxbai-embed-large"
url="http://localhost:11434"
embeddings = OllamaEmbeddings(model=model, base_url=url)

with oracledb.connect(user=username_vs, password=password_vs, dsn=dsn_vs) as conn23:
    with oracledb.connect(user=username, password=password, dsn=dsn) as conn:
        #query= "SELECT product_id as ID,product_details AS TEXT FROM products"
        query= "SELECT product_id as ID,product_name AS TEXT FROM products"
        vs = OracleVS.from_SQL(connection_load=conn, sql=query,connection = conn23,embedding=embeddings,table_name=table_name_vs,distance_strategy=DistanceStrategy.DOT_PRODUCT)
        print(f"table name stored: {vs.table_name}\n---------------------")
        print(f"embeddings dim: {vs.get_embedding_dimension()}")

        #user_question="Which IDE should be used in this demo?"
        user_question="I'm looking for a Pyjamas for my little daughter of a dark color"
        print("\n\nANSWERS: \n"+rag_debug(vs,query,user_question))


HASH UPDATE: 
UPDATE "VECTOR"."VECTOR_STORE_LINKED"  SET "TEXT" = TO_CHAR(ORA_HASH(DBMS_LOB.SUBSTR("TEXT", 4000, 1)))
Update with hash successful.
table name stored: VECTOR_STORE_LINKED
---------------------
embeddings dim: 1024
vector_store_query: 
            SELECT ID,TEXT, JSON_VALUE(METADATA, '$.id') as S_ID,
            VECTOR_DISTANCE(embedding,:embedding, DOT) as DISTANCE
            FROM VECTOR_STORE_LINKED 
            ORDER BY DISTANCE ASC
            FETCH FIRST 4 ROWS ONLY
            


  docs = retriever.get_relevant_documents(user_question)



CustomVectorStoreRetriever::sql_get_relevant_documents()

            SELECT ID,TEXT, JSON_VALUE(METADATA, '$.id') as S_ID,
            VECTOR_DISTANCE(embedding,:embedding, DOT) as DISTANCE
            FROM VECTOR_STORE_LINKED 
            ORDER BY DISTANCE ASC
            FETCH FIRST 4 ROWS ONLY
            
VECTOR_STORE_LINKED
From the Vector store: retrieved nearest chunks
S_ID: 40, TEXT: 1442013119...
S_ID: 21, TEXT: 351660006...
S_ID: 38, TEXT: 2953145956...
S_ID: 30, TEXT: 1832232836...
SELECT ID, TEXT FROM (SELECT product_id as ID,product_name AS TEXT FROM products) WHERE ID IN ('40', '21', '38', '30') ORDER BY DECODE(ID, '40', 1,'21', 2,'38', 3,'30', 4)
From the Source Table: retrieved content
ID: 40, TEXT: Girl's Pyjamas (Blac...
ID: 21, TEXT: Girl's Pyjamas (Whit...
ID: 38, TEXT: Girl's Pyjamas (Red)...
ID: 30, TEXT: Women's Pyjamas (Gre...



    You are an assistant for question-answering tasks, be concise. 
    Use the retrieved DOCUMENTS to answer the user input as accu