<a href="https://colab.research.google.com/github/chemala/GenAI_StudyBuddy/blob/dev/rag_study_assistant_tavily_ui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pdfplumber faiss-cpu sentence-transformers tavily-python gradio langchain


In [None]:
import os, io
import pdfplumber
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from tavily import TavilyClient
import gradio as gr

In [None]:
TAVILY_API_KEY = "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" #1000 Free Credits / Month

print("Tavily key set:", bool(TAVILY_API_KEY))

Extract text from PDF

In [None]:
# def extract_pdf_text(pdf_bytes):
#     text = ""
#     with pdfplumber.open(io.BytesIO(pdf_bytes)) as pdf:
#         for page in pdf.pages:
#             t = page.extract_text()
#             if t:
#                 text += t + "\n"
#     return text

def extract_pdf_text_from_path(path):
    text = ""
    with pdfplumber.open(path) as pdf:
        for page in pdf.pages:
            t = page.extract_text()
            if t:
                text += t + "\n"
    return text

Chunk text

In [None]:
def chunk_text(text, chunk_size=600, overlap=120):
    chunks = []
    start = 0
    length = len(text)

    while start < length:
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk)
        start = end - overlap

    return chunks

Make embeddings from text

In [None]:
embedder = SentenceTransformer("all-mpnet-base-v2")
EMB_DIM = embedder.get_sentence_embedding_dimension()

def embed(texts):
    return embedder.encode(texts, convert_to_numpy=True).astype("float32")


Build FAISS Index

In [None]:
def build_index(chunks):
    vectors = embed(chunks)
    index = faiss.IndexFlatL2(EMB_DIM)
    index.add(vectors)
    return index, chunks


In [None]:
from huggingface_hub import login
login()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    use_fast=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype="auto"
)

In [None]:
llm = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=1000,
    do_sample=True,
    temperature=0.3
)

Use Tavily Web Access Layer for web search

In [None]:
def tavily_search(query, api_key, k=3):
    client = TavilyClient(api_key=api_key)
    res = client.search(query, limit=k)
    return [r.get("content", "") for r in res['results']]

Retrieve Local Chunks

In [None]:
def retrieve_local(query, index, chunks, k=5):
    qv = embed([query])
    _, I = index.search(qv, k)
    return [chunks[i] for i in I[0]]


 RAG (local/web/hybrid depending on parameters) - we can just make three separate options to select on of the modes

In [None]:
def rag(query, mode, index, chunks, api_key):
    context = []

    if mode in ["local", "hybrid"]:
        context += retrieve_local(query, index, chunks)

    if mode in ["web", "hybrid"] and api_key:
        context += tavily_search(query, api_key)

    return "\n\n---\n\n".join(context)


In [None]:
state = {
    "index": None,
    "chunks": None
}

In [None]:
def build(files):
    if not files:
        return "❌ No files uploaded."

    texts = []
    for f in files:
        text = extract_pdf_text_from_path(f)
        texts += chunk_text(text)

    idx, ch = build_index(texts)
    state["index"] = idx
    state["chunks"] = ch

    return f"✅ Indexed {len(ch)} chunks."


In [None]:
from sentence_transformers import SentenceTransformer

embedding_model = embedder # Use the same embedder that built the index

In [None]:
import json

def rag_with_llm(question, mode, index, chunks, tavily_api_key, chat_history, top_k=5):
    # 1✅ Embed question
    q_emb = embedding_model.encode(
        [question],
        normalize_embeddings=True
    ).astype("float32")

    # 2✅ Retrieve from FAISS
    D, I = index.search(q_emb, top_k)
    local_context = "\n".join(chunks[i] for i in I[0])

    # 3✅ Optional web search
    web_context = ""
    if mode in ["web", "hybrid"]:
        tavily = TavilyClient(api_key=tavily_api_key)
        web = tavily.search(question, max_results=3)
        web_context = "\n".join(r["content"] for r in web['results'])

    # 4✅ Build prompt
    context = local_context
    if web_context:
        context += "\n\nWeb context:\n" + web_context

    # Include chat history in the prompt
    history_str = ""
    if chat_history:
        history_str = "\n\nPrevious conversation:\n" + "\n".join([f"User: {h['question']}\nAssistant: {h['answer']}" for h in chat_history])

    prompt = f"""<s>[INST]
Based *strictly* on the following context and previous conversation, formulate a helpful response to the query. Provide information or tips *only as directly relevant* to the question and found within the context. Do not ask clarifying questions, introduce new topics, or discuss linguistic nuances unless explicitly asked about them in the question.

Context:
{context}
{history_str}

Question:
{question}
[/INST]
"""

    # 5✅ Generate answer
    result = llm(prompt)
    answer = result[0]["generated_text"].split("[/INST]")[-1].strip()

    # Update chat history
    new_chat_history = chat_history + [{'question': question, 'answer': answer}]

    return answer, new_chat_history

