## Here we have defined a custom DSpy RetirverModelClient

The client uses OllamaEmbeddingFunction to  fetch  answers.

The OllamaEmbeddingFunction is also provided Here


This OllamaEmbeddingFunction  uses Ollama python library to build the interface between ollama embedder and retriever
with a collection that is created by initialize_chromadb_collection method

The initialize_chromadb_collection method depends on OllamaEmbeddingFunction for writing vectors to the chromadb server

DSPythonicRMClient uses OllamaEmbeddingFunction for reading passages from chromadb server

### OllamaEmbeddingFunction 

In [1]:
import random
import ollama
import chromadb
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing import Optional, Union, List
import pandas as pd



class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
    Documents = Union[str, List[str], pd.DataFrame]

    def __init__(self, model_name: str = "mxbai-embed-large", collection=None):
        """Initialize the embedding function."""
        self.model_name = model_name
        self.collection = collection  # Store the collection

    def __call__(self, input: Documents) -> List[List[float]]:
        """Embed the input documents."""
        return self._embed(input)

    def _embed(self, documents: Documents) -> List[List[float]]:
        """Generate embeddings for the input documents using Ollama."""
        embeddings = []

        # Handle different input types
        if isinstance(documents, str):
            # If input is a single string, convert it to a list
            documents = [documents]
        elif isinstance(documents, pd.DataFrame):
            # If input is a DataFrame, extract the first column as a list of strings
            documents = documents.iloc[:, 0].tolist()
        elif isinstance(documents, list):
            # If input is a list, ensure all elements are strings
            documents = [str(doc) for doc in documents]
        elif hasattr(documents, '__iter__'):
            # If input is any other iterable (like a set or tuple), convert to list of strings
            documents = [str(doc) for doc in documents]
        else:
            raise ValueError("Unsupported document type. Please provide a string, list, or pandas DataFrame.")

        # Generate embeddings for each document
        for doc in documents:
            response = ollama.embeddings(
                model=self.model_name,
                prompt=doc
            )
            embeddings.append(response["embedding"])

        return embeddings

    def _retrieve(self, query: Union[str, List[str]], n_results: int) -> List[str]:
        """Retrieve relevant documents based on a query using Ollama."""
        if self.collection is None:
            raise ValueError("Collection is not set. Please initialize the OllamaEmbeddingFunction with a valid collection.")

        # Handle different input types for query
        if isinstance(query, list):
            # If input is a list, join it into a single string
            query = ' '.join(query)

        response = ollama.embeddings(
            model=self.model_name,
            prompt=query
        )
        query_embedding = response["embedding"]

        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results
        )

        return results['documents'][:n_results]  # Return the top n_results documents in correct order



### Make the Robust ChromaDB Collection Method

In [2]:
import random
import chromadb
from typing import Optional
import random
import chromadb
from typing import Optional


# Create a dictionary to store the last used collection name and serial number
global last_used_info

last_used_info = {}



def initialize_chromadb_collection(host: str = 'localhost', port: int = 8000, reset: Optional[bool] = False, create_new_collection: bool = True, last_used: Optional[dict] = None) -> chromadb.Collection:
    """
    Initializes a ChromaDB HTTP client and creates or retrieves a collection.

    Args:
        host (str): The host where the ChromaDB server is running. Defaults to 'localhost'.
        port (int): The port on which the ChromaDB server is listening. Defaults to 8000.
        reset (Optional[bool]): If True, resets the ChromaDB client before creating or using a collection. Defaults to False.
        create_new_collection (bool): If True, creates a new collection with a serial numbered name. If False, uses the last used collection name. Defaults to True.
        last_used (Optional[dict]): A dictionary to store the last used collection name and number. Defaults to None.

    Returns:
        chromadb.Collection: The created or existing ChromaDB collection.
    """
    # Initialize last_used if it is None
    if last_used is None:
        last_used = {'collection_name': None, 'serial_number': 0}
    elif 'serial_number' not in last_used:
        last_used['serial_number'] = 0

    # Create a ChromaDB HTTP client
    client = chromadb.HttpClient(host=host, port=port)
    
    # Reset the client if requested
    if reset:
        client.reset()
    
    if create_new_collection:
        # Increment the serial number for the new collection name
        last_used['serial_number'] += 1
        collection_name = f"docs{last_used['serial_number']}"
        
        # Use get_or_create_collection to avoid UniqueConstraintError
        collection = client.get_or_create_collection(name=collection_name)
        
        # Store the collection name for future use
        last_used['collection_name'] = collection_name
    else:
        # Use the last used collection name
        collection_name = last_used.get('collection_name')
        
        if collection_name is None:
            raise ValueError("No previous collection name found. Set create_new_collection to True to create a new collection.")
        
        # Get or create the collection with the last used name
        collection = client.get_or_create_collection(name=collection_name)

    # Print the name of the created or used collection
    print(f"Using collection: {collection.name}")
    
    return collection


