<a href="https://colab.research.google.com/github/msquareddd/ai-engineering-notebooks/blob/main/simple_rag_llamaindex.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Additional Installations

In [None]:
!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate
!pip install -q llama-index-llms-ollama llama-index-embeddings-huggingface
!pip install -q llama-index-llms-huggingface llama-index-llms-langchain
!pip install -q llama-index-readers-file
!pip install -q docx2txt pypdf
!pip install -q langchain langchain_community langchain_core langgraph huggingface_hub
!pip install -U -q bitsandbytes

# Imports

In [None]:
import torch, os
import gradio as gr
from llama_index.llms.langchain import LangChainLLM
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import (
    SimilarityPostprocessor,
    KeywordNodePostprocessor,
    MetadataReplacementPostProcessor,
    LongContextReorder,
)
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          BitsAndBytesConfig,Gemma3ForConditionalGeneration,
                          AutoProcessor)
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_core.callbacks import StdOutCallbackHandler
from transformers import pipeline
from google.colab import userdata, files
from huggingface_hub import login
from langchain.schema import BaseRetriever, Document
from typing import Any, Optional, List
from pydantic import Field, ConfigDict

# HF Log-in

In [None]:
# import os, getpass

# def _set_env(var: str):
#     if not os.environ.get(var):
#         os.environ[var] = getpass.getpass(f"{var}: ")

# _set_env("OPENAI_API_KEY")

# Set up LangSmith (optional)
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_API_KEY"] = userdata.get('LANGSMITH_API_KEY')
os.environ["LANGSMITH_PROJECT"] = "RAG LlamaIndex"

# HuggingFace login
hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

# Directory Creation

In [None]:
base_path = "/content/data"

if not os.path.exists(base_path):
    os.makedirs(base_path)
    print(f"Directory '{base_path}' created successfully.")
else:
    print(f"Directory '{base_path}' already exists.")

# LLM Initialization

In [None]:
LLM_MODEL = "google/gemma-3-4b-it"#"meta-llama/Llama-3.2-3B-Instruct" #"HuggingFaceTB/SmolLM3-3B" ## "meta-llama/Llama-3.1-8B-Instruct" #"Qwen/Qwen3-14B" "MiniMaxAI/MiniMax-M1-80k"
EMBED_MODEL = "google/embeddinggemma-300m"#"BAAI/bge-m3" #  "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" "nomic-ai/nomic-embed-text-v1.5" "BAAI/bge-base-en-v1.5"

In [None]:
# Quantization config
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.bfloat16,
    bnb_8bit_quant_type="int8",
    bnb_8bit_use_double_quant=False,
)

# Load tokenizer with proper configuration
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
# processor = AutoProcessor.from_pretrained(LLM_MODEL)

# Configure tokenizer for Gemma
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
# tokenizer.model_max_length = 2048

# Load model
model = Gemma3ForConditionalGeneration.from_pretrained(
    LLM_MODEL,
    device_map="auto",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,  # Use bfloat16
    # attn_implementation="sdpa",  # Scaled dot-product attention
    # low_cpu_mem_usage=True,
)

# Create pipeline
# llm_pipeline = pipeline(
#     "image-text-to-text",
#     model=model,
#     tokenizer=tokenizer,
#     return_full_text=False,
#     max_new_tokens=256,
#     do_sample=True,
#     temperature=0.3,
#     top_p=0.9,
#     repetition_penalty=1.1,
#     # pad_token_id=tokenizer.pad_token_id,
#     # eos_token_id=tokenizer.eos_token_id,
#     torch_dtype=torch.bfloat16,
#     truncation=True,
#     max_length=2048,
# )

llm_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=False,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.3,
    top_p=0.9,
    repetition_penalty=1.1,
    # device="cuda",
    torch_dtype=torch.bfloat16
)

# Wrap with LangChain
llm = HuggingFacePipeline(pipeline=llm_pipeline)
print("✅ Model loaded successfully!")

In [None]:
torch.cuda.empty_cache()

# LlamaIndex Settings

In [None]:
Settings.embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL)
Settings.llm = LangChainLLM(llm)  # llm is your HuggingFacePipeline instance
# Settings.tokenizer = tokenizer
# Settings.llm = HuggingFaceLLM(model_name=LLM_MODEL)
# Settings.chunk_size = 512
# Settings.chunk_overlap = 20

In [None]:
# Create a node parser
node_parser = SentenceSplitter(chunk_size=512, chunk_overlap=50)

In [None]:
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex

documents = SimpleDirectoryReader("/content/data").load_data()

for i, doc in enumerate(documents):
    print(f"Document {i+1}: {doc.metadata.get('file_name', 'Unknown')} - {len(doc.text)} characters")

nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
print(f"Created {len(nodes)} text chunks")

# query_engine = index.as_query_engine()

In [None]:
index = VectorStoreIndex(nodes, show_progress=True)
print("Index created successfully!")

In [None]:
top_k = 3

retriever = VectorIndexRetriever(
    index=index,
    similarity_top_k=top_k,
)

In [None]:
# assemble query engine
query_engine = RetrieverQueryEngine(
    retriever=retriever,
    node_postprocessors=[
        SimilarityPostprocessor(similarity_cutoff=0.2),
        # KeywordNodePostprocessor(keywords=["important", "critical"]),  # Filter by keywords
        LongContextReorder(),  # Reorder nodes for better context window usage
    ],
)

# LangChain Wrapper

