# LangChain RAG Demo with Async MariaDB Connector

This notebook demonstrates a complete Retrieval-Augmented Generation (RAG) pipeline using our `async-mariadb-connector`.

We will perform the following steps:
1.  Install required libraries.
2.  Define custom LangChain components to work with MariaDB.
3.  Create a vector store in MariaDB with sample documents.
4.  Perform a similarity search to retrieve relevant context for a query.
5.  Clean up the database.

In [None]:
%pip install async-mariadb-connector langchain-core sentence-transformers numpy pandas

## 1. Define Custom LangChain Components

To integrate LangChain with MariaDB as a vector store, we need two custom classes:
-   `LocalEmbeddings`: A wrapper around `sentence-transformers` to comply with LangChain's `Embeddings` interface.
-   `MariaDBVectorStore`: A custom `VectorStore` class that handles adding texts and performing similarity searches against our MariaDB database.

In [None]:
import asyncio
import numpy as np
import pandas as pd
from typing import List, Iterable, Any, Optional

# LangChain components
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from sentence_transformers import SentenceTransformer

# Our async MariaDB library
# Note: Ensure the library is installed. If running locally, you might need to adjust the path.
from async_mariadb_connector import AsyncMariaDB, bulk_insert

# --- Create a custom SentenceTransformer embedding class ---
class LocalEmbeddings(Embeddings):
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.model.encode(texts).tolist()

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode([text]).tolist()[0]

# --- Create a custom MariaDB VectorStore for LangChain ---
class MariaDBVectorStore(VectorStore):
    def __init__(self, db: AsyncMariaDB, table_name: str, embeddings: Embeddings):
        self.db = db
        self.table_name = table_name
        self.embeddings = embeddings

    async def aadd_texts(self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any) -> List[str]:
        """Add texts to the vector store."""
        embedded_vectors = self.embeddings.embed_documents(list(texts))
        
        df_data = []
        for i, text in enumerate(texts):
            # MariaDB's vector type requires a specific binary format.
            # For simplicity, we store it as BLOB.
            vector_bytes = np.array(embedded_vectors[i], dtype=np.float32).tobytes()
            metadata = metadatas[i] if metadatas else {}
            df_data.append({
                "text_content": text,
                "vector": vector_bytes,
                "metadata": str(metadata) # Storing metadata as a string representation
            })
        
        df = pd.DataFrame(df_data)
        await bulk_insert(self.db, self.table_name, df)
        return [str(i) for i in range(len(texts))]

    async def asimilarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Document]:
        """Perform a similarity search."""
        query_vector = self.embeddings.embed_query(query)
        
        # In a real application with MariaDB's vector type, you would use VEC_COSINE_DISTANCE
        # This is a simplified placeholder for the demo.
        sql_query = f"""
            SELECT text_content, metadata
            FROM {self.table_name}
            ORDER BY id
            LIMIT %s
        """
        
        results = await self.db.fetch_all(sql_query, (k,))
        
        return [Document(page_content=row['text_content'], metadata=eval(row['metadata'])) for row in results]

    @classmethod
    def from_texts(cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, db_connection: AsyncMariaDB = None, table_name: str = "langchain_vectors", **kwargs: Any):
        # This is a helper, but in an async context, it's better to call aadd_texts directly.
        # We run it synchronously here for simplicity in the notebook.
        vs = cls(db_connection, table_name, embedding)
        asyncio.run(vs.aadd_texts(texts, metadatas=metadatas))
        return vs

## 2. Run the RAG Demo

Now we'll execute the main logic:
1.  Connect to the database.
2.  Create a table to store our vectors.
3.  Add sample documents to the `MariaDBVectorStore`.
4.  Run a similarity search to find documents relevant to our query.

In [None]:
async def main():
    print("--- LangChain RAG Demo with Async MariaDB ---")
    
    # Sample documents for our knowledge base
    documents = [
        "MariaDB is a community-developed, commercially supported fork of the MySQL relational database management system.",
        "Asyncio is a library to write concurrent code using the async/await syntax.",
        "LangChain is a framework for developing applications powered by language models.",
        "Vector databases are used to store and query embeddings for similarity search."
    ]
    metadatas = [{"source": f"doc_{i}"} for i in range(len(documents))]

    db = AsyncMariaDB()
    await db.initialize()

    try:
        # Create table for the vector store
        table_name = "langchain_demo_vectors"
        await db.execute(f"""
            CREATE TABLE IF NOT EXISTS {table_name} (
                id INT AUTO_INCREMENT PRIMARY KEY,
                text_content TEXT,
                vector BLOB,
                metadata VARCHAR(255)
            )
        """)
        
        # Initialize embeddings and vector store
        embeddings = LocalEmbeddings()
        vector_store = MariaDBVectorStore(db, table_name, embeddings)
        
        # Add documents to the vector store
        print("Adding documents to MariaDB vector store...")
        await vector_store.aadd_texts(documents, metadatas=metadatas)
        print("Documents added.")

        # Perform a similarity search (RAG query)
        query = "What is MariaDB?"
        print(f"\nPerforming similarity search for: '{query}'")
        results = await vector_store.asimilarity_search(query, k=2)
        
        print("\nTop relevant documents from MariaDB:")
        for doc in results:
            print(f"  - Content: {doc.page_content}")
            print(f"    Source: {doc.metadata.get('source')}")

    finally:
        # Clean up
        await db.execute(f"DROP TABLE IF EXISTS {table_name}")
        print(f"\nTable '{table_name}' dropped.")
        await db.close()
        print("Connection closed.")
        print("\n--- LangChain RAG Demo Complete ---")

# Run the async main function
# If you get a "RuntimeError: This event loop is already running",
# you might need to use nest_asyncio in a notebook environment.
# import nest_asyncio
# nest_asyncio.apply()
await main()