#  Using Semantic chunks with Gemini API and Gemini Embeddings

In [None]:
# Regular Imports
import os
import glob
import time
from dotenv import load_dotenv
from tqdm.notebook import tqdm
import gradio as gr

In [None]:
# Visual Import
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import plotly.graph_objects as go

In [None]:
# Lang Chain Imports

from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_chroma import Chroma
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.runnables import RunnableLambda

In [None]:
# Constants

CHAT_MODEL = "gemini-2.5-flash"
EMBEDDING_MODEL = "models/text-embedding-004"
# EMBEDDING_MODEL_EXP = "models/gemini-embedding-exp-03-07"

folders = glob.glob("knowledge-base/*")
text_loader_kwargs = {'encoding': 'utf-8'}
db_name = "vector_db"

In [None]:
load_dotenv(override=True)

api_key =  os.getenv("GOOGLE_API_KEY")

if not api_key:
    print("API Key not found!")
else:
    print("API Key loaded in memory")

In [None]:
def add_metadata(doc, doc_type):
    doc.metadata["doc_type"] = doc_type
    return doc

In [None]:
documents = []
for folder in tqdm(folders, desc="Loading folders"):
    doc_type = os.path.basename(folder)
    loader = DirectoryLoader(folder, glob="**/*.md", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
    folder_docs = loader.load()
    documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])

print(f"Total documents loaded: {len(documents)}")

## Create Semantic Chunks

In [None]:
chunking_embedding_model = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, task_type="retrieval_document")

text_splitter = SemanticChunker(
    chunking_embedding_model,
    breakpoint_threshold_type="percentile", 
    breakpoint_threshold_amount=95.0,     
    min_chunk_size=3                      
)

start = time.time()

semantic_chunks = []
pbar = tqdm(documents, desc="Semantic chunking documents")

for i, doc in enumerate(pbar):
    doc_type = doc.metadata.get('doc_type', 'Unknown')
    pbar.set_postfix_str(f"Processing: {doc_type}")
    try:
        doc_chunks = text_splitter.split_documents([doc])
        semantic_chunks.extend(doc_chunks)
    except Exception as e:
        tqdm.write(f"❌ Failed to split doc ({doc.metadata.get('source', 'unknown source')}): {e}")
print(f"⏱️ Took {time.time() - start:.2f} seconds")
print(f"Total semantic chunks: {len(semantic_chunks)}")

# import time
# start = time.time()

# try:
#     semantic_chunks = text_splitter.split_documents(documents)
#     print(f"✅ Chunking completed with {len(semantic_chunks)} chunks")
# except Exception as e:
#     print(f"❌ Failed to split documents: {e}")

# print(f"⏱️ Took {time.time() - start:.2f} seconds")

In [None]:
# Some Preview of the chunks
for i, doc in enumerate(semantic_chunks[:15]):
    print(f"--- Chunk {i+1} ---")
    print(doc.page_content) 
    print("\n")

## Embed with Gemini Embeddings

In [None]:
embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL,task_type="retrieval_document")

if os.path.exists(db_name):
    Chroma(persist_directory=db_name, embedding_function=embedding).delete_collection()

vectorstore = Chroma.from_documents(
    documents=semantic_chunks,
    embedding=embedding,
    persist_directory=db_name
)

print(f"✅ Vectorstore created with {vectorstore._collection.count()} documents")

## Visualzing Vectors

In [None]:
collection = vectorstore._collection
result = collection.get(include=['embeddings', 'documents', 'metadatas'])
vectors = np.array(result['embeddings'])
documents = result['documents']
metadatas = result['metadatas']
doc_types = [metadata['doc_type'] for metadata in metadatas]
colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in doc_types]

In [None]:
# We humans find it easier to visalize things in 2D!
# Reduce the dimensionality of the vectors to 2D using t-SNE
# (t-distributed stochastic neighbor embedding)

tsne = TSNE(n_components=2, random_state=42)
reduced_vectors = tsne.fit_transform(vectors)

# Create the 2D scatter plot
fig = go.Figure(data=[go.Scatter(
    x=reduced_vectors[:, 0],
    y=reduced_vectors[:, 1],
    mode='markers',
    marker=dict(size=5, color=colors, opacity=0.8),
    text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)],
    hoverinfo='text'
)])

