# Introduction

This jupyter notebook helps you to build a RAG system from scratch.

I strongly recommend you to checkout the [README](./readme.md) section to gain a background about this topic before diving straight into the code.


# Setup dev env


## Python Virtual Environment

- [Check here](https://realpython.com/python-virtual-environments-a-primer/) why is a venv useful
- Run cell below to create a venv


In [None]:
# Create a Python virtual environment
!python -m venv rag_venv

# Once activated, select it as your Jupyter Kernel (see right hand top of your jupyter notebook).

## Install Packages


In [None]:
# Install all dependencies
!pip install -r requirements.txt


## Load API Key

The Groq API key is stored in environment variables using the [python-dotenv package](https://pypi.org/project/python-dotenv/)

Get your Groq key from [here](https://console.groq.com/docs/quickstart).
Once you have the API key, create a `.env` file inside the root of the downloaded git repository.
Add this text to the newly created file

`GROQ_API_KEY="your_key"`


In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
#print(os.getenv('MY_VAR'))
print(os.getenv('GROQ_API_KEY'))


# Run RAGify


In [73]:
import os
import groq
from typing import List
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from pypdf import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import json
import logging
from dotenv import load_dotenv
import streamlit as st
import hashlib
import time
from groq import RateLimitError

# I have loaded environment variables to keep sensitive information out of the codebase.
# This is crucial for security and allows for easy configuration changes across environments.
load_dotenv()

# I have set up logging to track execution and debug issues.
# Proper logging is essential for monitoring and troubleshooting in production environments.
logging.basicConfig(level=logging.INFO)

# I have initialized the Groq client for API access.
# Here, I'm using an API key stored in environment variables for security.
# The commented out line shows an alternative using Streamlit secrets, which is useful for deployment scenarios.
client = groq.Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
    #api_key=st.secrets["GROQ_API_KEY"], # GROQ_API_KEY = ""
)

# I have loaded a pre-trained sentence transformer model for generating text embeddings.
# I chose 'all-mpnet-base-v2' for its balance of performance and accuracy.
# This model is crucial for converting text to vector representations for similarity search.
model_name = 'all-mpnet-base-v2'
model = SentenceTransformer(model_name)

def extract_text_from_pdf(pdf_path):
    # I have extracted text from PDFs to make the content searchable.
    # This allows us to work with various document formats in a unified way.
    with open(pdf_path, 'rb') as file:
        reader = PdfReader(file)
        text = ''
        for page in reader.pages:
            text += page.extract_text() + '\n'
    return text

def create_chunks(text, chunk_size=1000, chunk_overlap=200):
    # I have chunked the text for two main reasons:
    # 1. It allows us to process long documents that might exceed model token limits.
    # 2. It creates more granular pieces of text, improving retrieval accuracy.
    # I have used overlap to maintain context between chunks.
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
    )
    chunks = text_splitter.split_text(text)
    return chunks

def get_files_hash(directory):
    # I have hashed the input files to detect changes.
    # This is crucial for maintaining an up-to-date knowledge base without unnecessary reprocessing.
    hash_md5 = hashlib.md5()
    for filename in sorted(os.listdir(directory)):
        if filename.endswith('.pdf'):
            with open(os.path.join(directory, filename), "rb") as f:
                for chunk in iter(lambda: f.read(4096), b""):
                    hash_md5.update(chunk)
    return hash_md5.hexdigest()

@st.cache_data
def process_pdfs(_hash=None):
    # I have used caching to avoid reprocessing PDFs on every run.
    # This significantly improves performance for repeated queries.
    pdf_directory = './input_files/'
    current_hash = get_files_hash(pdf_directory)

    # I have cleared the cache if input files have changed.
    # This ensures we're always working with the most up-to-date information.
    if _hash is not None and _hash != current_hash:
        st.cache_data.clear()

    all_chunks = []
    chunk_to_doc = {}
    for filename in os.listdir(pdf_directory):
        if filename.endswith('.pdf'):
            pdf_path = os.path.join(pdf_directory, filename)
            text = extract_text_from_pdf(pdf_path)
            chunks = create_chunks(text)
            all_chunks.extend(chunks)
            for chunk in chunks:
                chunk_to_doc[chunk] = filename

    # I have used logging to help with debugging and monitoring the chunking process.
    logging.info(f"Total chunks: {len(all_chunks)}")
    logging.info(f"Sample chunk: {all_chunks[0][:100]}...")

    return all_chunks, chunk_to_doc, current_hash

@st.cache_resource
def create_faiss_index(all_chunks):
    # I have used FAISS for efficient similarity search.
    # This is crucial for quickly finding relevant chunks when answering queries.
    embeddings = model.encode(all_chunks)
    dimension = embeddings.shape[1]
    num_chunks = len(all_chunks)

    # I have dynamically chosen the index type based on the dataset size.
    # This optimizes search performance: FlatL2 for small datasets, IVFFlat for larger ones.
    if num_chunks < 100:
        logging.info("Using FlatL2 index due to small number of chunks")
        index = faiss.IndexFlatL2(dimension)
    else:
        logging.info("Using IVFFlat index")
        n_clusters = min(int(np.sqrt(num_chunks)), 100)  # Balancing clustering and search efficiency
        quantizer = faiss.IndexFlatL2(dimension)
        index = faiss.IndexIVFFlat(quantizer, dimension, n_clusters)
        index.train(embeddings.astype('float32'))

    index.add(embeddings.astype('float32'))
    return index

# I have initialized a cache for storing query results.
# Caching improves response times for repeated or similar queries.
cache_file = 'semantic_cache.json'

def load_cache():
    # I have loaded the cache from a file to persist it across sessions.
    # This improves the system's efficiency over time.
    try:
        with open(cache_file, 'r') as f:
            cache = json.load(f)
            # I have reset the cache if the embedding model changes to ensure consistency.
            if cache.get('model_name') != model_name:
                logging.info("Embedding model changed. Resetting cache.")
                return {"queries": [], "embeddings": [], "responses": [], "model_name": model_name}
            return cache
    except FileNotFoundError:
        return {"queries": [], "embeddings": [], "responses": [], "model_name": model_name}

def save_cache(cache):
    # I have regularly saved the cache to ensure we don't lose valuable precomputed results.
    with open(cache_file, 'w') as f:
        json.dump(cache, f)

cache = load_cache()

def retrieve_from_cache(query_embedding, threshold=0.5):
    # I have implemented semantic caching to reuse results for similar queries.
    # This significantly reduces API calls and improves response times.
    for i, cached_embedding in enumerate(cache['embeddings']):
        if len(cached_embedding) != len(query_embedding):
            logging.warning("Cached embedding dimension mismatch. Skipping cache entry.")
            continue
        distance = np.linalg.norm(query_embedding - np.array(cached_embedding))
        if distance < threshold:
            return cache['responses'][i]
    return None

def update_cache(query, query_embedding, response):
    # I have updated the cache with new queries to continually improve performance.
    cache['queries'].append(query)
    cache['embeddings'].append(query_embedding.tolist())
    cache['responses'].append(response)
    cache['model_name'] = model_name
    save_cache(cache)

def retrieve_relevant_chunks(query, index, all_chunks, top_k=10):
    # I have used vector similarity to find the most relevant chunks.
    # This is more effective than keyword matching for understanding context and semantics.
    query_vector = model.encode([query])[0]

    cached_response = retrieve_from_cache(query_vector)
    if cached_response:
        logging.info("Answer recovered from Cache.")
        return cached_response

    # I have limited top_k to avoid retrieving more chunks than available.
    top_k = min(top_k, len(all_chunks))
    D, I = index.search(np.array([query_vector]).astype('float32'), top_k)
    relevant_chunks = [all_chunks[i] for i in I[0]]

    update_cache(query, query_vector, relevant_chunks)
    return relevant_chunks

def generate_response(query: str, relevant_chunks: List[str], primary_model: str = "llama-3.1-8b-instant", fallback_model: str = "gemma2-9b-it", max_retries: int = 3):
    # I have used a language model to generate responses based on retrieved chunks.
    # This allows for more natural and contextually appropriate answers.
    context = "\n".join(relevant_chunks)
    prompt = f"""Based on the following context, please answer the question. If the answer is not fully contained in the context, provide the most relevant information available and indicate any uncertainty.

Context:
{context}

Question: {query}

Answer:"""

    # I have implemented a fallback mechanism and retry logic for robustness.
    # This ensures the system can handle API errors and rate limits gracefully.
    models = [primary_model, fallback_model]
    for model in models:
        for attempt in range(max_retries):
            try:
                chat_completion = client.chat.completions.create(
                    messages=[
                        {
                            "role": "system",
                            "content": "You are a helpful assistant that answers questions based on the given context."
                        },
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ],
                    model=model,
                    temperature=0.7,
                    max_tokens=1024,
                    top_p=1,
                    stream=False,
                    stop=None
                )

                response = chat_completion.choices[0].message.content.strip()
                usage_info = {
                    "prompt_tokens": chat_completion.usage.prompt_tokens,
                    "completion_tokens": chat_completion.usage.completion_tokens,
                    "total_tokens": chat_completion.usage.total_tokens,
                    "model_used": model
                }
                logging.info(f"Usage Info: {usage_info}")
                return response, usage_info, relevant_chunks

            except RateLimitError as e:
                if model == fallback_model and attempt == max_retries - 1:
                    logging.error(f"Rate limit exceeded for both models after {max_retries} attempts.")
                    raise e
                logging.warning(f"Rate limit exceeded for model {model}. Retrying in 5 seconds...")
                time.sleep(5)
            except Exception as e:
                logging.error(f"Error occurred with model {model}: {str(e)}")
                break  # Move to the next model if any other error occurs

    raise Exception("Failed to generate response with all available models.")

def rag_query(query: str, index, all_chunks, chunk_to_doc, top_k: int = 10) -> tuple:
    # I have combined retrieval and generation for a complete RAG pipeline.
    # RAG allows us to ground the model's responses in specific, relevant information.
    relevant_chunks = retrieve_relevant_chunks(query, index, all_chunks, top_k)
    response, usage_info, used_chunks = generate_response(query, relevant_chunks)

    # I have tracked source documents for transparency and citation.
    source_docs = list(set([chunk_to_doc.get(chunk, "Unknown Source") for chunk in used_chunks]))

    return response, usage_info, source_docs

# I have configured the Streamlit app.
# I have used Streamlit for rapid prototyping and easy deployment of the user interface.
st.set_page_config(page_title="Blunder Mifflin", page_icon=":soccer:", layout="wide", initial_sidebar_state="expanded", menu_items=None)

def main():
    st.write("Ask questions about Blunder Mifflin's Company Policy.")

    # I have processed PDFs and created the index at the start to ensure up-to-date information.
    all_chunks, chunk_to_doc, current_hash = process_pdfs()
    index = create_faiss_index(all_chunks)

    # I have provided default questions to guide users and demonstrate system capabilities.
    default_questions = [
        "Select a question",
        "What is Blunder Mifflin's product range?",
        "Who is part of Blunder Mifflin's team?",
        "What is Blunder Mifflin's policy relationships and nepotism?",
        "Describe Blunder Mifflin's Birthday Party Committee Rules",
        "Other (Type your own question)"
    ]

    # I have used a dropdown for ease of use, but also allowed custom questions for flexibility.
    selected_question = st.selectbox("Choose a question or select 'Other' to type your own:", default_questions)

    if selected_question == "Other (Type your own question)":
        user_query = st.text_input("Enter your question:")
    elif selected_question != "Select a question":
        user_query = selected_question
    else:
        user_query = ""

    if user_query:
        pass

    # I have used a button to trigger the query process, giving users control over when to send a request.
    if st.button("Get Answer"):
        if user_query and user_query != "Select a question":
            with st.spinner("Generating answer..."):
                # I have rechecked for changes in PDFs to ensure we're using the latest data.
                all_chunks, chunk_to_doc, _ = process_pdfs(current_hash)
                index = create_faiss_index(all_chunks)
                response, usage_info, source_docs = rag_query(user_query, index, all_chunks, chunk_to_doc)

            # I have displayed the response, sources, and usage info for transparency.
            st.subheader("Answer:")
            st.write(response)

            st.subheader("Source Documents:")
            for doc in source_docs:
                st.write(f"- {doc}")

            with st.expander("Usage Information"):
                st.json({
                    "Prompt Tokens": usage_info["prompt_tokens"],
                    "Completion Tokens": usage_info["completion_tokens"],
                    "Total Tokens": usage_info["total_tokens"],
                    "Model Used": usage_info["model_used"]
                })
        else:
            st.warning("Please select a question or enter your own.")

if __name__ == "__main__":
    main()


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-mpnet-base-v2
INFO:root:Total chunks: 28
INFO:root:Sample chunk: # About Blunder Mifﬂin## Our HistoryFounded in 1785 by the visionary Michael Blunder, Blunder Mifﬂin...
Batches: 100%|██████████| 1/1 [00:02<00:00,  2.46s/it]
INFO:root:Using FlatL2 index due to small number of chunks
Batches: 100%|██████████| 1/1 [00:00<00:00, 15.66it/s]
INFO:httpx:HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
INFO:root:Usage Info: CompletionUsage(completion_tokens=150, prompt_tokens=1569, total_tokens=1719, completion_time=0.2, prompt_time=0.341392703, queue_time=None, total_time=0.5413927030000001)


Query: Describe Blunder Mifflin's remote work policy?
Response: According to the provided context, Blunder Mifflin's remote work policy allows for remote and hybrid work options, but they "value in-person collaboration" and seem to prefer traditional office work arrangements. The policy mentions that remote and hybrid work are options "if they fit the job," but does not provide further details on eligibility, expectations, or procedures.

It's worth noting that the context also mentions a "Staff Grievance Procedure" and a "Drug, Alcohol, and Smoking Policy," but these sections do not directly relate to remote work policies.

Overall, Blunder Mifflin's remote work policy appears to be somewhat ambiguous and in need of further clarification, as it prioritizes in-person collaboration but also offers remote options.


# Playground


## Full json response


In [34]:
import os
import json
from groq import Groq
from datetime import datetime

client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant."
        },
        {
            "role": "user",
            "content": "Give me a funny one-liner.",
        }
    ],
    model="gemma2-9b-it",  # gemma2-9b-it
    temperature=1,
    max_tokens=1024,
    top_p=1,
    stream=False,
    stop=None
)