def ask(q, mode, api_key):
    if state["index"] is None:
        # Return an empty string for the question to keep it as is, not clear it on error
        return "❌ Build the index first.", [], gr.update(value=q)

    chat_history = state["chat_history"]

    answer, updated_chat_history = rag_with_llm(
        q,
        mode,
        state["index"],
        state["chunks"],
        api_key,
        chat_history
    )

    state["chat_history"] = updated_chat_history

    # Format chat history for Gradio Chatbot
    formatted_history = [[h['question'], h['answer']] for h in state['chat_history']]

    # Return the answer, formatted history, and a command to clear the question textbox
    return answer, formatted_history, gr.update(value="")

def generate_flashcards(flashcard_topic, num_flashcards, mode, index, chunks, api_key):
    # 2. Construct a query string
    query = f"Information about {flashcard_topic}"

    # 3. Call the rag function to obtain relevant context
    context = rag(query, mode, index, chunks, api_key)

    # 4. Craft a detailed prompt for the LLM
    prompt = f"""<s>[INST]
Based on the following context, generate {num_flashcards} flashcards about '{flashcard_topic}'.
Each flashcard should have a 'question' and an 'answer'.
Return the output as a JSON array of objects, where each object has a 'question' key and an 'answer' key.

Context:
{context}

Example format:
[
  {{"question": "What is X?", "answer": "Y."}},
  {{"question": "What is A?", "answer": "B."}}
]
[/INST]
"""

    # 5. Call the llm pipeline with this prompt
    llm_response = llm(prompt)
    raw_response = llm_response[0]["generated_text"].split("[/INST]")[-1].strip()

    # 6. Parse the LLM's response
    try:
        # Attempt to find the JSON part of the response, as LLMs can sometimes add extra text
        json_start = raw_response.find('[')
        json_end = raw_response.rfind(']') + 1
        if json_start != -1 and json_end != -1 and json_end > json_start:
            json_string = raw_response[json_start:json_end]
            flashcards = json.loads(json_string)
        else:
            # If JSON delimiters not found, try to parse the whole string
            flashcards = json.loads(raw_response)
        return flashcards
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON response: {e}")
        print(f"Raw LLM response: {raw_response}")
        return [] # Return empty list on parsing error

def call_generate_flashcards(flashcard_topic, num_flashcards, mode, api_key):
    if state["index"] is None:
        return "❌ Build the index first before generating flashcards.", []

    flashcards_data = generate_flashcards(
        flashcard_topic,
        int(num_flashcards),
        mode,
        state["index"],
        state["chunks"],
        api_key
    )

    # Format for gr.DataFrame
    formatted_flashcards = [[card['question'], card['answer']] for card in flashcards_data]
    return formatted_flashcards