fig.update_layout(
    title='2D Chroma Vector Store Visualization',
    scene=dict(xaxis_title='x',yaxis_title='y'),
    width=800,
    height=600,
    margin=dict(r=20, b=10, l=10, t=40)
)

fig.show()

## RAG Setup

In [None]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 20 })

# Conversation Memory
# memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

chat_llm = ChatGoogleGenerativeAI(model=CHAT_MODEL, temperature=0.7)

question_generator_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
If the follow up question is already a standalone question, return it as is.

Chat History:
{chat_history}
Follow Up Input: {input}  
Standalone question:"""

question_generator_prompt = ChatPromptTemplate.from_messages([
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{input}")
])

history_aware_retriever = create_history_aware_retriever(
    chat_llm, retriever, question_generator_prompt
)

qa_system_prompt = """You are Insurellm’s intelligent virtual assistant, designed to answer questions with accuracy and clarity. Respond naturally and helpfully, as if you're part of the team.
Use the retrieved documents and prior conversation to provide accurate, conversational, and concise answers.Rephrase source facts in a natural tone, not word-for-word.
When referencing people or company history, prioritize clarity and correctness.
Only infer from previous conversation if it provides clear and factual clues. Do not guess or assume missing information.
If you truly don’t have the answer, respond with:
"I don't have that information."
Avoid repeating the user's wording unnecessarily. Do not refer to 'the context', speculate, or make up facts.

{context}"""


qa_human_prompt = "{input}" 

qa_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(qa_system_prompt),
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{input}")
])

combine_docs_chain = create_stuff_documents_chain(chat_llm, qa_prompt)

# inspect_context = RunnableLambda(lambda inputs: (
#     print("\n Retrieved Context:\n", "\n---\n".join([doc.page_content for doc in inputs["context"]])),
#     inputs  # pass it through unchanged
# )[1])

# inspect_inputs = RunnableLambda(lambda inputs: (
#     print("\n Inputs received by the chain:\n", inputs),
#     inputs
# )[1])

base_chain = create_retrieval_chain(history_aware_retriever, combine_docs_chain)

# Using Runnable Lambda as Gradio needs the response to contain only the output (answer) and base_chain would have a dict with input, context, chat_history, answer

# base_chain_with_output = base_chain | inspect_context | RunnableLambda(lambda res: res["answer"])
# base_chain_with_output = base_chain | RunnableLambda(lambda res: res["answer"])


# Session Persistent Chat History 
# If we want to persist history between sessions then use MongoDB (or any non sql DB)to store and use MongoDBChatMessageHistory (relevant DB Wrapper)

chat_histories = {}

def get_history(session_id):
    if session_id not in chat_histories:
        chat_histories[session_id] = InMemoryChatMessageHistory()
    return chat_histories[session_id]

# Currently set to streaming ...if one shot response is needed then comment base_chain and output_message_key and enable base_chain_with_output
conversation_chain = RunnableWithMessageHistory(
    # base_chain_with_output,
    base_chain,
    get_history,
    output_messages_key="answer",        
    input_messages_key="input",
    history_messages_key="chat_history",
)

In [None]:
def chat(question, history):
    try:
        # result = conversation_chain.invoke({"input": question, "chat_history": memory.buffer_as_messages})
        
        # memory.chat_memory.add_user_message(question)
        # memory.chat_memory.add_ai_message(result["answer"])

        # return result["answer"]

        
        session_id = "default-session"

        # # FUll chat version
        # result = conversation_chain.invoke(
        #     {"input": question},
        #     config={"configurable": {"session_id": session_id}}
        # )
        # # print(result)
        # return result

        # Streaming Version
        response_buffer = ""

        for chunk in conversation_chain.stream({"input": question},config={"configurable": {"session_id": session_id}}):
            if "answer" in chunk:
                response_buffer += chunk["answer"]
                yield response_buffer 
    except Exception as e:
        print(f"An error occurred during chat: {e}")
        return "I apologize, but I encountered an error and cannot answer that right now."

In [None]:
view = gr.ChatInterface(chat, type="messages").launch(inbrowser=True)