In [None]:
class LlamaIndexRetrieverWrapper(BaseRetriever):
    # Define fields explicitly for Pydantic
    query_engine: Optional[Any] = Field(default=None, exclude=True)
    llamaindex_retriever: Optional[Any] = Field(default=None, exclude=True)
    use_engine: bool = Field(default=False, exclude=True)
    top_k: int = Field(default=3, exclude=True)

    def __init__(self, query_engine=None, llamaindex_retriever=None, top_k=3, **kwargs):
        # Initialize parent with any additional kwargs
        super().__init__(**kwargs)

        # Option 1: Use a complete query engine (preferred)
        if query_engine:
            object.__setattr__(self, 'query_engine', query_engine)
            object.__setattr__(self, 'use_engine', True)
        # Option 2: Use just the retriever
        elif llamaindex_retriever:
            object.__setattr__(self, 'llamaindex_retriever', llamaindex_retriever)
            object.__setattr__(self, 'use_engine', False)
        else:
            raise ValueError("Must provide either query_engine or llamaindex_retriever")

        object.__setattr__(self, 'top_k', top_k)

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = []

        if self.use_engine:
            # Use the query engine which includes postprocessing
            response = self.query_engine.query(query)

            # Get nodes from the response (already postprocessed)
            for node in response.source_nodes[:self.top_k]:
                # The SimilarityPostprocessor has already filtered these
                # Build metadata dictionary safely
                metadata = {}
                if hasattr(node, 'metadata') and node.metadata:
                    metadata.update(node.metadata)
                if hasattr(node, 'score'):
                    metadata['score'] = node.score

                doc = Document(
                    page_content=node.text,
                    metadata=metadata
                )
                docs.append(doc)

                print(f"with enginge - doc appended: {doc}")
        else:
            # Fallback to direct retriever usage
            nodes = self.llamaindex_retriever.retrieve(query)

            for node in nodes[:self.top_k]:
                # Manual similarity cutoff since no postprocessor
                if hasattr(node, 'score') and node.score < 0.5:
                    continue

                # Build metadata dictionary safely
                metadata = {}
                if hasattr(node, 'metadata') and node.metadata:
                    metadata.update(node.metadata)
                if hasattr(node, 'score'):
                    metadata['score'] = node.score

                doc = Document(
                    page_content=node.text,
                    metadata=metadata
                )
                docs.append(doc)

                print(f"no enginge - doc appended: {doc}")

        return docs

    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        # For async support (optional)
        return self._get_relevant_documents(query)

In [None]:
# Create the LangChain-compatible retriever
langchain_retriever = LlamaIndexRetrieverWrapper(query_engine=query_engine, llamaindex_retriever=retriever, top_k=top_k)

# Prompt Template

In [None]:
# Prompt Template
prompt_template = """You are a helpful multilingual assistant answering questions based on provided documents and context.
The documents may be in English or Italian. You should answer in the same language as the question.

Instructions:
- Answer based ONLY on the provided context
- Be concise and direct
- If the answer is not in the context, say "I don't have enough information" (or "Non ho abbastanza informazioni" in Italian)
- Respond in the same language as the question
- The documents provided might not be related to the topic, feel free to ignore them if you deem them unrelated
- Just provide a direct answer to the question asked

{context}

Question: {question}

Answer:"""

PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

In [None]:
# question = "question"

# response = query_engine.query(question)
# # print(response)

# context = "Context:\n"
# # Iterate only up to the number of source nodes returned
# for i in range(top_k):
#     context = context + response.source_nodes[i].text + "\n\n"

# # print(context)

# input=PROMPT.format(
#             context=context,
#             question=question)

# # print(input)

# Chain Set-up

In [None]:
# Set up conversation memory
memory = ConversationBufferMemory(
    memory_key='chat_history',
    return_messages=True,
    output_key='answer'
)

In [None]:
# Create the conversational chain with the wrapped retriever
conversation_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=langchain_retriever,  # Use the wrapped retriever
    memory=memory,
    callbacks=[StdOutCallbackHandler()],
    combine_docs_chain_kwargs={"prompt": PROMPT},
    return_source_documents=True,
    verbose=False
)

In [None]:
torch.cuda.empty_cache()

# Query

In [None]:
message = "how was this problem solved?"

In [None]:
# docs = langchain_retriever.get_relevant_documents(message)
# print(docs[0].page_content)

In [None]:
result = conversation_chain.invoke({"question": message})
# print(result)
answer = result.get("answer", "Sorry, I couldn't process your question.")
print(answer)

# Gradio Chat Interface

In [None]:
# Function for Gradio chat interface
def chat_function(message, history):
    try:
        # Get response from the conversation chain
        result = conversation_chain.invoke({"question": message})

        # Extract answer
        answer = result.get("answer", "Sorry, I couldn't process your question.")

        # Optionally, you can also access source documents
        sources = result.get("source_documents", [])
        if sources:
            answer += "\n\n📚 Sources used:"
            for i, doc in enumerate(sources[:3], 1):
                snippet = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
                answer += f"\n{i}. {snippet}"

        return answer
    except Exception as e:
        return f"Error: {str(e)}"

In [None]:
chat_interface = gr.ChatInterface(
    fn=chat_function,
    title="RAG LlamaIndex",
    description="Ask questions about your documents. The bot remembers conversation history!",
    examples=[
        "Can you summarize the activities performed in the [name] project?",
        "What are the main findings?",
        "Tell me more about the first point you mentioned"
    ]
)

In [None]:
chat_interface.launch(debug=True)