def launch_ui():
    with gr.Blocks() as demo:
        gr.Markdown("## ✊ RAG Study Assistant")

        files = gr.File(
            file_types=[".pdf"],
            file_count="multiple",
            label="Upload PDFs"
        )

        api = gr.Textbox(
            label="Tavily API Key",
            value=TAVILY_API_KEY,
            type="password"
        )

        build_btn = gr.Button("Build Index")
        status = gr.Textbox(label="Status")

        mode = gr.Radio(
            ["local", "web", "hybrid"],
            value="hybrid",
            label="Search mode"
        )

        question = gr.Textbox(label="Question", lines=2)

        # Add Chatbot component for displaying history
        chatbot = gr.Chatbot(
            label="Chat History",
            height=300 # Set a fixed height for the chatbot window
        )

        answer = gr.Textbox(label="Answer", lines=12)

        # New components for flashcard generation
        flashcard_topic = gr.Textbox(
            label="Flashcard Topic",
            placeholder="e.g., Key concepts from the document"
        )
        num_flashcards = gr.Number(
            label="Number of Flashcards",
            value=5,
            minimum=1,
            maximum=20,
            step=1
        )

        # Add gr.DataFrame for displaying generated flashcards
        flashcard_display = gr.DataFrame(
            headers=["Question", "Answer"],
            value=[],
            label="Generated Flashcards"
        )

        ask_btn = gr.Button("Ask")
        generate_flashcards_btn = gr.Button("Generate Flashcards") # New button

        build_btn.click(build, inputs=files, outputs=status)

        # Update ask_btn.click to return the answer, chat history, and clear the question textbox
        ask_btn.click(
            fn=ask,
            inputs=[question, mode, api],
            outputs=[answer, chatbot, question] # Now outputs both the answer, chat history, and the question component for clearing
        )

        # Attach the call_generate_flashcards function to the new button's click event
        generate_flashcards_btn.click(
            fn=call_generate_flashcards,
            inputs=[flashcard_topic, num_flashcards, mode, api],
            outputs=flashcard_display
        )

        # Function to update the chatbot display from the state's chat_history
        def update_chatbot_display():
            return [[h['question'], h['answer']] for h in state['chat_history']]

        # Initial load and subsequent updates for the chatbot
        demo.load(update_chatbot_display, inputs=None, outputs=chatbot)
        # Removed the problematic line with _js

    demo.launch(share=True, debug=True)

In [None]:
launch_ui()

## Final Task

### Subtask:
Summarize the newly implemented flashcard generation feature, explaining how users can now specify a topic and quantity to create flashcards from their study material and web context, and how these are displayed in the Gradio interface.


## Summary:

### Data Analysis Key Findings

*   **User Interface for Flashcard Generation**: The Gradio user interface was enhanced with two new input components:
    *   A `gr.Textbox` labeled "Flashcard Topic" to allow users to specify the subject for flashcard creation.
    *   A `gr.Number` labeled "Number of Flashcards" which enables users to set the desired quantity of flashcards, with a default of 5, a minimum of 1, and a maximum of 20.
*   **Flashcard Generation Logic**: A new Python function, `generate_flashcards`, was developed to create flashcards. This function:
    *   Constructs a query based on the user-provided `flashcard_topic`.
    *   Retrieves relevant context using a Retrieval Augmented Generation (RAG) approach, drawing from both local study materials (through an index and chunks) and optional web searches.
    *   Sends a carefully crafted prompt to a Large Language Model (LLM) to generate flashcards in a JSON array format, each with "question" and "answer" keys.
    *   Includes robust JSON parsing to extract the flashcards from the LLM's raw output, handling cases where the LLM might include additional text.
*   **Flashcard Display Component**: A `gr.DataFrame` component was added to the Gradio UI, labeled "Generated Flashcards," configured with "Question" and "Answer" headers, to visually present the output from the `generate_flashcards` function.
*   **Integrated Flashcard Workflow**: A new "Generate Flashcards" button was added to the UI. When this button is clicked, it invokes an intermediary function, `call_generate_flashcards`. This function gathers the user-specified topic, quantity, search mode, and API key, calls the `generate_flashcards` logic, and then populates the `gr.DataFrame` (`flashcard_display`) with the generated flashcards. An initial check ensures that the RAG index has been built before attempting flashcard generation.

### Insights or Next Steps

*   The newly implemented flashcard generation feature provides users with a powerful tool to create study aids dynamically from their materials and web context, greatly enhancing the study assistant's utility. Users can easily specify their desired topic and quantity, and the system intelligently leverages RAG and LLMs to produce structured flashcards displayed directly within the Gradio interface.
*   Future development could focus on adding functionalities such as saving or exporting the generated flashcards, allowing users to edit flashcard content, or integrating a basic spaced repetition system for active recall.