### Now comes the RMC

This RMC cannot (currently) fetch from any url / port and it is dependent on OllamaEmbeddingFunction  for encoding decoding

In [3]:
from typing import List, Union, Optional
import dspy


class DSPythonicRMClient(dspy.Retrieve):
    def __init__(self, embedding_function: OllamaEmbeddingFunction, k: int = 3):
        """
        Initialize the DSPythonicRMClient.

        Args:
            embedding_function (OllamaEmbeddingFunction): The embedding function to use for retrieval.
            k (int): The number of top passages to retrieve. Defaults to 3.
        """
        super().__init__(k=k)
        self.embedding_function = embedding_function

    def forward(self, query: Union[str, List[str]], n_results: Optional[int] = None) -> dspy.Prediction:
        """
        Retrieve passages based on the embedded query.

        Args:
            query (str): The query string for which to retrieve passages.
            n_results (Optional[int]): The number of results to return. Defaults to k.

        Returns:
            dspy.Prediction: An object containing the retrieved passages.
        """
        n_results = n_results if n_results is not None else self.k
        retrieved_documents = self.embedding_function._retrieve(query, n_results=n_results)

        return dspy.Prediction(passages=retrieved_documents)

### Example use With Rag

In [4]:
import os
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize

# Make sure to download the punkt tokenizer if you haven't already
""" nltk.download('punkt') """

def load_documents(folder_path):
    """Recursively searches for .md, .docx, and .txt files in the given folder path and its subfolders."""
    documents = []
    data = []

    for root, _, files in os.walk(folder_path):
        for filename in files:
            file_path = os.path.join(root, filename)

            if filename.endswith('.md'):
                with open(file_path, 'r', encoding='utf-8') as file:
                    content = file.read()

            

            elif filename.endswith('.txt'):
                with open(file_path, 'r', encoding='utf-8') as file:
                    content = file.read()

            else:
                continue  # Skip files that are not .md or .txt

            # Split the content into sentences
            sentences = sent_tokenize(content)

            documents.extend(sentences)  # Add sentences to the documents list
            data.append({'index': len(documents) - len(sentences), 'filename': filename, 'content': sentences})
            #print(pd.DataFrame(data).to_markdown)

    return documents




# Step 1: Initialize ChromaDB Collection
collection = initialize_chromadb_collection(create_new_collection=True) #, last_used=last_used_info

# Step 2: Create an instance of OllamaEmbeddingFunction
embedding_function = OllamaEmbeddingFunction(model_name="mxbai-embed-large", collection=collection)

# Step 3: Embed Documents and Add Them to the Collection
documents = load_documents("/home/riju279/Documents/writings/Obsidian/Videodraft2/") #str(input("Enter the absoluter path to the folder"))
last_used_info = {}
# Embed documents and add them to the collection
embeddings = embedding_function(documents)
collection.add(
    ids=[str(i) for i in range(1, len(documents) + 1)],
    embeddings=embeddings,
    documents=documents
)



Using collection: docs1


### Above code is a fully working example for making embeddings out of markdown files from a given Directory

getting proper infrastructure for doing large scale dspy experiments was important for us.


# Step 4: Create an instance of DSPythonicRMClient and LLM


In [5]:

lrm = DSPythonicRMClient(embedding_function=embedding_function, k=10)
olm=dspy.OpenAI(api_base="http://localhost:11434/v1/", api_key="ollama", model="mistral-nemo:latest", stop='\n\n', model_type='chat')

dspy.settings.configure(lm=olm,rm=lrm)

In [None]:
import dspy

class GenerateQuestion(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.OutputField(desc="ask questions about the facts")
    


class RAG(dspy.Module):
    def __init__(self, num_passages=5):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_question = dspy.ChainOfThought(GenerateQuestion)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_question(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)
    
import time

generate_question = dspy.ChainOfThoughtWithHint(GenerateQuestion)

for i in range(len(documents)):
    context=documents[i]
    hint= f"ask why, how and what about {documents[i]}"
    pred=generate_question(context=context, hint=hint)
    print(f"{i}\n Context: {context}\n\n")
    print(f"Predicted Question: {pred.question}\n\n -----\n\n######\n\n")
    time.sleep(2)

