# RAG Milvus Storage

Building on the 2 previous examples we will now add a storage layer to our RAG model. We will use Milvus to store the embeddings of the documents and the queries. We will also use Milvus to perform the similarity search.

If you need to setup Milvus there's an included `docker-compose.yml` file that will start a Milvus instance with the required configuration. You can start it with `docker-compose up`.  It's based on the official Milvus `docker-compose.yml` file with an updated Milvus version to match the client in the `requirements.txt` file.

## Usage

Set the variables below.

* `os.environ['ANTHROPIC_API_KEY']` - Your API key from the [Anthropic API](https://www.anthropic.com/).
* `ANTRHOPIC_MODEL_ID` - The model ID from the [Anthropic API](https://www.anthropic.com/).
* `EMBEDDING_MODEL` - The model to use for the embeddings.
* `MILVUS_COLLECTION_ID` - The Milvus collection to store the embeddings.
* `MILVUS_HOST` - The Milvus host.
* `MILVUS_PORT` - The Milvus port.
* `DIMENSIONS` - The number of dimensions of the embeddings.
* `DOCUMENT_LIST` - Path to the documents to parse.

In [None]:
import os

ANTHROPIC_MODEL_ID = "claude-3-5-sonnet-latest"
#os.environ['ANTHROPIC_API_KEY'] = ''
EMBEDDING_MODEL = 'BAAI/bge-large-en-v1.5'
MILVUS_COLLECTION_ID = 'rag_collection'
MILVUS_HOST = 'localhost'
MILVUS_PORT = 19530
DIMENSIONS = 1024

DOCUMENT_LIST = [

]

Import required libraries


In [None]:
import anthropic
import openai
from typing import Optional
import os
import re
import json
from typing import List, Dict, Tuple
import string
import numpy as np
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import markdown
import tempfile
from sphinx.application import Sphinx
from pymilvus import (
    connections,
    Collection,
    CollectionSchema,
    FieldSchema,
    DataType,
    utility
)
import numpy as np
import torch
from sentence_transformers import SentenceTransformer

The `LocalEmbeddingGenerator` class is significantly different from the previous examples.  We've updated the class to use `SentenceTransformers` to generate the embeddings.

In [None]:
class LocalEmbeddingGenerator:
    """
    A class to generate embeddings for input text using a specified model.
    Attributes:
    -----------
    model : SentenceTransformer
        The model used for generating embeddings.
    _dimensions : int
        The dimensions of the embeddings generated by the model.
    Methods:
    --------
    generate_embedding(text):
        Generates embeddings for the input text.
    dimensions:
        Returns the dimensions of the embeddings.
    """
    def __init__(self, model=EMBEDDING_MODEL):
        # Use BAAI model for 1024 dimensions
        self.model = SentenceTransformer(model)
        self._dimensions = DIMENSIONS
    
    def generate_embedding(self, text):
        """Generate embeddings for input text"""
        embeddings = self.model.encode(text, normalize_embeddings=True)
        return embeddings.tolist()
    
    @property
    def dimensions(self):
        """Return embedding dimensions for collection validation"""
        return self._dimensions

The `RAGPromptGenerator`, `LLMInterface`, and `ClaudeInterface` classes are the same as the previous examples.  

In [None]:
class RAGPromptGenerator:
    def __init__(self, rag_system):
        self.rag = rag_system
    
    def generate_prompt(self, query: str, system_prompt: str = None) -> str:
        """Generate a prompt for the LLM using retrieved context."""
        context = self.rag.generate_context(query)
        
        prompt = f"""
        System: {system_prompt or 'You are a helpful AI assistant. Use the provided context to answer questions accurately. If the context does not contain relevant information, say so.'}
        
        Context:
        {context}
        
        Human: {query}
        
        Assistant: Based on the provided context, I'll help answer your question.
        """
        print(prompt)
        return prompt.strip()

class LLMInterface:
    """Base class for LLM interactions"""
    def generate_response(self, prompt: str) -> str:
        raise NotImplementedError

class ClaudeInterface(LLMInterface):
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
        self.client = anthropic.Anthropic(api_key=self.api_key)
    
    def generate_response(self, prompt: str) -> str:
        try:
            message = self.client.messages.create(
                model=ANTHROPIC_MODEL_ID,
                max_tokens=1000,
                messages=[{
                    "role": "user",
                    "content": prompt
                }]
            )
            return message.content[0].text
        except Exception as e:
            print(f"Error generating response from Claude: {str(e)}")
            return ""

`MilvusStorage` is a an updated `RAGSystem` class that uses Milvus to store the embeddings.  When the class is initialized, it will create a connection to a Milvus server and create the collection to store embeddings if it doesn't exist.

The `create_collection` method is used to create the collection in Milvus.  The schema to our database is configured here.  We store a document id, the embeddings as a vector of floating point numbers, the document text, and the metadata for the document.

Similar to before, the `add_document_with_chunking` method is used to generate embeddings and store the documents. The difference is that we are storing the embeddings in Milvus instead of in memory with storage to a local file.

`query_similar` allows you to query the Milvus collection for similar documents based on a query string.  The method will generate the embeddings for the query string and then search the Milvus collection for similar embeddings.

`generate_context` 

In [None]:
class MilvusRAGSystem:
    def __init__(self, collection_name="documents"):
        self.collection_name = collection_name
        
        # Connect to Milvus
        connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)

        self.embedding_generator = LocalEmbeddingGenerator()
        
        # Initialize collection
        if not utility.has_collection(collection_name):
            self.create_collection(collection_name)
        self.collection = Collection(name=collection_name)
        
        # Ensure collection is loaded before operations
        self.ensure_collection_loaded()

    def create_collection(self, collection_name):
        # Define schema
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),  # Primary key field
            FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=DIMENSIONS),
            FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=65355),
            FieldSchema(name="metadata", dtype=DataType.JSON)
        ]
        schema = CollectionSchema(fields, description="Document collection")
        
        # Create collection
        collection = Collection(name=collection_name, schema=schema)
        collection.create_index(field_name="embeddings", index_params={"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 128}})
        collection.load()

    def ensure_collection_loaded(self):
        """Ensure collection is loaded before operations"""
        if not utility.has_collection(self.collection_name):
            self.create_collection(self.collection_name)
        if not utility.load_state(self.collection_name):
            self.collection.load()

    def add_document_with_chunking(self, text, metadata, chunk_size=512):
        # Split the text into chunks
        chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
        
        documents = []
        
        for chunk in chunks:
            # Get embeddings for each chunk
            embeddings = self.embedding_generator.generate_embedding(chunk)  # Implement this based on your embedding model
            
            # Create document entry for each chunk
            document = {
                "document": chunk,
                "embeddings": embeddings,
                "metadata": metadata
            }
            
            documents.append(document)
        
        # Insert all chunked documents into Milvus
        self.collection.insert(documents)

    def query_similar(self, query_text, top_k=5):
        # Get query embeddings
        query_embedding = self.embedding_generator.generate_embedding(query_text)
        
        # Search in Milvus
        search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
        results = self.collection.search(
            data=[query_embedding],
            anns_field="embeddings",
            param=search_params,
            limit=top_k,
            output_fields=["document", "metadata"]
        )
        
        return [
            {
                "document": hit.entity.get("document"),
                "metadata": hit.entity.get("metadata"),
                "score": hit.score
            }
            for hit in results[0]
        ]
    
    def generate_context(self, query: str, num_results: int = 3) -> str:
        try:
            self.ensure_collection_loaded()  # Ensure loaded before search
            
            # Generate query embedding
            query_embedding = self.embedding_generator.generate_embedding(query)
            
            # Search collection
            search_results = self.collection.search(
                data=[query_embedding],
                anns_field="embeddings",
                param={"metric_type": "COSINE", "params": {"nprobe": 10}},
                limit=num_results,
                output_fields=["document", "metadata"]
            )
            
            if not search_results or len(search_results[0]) == 0:
                return "No relevant context found."
            
            # Format context from results
            context_parts = []
            for hit in search_results[0]:
                doc = hit.entity.get('document')  # Correct usage of get method
                score = hit.score
                metadata = hit.entity.get('metadata')  # Correct usage of get method
                source = metadata.get('source', 'Unknown source')
                
                context_part = f"""
Source: {source}
Relevance: {score:.3f}
Content: {doc}
---"""
                context_parts.append(context_part)
            
            return "\n".join(context_parts)
        
        except Exception as e:
            print(f"Search error: {e}")
            return "Error retrieving context. Please ensure collection is properly initialized."


`RAGWithLLM` is the same as the previous examples.

`RAGConversation` enables enhanced conversation capabilities.  The class maintains a history of the conversation and includes it into the prompt when generating subsequent responses.

In [None]:
class RAGWithLLM:
    def __init__(self, rag_system, llm_interface: LLMInterface):
        self.rag = rag_system
        self.llm = llm_interface
        self.prompt_generator = RAGPromptGenerator(rag_system)
    
    def query(self, user_query: str, system_prompt: Optional[str] = None) -> str:
        # Generate RAG-enhanced prompt
        enhanced_prompt = self.prompt_generator.generate_prompt(
            user_query,
            system_prompt
        )
        
        # Get LLM response
        response = self.llm.generate_response(enhanced_prompt)
        
        return response

# Advanced usage with conversation history
class RAGConversation:
    def __init__(self, rag_with_llm: RAGWithLLM):
        self.rag_with_llm = rag_with_llm
        self.conversation_history = []
    
    def add_message(self, role: str, content: str):
        self.conversation_history.append({
            "role": role,
            "content": content
        })
    
    def query(self, user_query: str) -> str:
        # Add user query to history
        self.add_message("user", user_query)
        
        # Generate context-aware prompt including conversation history
        context = "\n".join([
            f"{msg['role']}: {msg['content']}"
            for msg in self.conversation_history[-5:]  # Last 5 messages
        ])
        
        # Get response
        response = self.rag_with_llm.query(
            user_query,
            system_prompt=f"""
            Consider the following conversation history:
            {context}
            
            Provide a response that maintains context and continuity.
            """
        )
        
        # Add response to history
        self.add_message("assistant", response)
        
        return response

Parsing and loading data.  If we're going to store data in a real database we should process some actual data.  `parse_html`, `parse_markdown`, and `parse_restructuredtext` are used to parse the data from the respective file types.  `load_data` is used to load the data into the Milvus collection.

In [None]:
def parse_html(html_content):
    soup = BeautifulSoup(html_content, 'html.parser')
    return soup.get_text()

def parse_markdown(md_content):
    html_content = markdown.markdown(md_content)
    soup = BeautifulSoup(html_content, 'html.parser')
    return soup.get_text()

def parse_restructuredtext(rst_content):
    # Create a temporary directory for Sphinx to work in
    with tempfile.TemporaryDirectory() as temp_dir:
        source_dir = os.path.join(temp_dir, 'source')
        build_dir = os.path.join(temp_dir, 'build')
        os.makedirs(source_dir)
        
        # Write the reStructuredText content to an index.rst file
        with open(os.path.join(source_dir, 'index.rst'), 'w') as f:
            f.write(rst_content)
        
        # Configure Sphinx
        conf = {
            'extensions': [],
            'master_doc': 'index',
            'html_theme': 'default',
            'exclude_patterns': [],
        }
        
        # Build the HTML using Sphinx
        app = Sphinx(
            srcdir=source_dir,
            confdir=None,
            outdir=build_dir,
            doctreedir=os.path.join(build_dir, 'doctrees'),
            buildername='html',
            confoverrides=conf,
        )
        app.build()
        
        # Read the generated HTML
        with open(os.path.join(build_dir, 'index.html'), 'r') as f:
            html_content = f.read()
    
    # Parse the HTML content with BeautifulSoup
    soup = BeautifulSoup(html_content, 'html.parser')
    return soup.get_text()


def load_documents(document_list: List[str]) -> List[Dict]:
    documents = []
    for doc in document_list:
        if doc.endswith('.html'):
            with open(doc, 'r') as f:
                html_content = f.read()
                text = parse_html(html_content)
        elif doc.endswith('.md'):
            with open(doc, 'r') as f:
                md_content = f.read()
                text = parse_markdown(md_content)
        elif doc.endswith('.rst'):
            with open(doc, 'r') as f:
                rst_content = f.read()
                text = parse_restructuredtext(rst_content)
        else:
            with open(doc, 'r') as f:
                text = f.read()
        topic = os.path.basename(doc)
        documents.append({
            'document': text,
            'metadata': {
                'source': doc,
                'topic': topic
            }
        })
    return documents



Example conversation

In [None]:
# Example usage:
def conversation_example():
    # Initialize RAG system
    #rag_system = EnhancedRAGSystem()
    rag_system = MilvusRAGSystem()
    
    # Sample documents
    documents = load_documents(DOCUMENT_LIST)

    
    # Add documents to RAG system
    for doc in documents:
        rag_system.add_document_with_chunking(doc['document'], doc['metadata'])
    
    # Initialize LLM interface (choose either Claude or ChatGPT)
    llm_interface = ClaudeInterface()  # or ChatGPTInterface()
    
    # Initialize RAG with LLM
    rag_with_llm = RAGWithLLM(rag_system, llm_interface)
    
    conversation = RAGConversation(rag_with_llm)
    
    # Example conversation
    queries = [
        "What is interview insights used for?",
        "Can you elaborate on its primary features?",
        "How does this relate to machine learning?",
    ]
    
    for query in queries:
        print(f"\nUser: {query}")
        response = conversation.query(query)
        print(f"Assistant: {response}")


conversation_example()