# Create a dictionary with the desired structure
response_dict = {
    "id": chat_completion.id,
    "object": "chat.completion",
    "created": int(datetime.now().timestamp()),
    "model": chat_completion.model,
    "system_fingerprint": chat_completion.system_fingerprint,  # This might be None
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": chat_completion.choices[0].message.content
            },
            "finish_reason": chat_completion.choices[0].finish_reason,
            "logprobs": None
        }
    ],
    "usage": {
        "prompt_tokens": chat_completion.usage.prompt_tokens,
        "completion_tokens": chat_completion.usage.completion_tokens,
        "total_tokens": chat_completion.usage.total_tokens,
        "prompt_time": round(chat_completion.usage.prompt_time, 3),
        "completion_time": round(chat_completion.usage.completion_time, 3),
        "total_time": round(chat_completion.usage.total_time, 3)
    }
}

# Print the formatted JSON response
print(json.dumps(response_dict, indent=2))

{
  "id": "chatcmpl-8daa49bf-3a07-42fb-ab0e-c954a10d0107",
  "object": "chat.completion",
  "created": 1722858859,
  "model": "gemma2-9b-it",
  "system_fingerprint": "fp_10c08bf97d",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "I'm reading a book about anti-gravity. It's impossible to put down!  \ud83d\udcda\ud83d\ude04  \n\n"
      },
      "finish_reason": "stop",
      "logprobs": null
    }
  ],
  "usage": {
    "prompt_tokens": 28,
    "completion_tokens": 27,
    "total_tokens": 55,
    "prompt_time": 0.003,
    "completion_time": 0.054,
    "total_time": 0.057
  }
}


## rate limits

- [rate limits](https://console.groq.com/docs/rate-limits)
- [errors](https://console.groq.com/docs/errors)


In [49]:
import requests
import os
from datetime import datetime, timedelta

# Set your Groq API key
api_key = os.environ.get("GROQ_API_KEY")

# Groq API endpoint
# https://console.groq.com/docs/api-reference#chat-create
url = "https://api.groq.com/openai/v1/chat/completions"

# Headers
headers = {
    "Authorization": f"Bearer {api_key}",
    "Content-Type": "application/json"
}

# Example request payload
payload = {
    "model": "llama-3.1-8b-instant", #llama-3.1-8b-instant, gemma2-9b-it
    "messages": [{"role": "user", "content": "Hello, how are you?"}]
}

def parse_time(time_str):
    if time_str.endswith('ms'):
        return float(time_str[:-2]) / 1000  # Convert milliseconds to seconds
    elif time_str.endswith('s'):
        return float(time_str[:-1])
    else:
        try:
            return float(time_str)  # Assume it's already in seconds
        except ValueError:
            return 0  # Default to 0 if format is unrecognized

def check_rate_limits():
    response = requests.post(url, json=payload, headers=headers)

    if response.status_code == 200:
        # Token usage limits
        token_limit = int(response.headers.get('x-ratelimit-limit-tokens', 0))
        tokens_remaining = int(response.headers.get('x-ratelimit-remaining-tokens', 0))
        token_reset = parse_time(response.headers.get('x-ratelimit-reset-tokens', '0'))

        # Daily request limits
        daily_limit = int(response.headers.get('x-ratelimit-limit-requests', 0))
        requests_remaining = int(response.headers.get('x-ratelimit-remaining-requests', 0))
        request_reset = parse_time(response.headers.get('x-ratelimit-reset-requests', '0'))

        # Calculate reset times
        token_reset_time = datetime.now() + timedelta(seconds=token_reset)
        request_reset_time = datetime.now() + timedelta(seconds=request_reset)

        print(f"Token Usage:")
        print(f"  Limit: {token_limit} tokens per minute")
        print(f"  Remaining: {tokens_remaining} tokens")
        print(f"  Resets in: {token_reset:.2f} seconds")
        print(f"  Resets at: {token_reset_time.strftime('%Y-%m-%d %H:%M:%S')}")
        print("\nDaily Request Limits:")
        print(f"  Limit: {daily_limit} requests per day")
        print(f"  Remaining: {requests_remaining} requests")
        print(f"  Resets in: {request_reset:.2f} seconds")
        print(f"  Resets at: {request_reset_time.strftime('%Y-%m-%d %H:%M:%S')}")

        # Check usage for this specific request
        usage = response.json().get('usage', {})
        print("\nThis request used:")
        print(f"  Prompt tokens: {usage.get('prompt_tokens', 0)}")
        print(f"  Completion tokens: {usage.get('completion_tokens', 0)}")
        print(f"  Total tokens: {usage.get('total_tokens', 0)}")

    else:
        print(f"Error: {response.status_code}")
        print(response.text)

if __name__ == "__main__":
    check_rate_limits()


Token Usage:
  Limit: 131072 tokens per minute
  Remaining: 131063 tokens
  Resets in: 0.00 seconds
  Resets at: 2024-08-05 14:21:34

Daily Request Limits:
  Limit: 14400 requests per day
  Remaining: 14399 requests
  Resets in: 6.00 seconds
  Resets at: 2024-08-05 14:21:40

This request used:
  Prompt tokens: 16
  Completion tokens: 50
  Total tokens: 66
