In [None]:
# --- Installs (Colab-friendly) ---
!pip -q install gradio PyPDF2 sentence-transformers transformers torch --upgrade

# --- Imports ---
import re
import os
import gc
import torch
import gradio as gr
from typing import List, Tuple

from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM

# --- Helpers: PDF -> text ---
def load_pdf_text(pdf_file) -> str:
    """
    Accepts a file path or file-like object from Gradio and returns concatenated text.
    Skips pages with no extractable text.
    """
    reader = PdfReader(pdf_file)
    text_parts = []
    for page in reader.pages:
        try:
            page_text = page.extract_text() or ""
        except Exception:
            page_text = ""
        if page_text.strip():
            text_parts.append(page_text)
    return "\n".join(text_parts)

# --- Chunking (same defaults as your code) ---
def split_text(text: str, chunk_size: int = 700, chunk_overlap: int = 200) -> List[str]:
    chunks = []
    start = 0
    N = len(text)
    # guard against degenerate overlap
    step = max(1, chunk_size - chunk_overlap)
    while start < N:
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk)
        start += step
    return chunks

# --- Cleaning (your function verbatim) ---
def clean_text_final(text: str) -> str:
    text = re.sub(r'-\n', '', text)
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r'\b(\w+)( \1){2,}\b', r'\1', text)
    text = re.sub(r'\b(\w{1,3})( \w{1,3}){2,}\b', '', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\s([.,])', r'\1', text)
    return text.strip()

# --- State containers (populated after "Build Index") ---
_embed_model = None
_tokenizer = None
_model = None

def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

def init_models(model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    global _embed_model, _tokenizer, _model
    if _embed_model is None:
        _embed_model = SentenceTransformer('all-MiniLM-L6-v2', device=get_device())
    if _tokenizer is None or _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Try to be light on memory
        _model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
        )
        _model.to(get_device())
        _model.eval()

# --- Build index (embeddings) ---
def build_index(pdf, chunk_size, chunk_overlap) -> Tuple[str, list, list]:
    """
    Returns (status_message, cleaned_chunks, chunk_embeddings)
    """
    if pdf is None:
        return ("Please upload a PDF first.", None, None)

    init_models()  # lazy-load models

    text = load_pdf_text(pdf.name if hasattr(pdf, "name") else pdf)
    if not text.strip():
        return ("No extractable text found in the PDF.", None, None)

    raw_chunks = split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    cleaned_chunks = [clean_text_final(c) for c in raw_chunks]
    # Remove empty chunks
    cleaned_chunks = [c for c in cleaned_chunks if c.strip()]

    if len(cleaned_chunks) == 0:
        return ("Text was extracted, but cleaning produced empty chunks.", None, None)

    # Embed
    chunk_embeddings = _embed_model.encode(cleaned_chunks, convert_to_tensor=True, show_progress_bar=True)
    return (f"Index built ✅ | {len(cleaned_chunks)} chunks", cleaned_chunks, chunk_embeddings)

# --- Retrieve top-k chunks ---
def retrieve(query: str, cleaned_chunks: List[str], chunk_embeddings, top_k: int = 3) -> List[str]:
    query_emb = _embed_model.encode(query, convert_to_tensor=True)
    cos_scores = util.cos_sim(query_emb, chunk_embeddings)[0]
    top_results = torch.topk(cos_scores, k=min(top_k, len(cleaned_chunks)))
    return [cleaned_chunks[i] for i in top_results.indices.tolist()]

# --- LLM answer generation (same prompt pattern you used) ---
def answer_query(query: str, cleaned_chunks: List[str], chunk_embeddings, max_new_tokens: int = 200, top_k: int = 3) -> Tuple[str, str]:
    if not query.strip():
        return "Please enter a question.", ""

    if cleaned_chunks is None or chunk_embeddings is None:
        return "Please upload a PDF and click 'Build Index' first.", ""

    top_chunks = retrieve(query, cleaned_chunks, chunk_embeddings, top_k=top_k)
    context = " ".join(top_chunks)

    messages = [{
        "role": "user",
        "content": (
            "Answer the question using ONLY the context below and if you donot know the answer say your are not trained on this type of data.\n\n"
            f"Context: {context}\n\n"
            f"Question: {query}"
        )
    }]

    inputs = _tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(_model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = _model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    # Slice off the prompt tokens
    prompt_len = inputs["input_ids"].shape[-1]
    generated_ids = outputs[0][prompt_len:]
    generated_text = _tokenizer.decode(generated_ids, skip_special_tokens=True)

    # For transparency, also return the concatenated top chunks
    joined = "\n\n--- Retrieved Chunk ---\n\n".join(top_chunks)
    return generated_text.strip(), joined

# --- Clear GPU memory (optional button) ---
def reset_models():
    global _embed_model, _tokenizer, _model
    _embed_model = None
    _tokenizer = None
    _model = None
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return "Cleared models from memory."

# ------------------ Gradio UI ------------------
with gr.Blocks(title="RAG PDF Q&A (TinyLlama)", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 📄🔎 RAG PDF Q&A\nUpload a PDF, build the index, then ask questions grounded in the document.")
    with gr.Row():
        with gr.Column(scale=1):
            pdf = gr.File(label="Upload PDF", file_types=[".pdf"])
            chunk_size = gr.Slider(200, 2000, value=700, step=50, label="Chunk Size")
            chunk_overlap = gr.Slider(0, 1000, value=200, step=50, label="Chunk Overlap")
            build_btn = gr.Button("Build Index", variant="primary")
            status = gr.Markdown("Status: _waiting for PDF_")
            clear_btn = gr.Button("Clear Models (free VRAM)")
        with gr.Column(scale=2):
            query = gr.Textbox(label="Ask a question about the PDF")
            top_k = gr.Slider(1, 10, value=3, step=1, label="Top-K Chunks")
            max_new_tokens = gr.Slider(32, 1024, value=200, step=16, label="Max New Tokens")
            ask_btn = gr.Button("Answer", variant="primary")
            answer = gr.Markdown(label="Answer")
            with gr.Accordion("Show retrieved chunks", open=False):
                chunks_view = gr.Markdown()

    # App state: keep chunks & embeddings between calls
    state_chunks = gr.State(value=None)
    state_embeds = gr.State(value=None)

    # Wire buttons
    build_btn.click(
        fn=build_index,
        inputs=[pdf, chunk_size, chunk_overlap],
        outputs=[status, state_chunks, state_embeds],
        api_name="build_index"
    )

    ask_btn.click(
        fn=answer_query,
        inputs=[query, state_chunks, state_embeds, max_new_tokens, top_k],
        outputs=[answer, chunks_view],
        api_name="ask"
    )

    clear_btn.click(
        fn=reset_models,
        inputs=[],
        outputs=[status],
        api_name="clear_models"
    )

# Launch
demo.queue().launch(debug=True)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.6/324.6 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hIt looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://e3b2cfbf4d57ac14e6.

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]