<a href="https://colab.research.google.com/github/mahirbarot/thrifty-ai/blob/main/task_1_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import google.generativeai as genai


def load_txt_documents(folder_path="data"):
    """
    Load all .txt files from the given folder as documents.

    """
    documents = {}
    for idx, filename in enumerate(os.listdir(folder_path), start=1):
        if filename.endswith(".txt"):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, "r", encoding="utf-8") as f:
                text = f.read().strip()
            doc_id = f"doc{idx}"
            documents[doc_id] = {
                "title": os.path.splitext(filename)[0],
                "text": text
            }
    print(f"Loaded {len(documents)} text documents from '{folder_path}'")
    return documents


def embed_documents(docs, model_name='all-MiniLM-L6-v2'):

    """

    Embed all documents using sentence-transformers

    """

    print("Loading embedding model...")
    model = SentenceTransformer(model_name)

    doc_ids = []
    doc_texts = []
    doc_titles = []

    for doc_id, doc_data in docs.items():
        doc_ids.append(doc_id)
        doc_texts.append(doc_data['text'])
        doc_titles.append(doc_data['title'])

    print("Embedding documents...")
    embeddings = model.encode(doc_texts, convert_to_numpy=True)
    return model, embeddings, doc_ids, doc_texts, doc_titles


def create_faiss_index(embeddings):

    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)

    # Normalize embeddings before adding to index
    normalized_embeddings = embeddings.copy()
    faiss.normalize_L2(normalized_embeddings)
    index.add(normalized_embeddings)
    return index



# Retrieve Top K Documents


def retrieve_documents(query, model, index, doc_ids, doc_texts, doc_titles, top_k=3):

    print(f"\nQuery: {query}")


    num_docs = len(doc_ids)
    top_k = min(top_k, num_docs)

    query_embedding = model.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)
    similarities, indices = index.search(query_embedding, top_k)

    retrieved_docs = []
    seen_ids = set()  # Track seen document IDs to avoid duplicates

    for i, idx in enumerate(indices[0]):
        doc_id = doc_ids[idx]

        if doc_id in seen_ids:
            continue
        seen_ids.add(doc_id)

        retrieved_docs.append({
            'doc_id': doc_id,
            'title': doc_titles[idx],
            'text': doc_texts[idx],
            'similarity_score': float(similarities[0][i])
        })

    return retrieved_docs


def generate_answer(query, retrieved_docs, api_key):

    genai.configure(api_key=api_key)
    context = "\n\n".join([
        f"Document: {doc['title']}\n{doc['text']}"
        for doc in retrieved_docs
    ])
    prompt = f"""Based on the following context, answer the question. If the answer cannot be found, say so.

Context:
{context}

Question: {query}

Answer:"""
    print("\nGenerating answer using Gemini...")
    model = genai.GenerativeModel('gemini-2.0-flash-exp')

    try:
        response = model.generate_content(
            prompt,
            generation_config=genai.types.GenerationConfig(
                temperature=0.3,
                max_output_tokens=200,
            )
        )

        # Check if response has valid content
        if response.candidates and response.candidates[0].content.parts:
            answer = response.text
        else:
            answer = "Could not generate an answer. The response was blocked or empty."

    except Exception as e:
        print(f"Error generating answer: {e}")
        answer = "Could not generate an answer due to unavailability of data."

    return answer


def calculate_confidence(retrieved_docs):
    if not retrieved_docs:
        return 0.0

    # Filter out invalid similarity scores
    valid_scores = [doc['similarity_score'] for doc in retrieved_docs
                    if doc['similarity_score'] > -1e10]

    if not valid_scores:
        return 0.0

    avg_similarity = np.mean(valid_scores)

    if len(valid_scores) > 1:
        gap = valid_scores[0] - valid_scores[1]
        confidence = (avg_similarity * 0.7) + (gap * 0.3)
    else:
        confidence = avg_similarity * 0.7

    return float(min(max(confidence, 0), 1))



# MAIN

def main():
    GEMINI_API_KEY = "AIzaSyCfo3sw-31M5Xu3U8sLvEelFRLvZrzra-8"

    # Load text files as documents
    documents = load_txt_documents("data")

    model, embeddings, doc_ids, doc_texts, doc_titles = embed_documents(documents)
    index = create_faiss_index(embeddings)

    queries = [
        "What is redis?",
        "Which star is in center of solar system?",  # out of context question
        "What is react?"
    ]

    for query in queries:
        print("\n" + "="*80)
        retrieved_docs = retrieve_documents(query, model, index, doc_ids, doc_texts, doc_titles, top_k=3)
        answer = generate_answer(query, retrieved_docs, GEMINI_API_KEY)
        confidence = calculate_confidence(retrieved_docs)

        print("\n" + "-"*80)
        print("RESULTS")
        print("-"*80)
        print(f"\nFinal Answer:\n{answer}")
        print("\nRetrieved Documents:")
        for i, doc in enumerate(retrieved_docs, 1):
            print(f"Doc ID:{i}. {doc['title']} (similarity score: {doc['similarity_score']:.4f})")
        print(f"\nOverall Confidence Score: {confidence:.4f}")
        print("="*80)

if __name__ == "__main__":
    main()

Loaded 6 text documents from 'data'
Loading embedding model...
Embedding documents...


Query: What is redis?

Generating answer using Gemini...

--------------------------------------------------------------------------------
RESULTS
--------------------------------------------------------------------------------

Final Answer:
Redis is an in-memory data structure store used as a cache, message broker, and database.


Retrieved Documents:
Doc ID:1. redis (similarity score: 0.5846)
Doc ID:2. fastapi (similarity score: 0.1877)
Doc ID:3. git (similarity score: 0.1649)

Overall Confidence Score: 0.3377


Query: Which star is in center of solar system?

Generating answer using Gemini...

--------------------------------------------------------------------------------
RESULTS
--------------------------------------------------------------------------------

Final Answer:
The provided documents do not contain information about the star at the center of the solar system.


Retrieved Documents: