In [None]:
# Install Dependencies
%pip install streamlit pymupdf sentence-transformers faiss-cpu transformers accelerate bitsandbytes pyngrok python-dotenv tiktoken rank-bm25 pandas plotly


In [None]:
# Setup Authentication
import os
from getpass import getpass
from dotenv import load_dotenv
from pyngrok import ngrok
import huggingface_hub

# Load environment variables
load_dotenv()

# Get Hugging Face token
hf_token = getpass("Enter your Hugging Face token: ")
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token

# Login to Hugging Face
huggingface_hub.login(token=hf_token)

# Get ngrok token
ngrok_token = os.getenv("NGROK_AUTH_TOKEN")
if ngrok_token:
    !ngrok authtoken $ngrok_token
    print("Authentication setup complete!")
else:
    print("Please set NGROK_AUTH_TOKEN in .env file")


In [None]:
%%writefile app.py
import streamlit as st
import fitz
import re
import os
import tiktoken
import numpy as np
import faiss
from typing import List, Dict
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from rank_bm25 import BM25Okapi
import json
import time

st.set_page_config(page_title="RAG PDF Chat", layout="wide")

st.markdown("""
<style>
    /* Remove upload zone border */
    .upload-zone {
        border: none;
        padding: 0;
        margin: 0;
    }
    
    /* Remove file uploader border completely */
    .stFileUploader > div {
        border: none !important;
    }
    
    .stFileUploader > div > div {
        border: none !important;
    }
    
    /* Chat container */
    .chat-container {
        max-width: 800px;
        margin: 0 auto;
    }
    
    /* User messages - right aligned */
    .user-message {
        background: #007bff;
        color: white;
        padding: 12px 18px;
        border-radius: 18px;
        margin: 10px 0;
        text-align: right;
        margin-left: 20%;
        max-width: 80%;
        word-wrap: break-word;
    }
    
    /* Bot messages - left aligned */
    .bot-message {
        background: #f8f9fa;
        color: #333;
        padding: 12px 18px;
        border-radius: 18px;
        margin: 10px 0;
        border: 1px solid #e9ecef;
        margin-right: 20%;
        max-width: 80%;
        word-wrap: break-word;
    }
    
    /* Context boxes - better visibility */
    .context-box {
        background: #f8f9fa;
        border-left: 4px solid #28a745;
        padding: 15px;
        margin: 10px 0;
        border-radius: 5px;
        color: #333;
        border: 1px solid #dee2e6;
    }
    
    /* Citations - dark theme for visibility */
    .citation {
        background: #343a40;
        color: #ffffff;
        border: 1px solid #495057;
        padding: 4px 8px;
        border-radius: 12px;
        font-size: 0.8em;
        margin: 2px;
        display: inline-block;
        font-weight: 500;
    }
    
    /* Professional slider styling */
    .stSlider > div > div > div {
        background: #f8f9fa !important;
        border-radius: 8px !important;
        height: 6px !important;
        border: 1px solid #e9ecef !important;
    }
    
    .stSlider > div > div > div > div {
        background: linear-gradient(90deg, #28a745, #20c997) !important;
        border-radius: 8px !important;
        height: 6px !important;
        box-shadow: none !important;
    }
    
    .stSlider > div > div > div > div > div {
        background: #ffffff !important;
        border: 2px solid #28a745 !important;
        border-radius: 50% !important;
        width: 18px !important;
        height: 18px !important;
        box-shadow: 0 1px 3px rgba(40, 167, 69, 0.2) !important;
        transition: all 0.2s ease !important;
        cursor: pointer !important;
    }
    
    .stSlider > div > div > div > div > div:hover {
        transform: scale(1.05) !important;
        box-shadow: 0 2px 4px rgba(40, 167, 69, 0.3) !important;
        border-color: #20c997 !important;
    }
    
    .stSlider > div > div > div > div > div:active {
        transform: scale(0.95) !important;
    }
    
    /* Slider track styling */
    .stSlider > div > div > div > div > div::before {
        content: '' !important;
        position: absolute !important;
        top: 50% !important;
        left: 50% !important;
        transform: translate(-50%, -50%) !important;
        width: 6px !important;
        height: 6px !important;
        background: #28a745 !important;
        border-radius: 50% !important;
    }
    
    /* Hide the value display on sliders */
    .stSlider > div > div > div > div > div > div {
        display: none !important;
    }
    
    /* Hide any other value displays */
    .stSlider label {
        display: none !important;
    }
    
    .stSlider > div > div > div > div > div > div > div {
        display: none !important;
    }
    
    /* Custom checkbox colors */
    .stCheckbox > div > label > div[data-testid="stMarkdownContainer"] {
        color: #333 !important;
    }
    
    /* Better button styling */
    .stButton > button {
        background: #007bff;
        color: white;
        border: none;
        border-radius: 6px;
        padding: 8px 16px;
    }
    
    .stButton > button:hover {
        background: #0056b3;
    }
    
    /* Sidebar improvements */
    .css-1d391kg {
        background: #f8f9fa;
    }
    
    /* Better spacing - minimal top padding */
    .main .block-container {
        padding-top: 0.5rem;
        padding-bottom: 2rem;
    }
    
    /* Remove extra header padding */
    .stApp > div:first-child {
        padding-top: 0;
    }
    
    /* Hide Streamlit header */
    header[data-testid="stHeader"] {
        display: none;
    }
    
    /* Remove any extra margins from title */
    .stApp h1 {
        margin-top: 0;
        padding-top: 0;
    }
    
    /* Ensure clean top spacing */
    .stApp > div:first-child {
        padding-top: 0 !important;
        margin-top: 0 !important;
    }
</style>
""", unsafe_allow_html=True)

# Model Manager for lazy loading
class SimpleModelManager:
    def __init__(self):
        self._embedding_model = None
        self._reranker = None
        self._mistral_model = None
        self._mistral_tokenizer = None
    
    @st.cache_resource
    def get_embedding_model(_self):
        if _self._embedding_model is None:
            _self._embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
        return _self._embedding_model
    
    @st.cache_resource
    def get_reranker(_self):
        if _self._reranker is None:
            _self._reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        return _self._reranker
    
    @st.cache_resource
    def get_mistral_model(_self):
        if _self._mistral_model is None or _self._mistral_tokenizer is None:
            hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
            if not hf_token:
                raise ValueError("Hugging Face token not found")
            
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )
            
            tokenizer = AutoTokenizer.from_pretrained(
                "mistralai/Mistral-7B-Instruct-v0.1", 
                token=hf_token
            )
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(
                "mistralai/Mistral-7B-Instruct-v0.1",
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True,
                token=hf_token
            )
            
            _self._mistral_model = model
            _self._mistral_tokenizer = tokenizer
        return _self._mistral_model, _self._mistral_tokenizer

# Utility functions
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = "".join([doc.load_page(page_num).get_text() for page_num in range(len(doc))])
    doc.close()
    return text

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'^\d+\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'\n+', '\n', text)
    return text.strip()

def count_tokens(text: str) -> int:
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
    return len(encoding.encode(text))

def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[Dict]:
    sentences = [s.strip() for s in re.split(r'[.!?]+', text) if s.strip()]
    chunks, current_chunk, chunk_id = [], "", 0
    
    for sentence in sentences:
        test_chunk = current_chunk + " " + sentence if current_chunk else sentence
        if count_tokens(test_chunk) > chunk_size and current_chunk:
            chunks.append({
                "id": chunk_id, 
                "text": current_chunk.strip(), 
                "token_count": count_tokens(current_chunk), 
                "char_count": len(current_chunk)
            })
            chunk_id += 1
            overlap_text = current_chunk[-overlap:] if len(current_chunk) > overlap else current_chunk
            current_chunk = overlap_text + " " + sentence if overlap_text else sentence
        else:
            current_chunk = test_chunk
    
    if current_chunk.strip():
        chunks.append({
            "id": chunk_id, 
            "text": current_chunk.strip(), 
            "token_count": count_tokens(current_chunk), 
            "char_count": len(current_chunk)
        })
    return chunks

# Simple Retriever
class SimpleRetriever:
    def __init__(self, embedding_model, faiss_index, bm25_index, metadata, reranker=None):
        self.embedding_model = embedding_model
        self.faiss_index = faiss_index
        self.bm25_index = bm25_index
        self.metadata = metadata
        self.reranker = reranker
    
    def retrieve(self, query: str, k: int = 3) -> List[Dict]:
        # Dense retrieval
        query_embedding = self.embedding_model.encode([query])[0]
        distances, dense_indices = self.faiss_index.search(query_embedding.reshape(1, -1).astype('float32'), k*2)
        
        # Sparse retrieval
        bm25_scores = self.bm25_index.get_scores(query.split())
        sparse_indices = np.argsort(bm25_scores)[::-1][:k*2]
        
        # Combine results
        all_indices = list(set(dense_indices[0].tolist() + sparse_indices.tolist()))
        
        chunks_with_scores = []
        for idx in all_indices:
            if idx < len(self.metadata["chunks"]):
                chunk = self.metadata["chunks"][idx].copy()
                
                # Calculate scores
                dense_score = 0
                if idx in dense_indices[0]:
                    dense_idx = np.where(dense_indices[0] == idx)[0]
                    if len(dense_idx) > 0:
                        dense_score = 1 / (1 + distances[0][dense_idx[0]])
                
                sparse_score = bm25_scores[idx] / (np.max(bm25_scores) + 1e-8)
                combined_score = 0.6 * dense_score + 0.4 * sparse_score
                
                chunk["dense_score"] = dense_score
                chunk["sparse_score"] = sparse_score
                chunk["combined_score"] = combined_score
                chunks_with_scores.append(chunk)
        
        # Sort by combined score
        chunks_with_scores.sort(key=lambda x: x["combined_score"], reverse=True)
        
        # Rerank if available
        if self.reranker and len(chunks_with_scores) > k:
            query_chunk_pairs = [(query, chunk["text"]) for chunk in chunks_with_scores[:k*2]]
            rerank_scores = self.reranker.predict(query_chunk_pairs)
            
            for i, chunk in enumerate(chunks_with_scores[:k*2]):
                chunk["rerank_score"] = rerank_scores[i]
                chunk["final_score"] = 0.7 * chunk["combined_score"] + 0.3 * rerank_scores[i]
        else:
            for chunk in chunks_with_scores:
                chunk["final_score"] = chunk["combined_score"]
        
        chunks_with_scores.sort(key=lambda x: x["final_score"], reverse=True)
        
        # Add ranking
        for i, chunk in enumerate(chunks_with_scores[:k]):
            chunk["rank"] = i + 1
        
        return chunks_with_scores[:k]

# Simple RAG System
class SimpleRAGSystem:
    def __init__(self, retriever, model, tokenizer):
        self.retriever = retriever
        self.model = model
        self.tokenizer = tokenizer
    
    def answer_question(self, question: str, k: int = 3) -> Dict:
        retrieved_chunks = self.retriever.retrieve(question, k=k)
        
        # Create context with citations
        context_parts = []
        for i, chunk in enumerate(retrieved_chunks):
            citation = f"[{i+1}]"
            context_parts.append(f"{citation} {chunk['text']}")
        context = "\n\n".join(context_parts)
        
        # Enhanced prompt
        rag_prompt = f"""<s>[INST] You are an expert AI assistant. Answer the question based ONLY on the provided context. Use citations [1], [2], etc. to reference specific sources. If the answer is not in the context, say "I cannot find this information in the provided context."

CONTEXT:
{context}

QUESTION: {question}

Please provide a comprehensive answer with proper citations based on the context: [/INST]"""
        
        response = self.generate_response(rag_prompt)
        return {
            "question": question, 
            "answer": response, 
            "retrieved_chunks": retrieved_chunks, 
            "context": context
        }
    
    def generate_response(self, prompt: str, max_length: int = 512) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        
        # Move to same device as model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs, 
                max_new_tokens=max_length, 
                temperature=0.2, 
                top_p=0.85, 
                do_sample=True, 
                pad_token_id=self.tokenizer.eos_token_id, 
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up the response to show only the actual answer
        if prompt in response:
            response = response.split(prompt)[-1].strip()
        
        # Remove any remaining prompt artifacts
        if "[/INST]" in response:
            response = response.split("[/INST]")[-1].strip()
        
        # Remove any remaining context or question references
        if "CONTEXT:" in response:
            response = response.split("CONTEXT:")[0].strip()
        if "QUESTION:" in response:
            response = response.split("QUESTION:")[0].strip()
        
        # Clean up any remaining artifacts
        response = response.replace("[INST]", "").replace("[/INST]", "").strip()
        
        return response

# Initialize model manager
model_manager = SimpleModelManager()

# Main UI
st.title("RAG PDF Chat System")
st.markdown("Upload PDFs and ask questions about their content")

# Sidebar for settings
with st.sidebar:
    st.header("Settings")
    
    # Model selection
    model_type = st.selectbox(
        "Select Model",
        ["Mistral-7B", "FLAN-T5", "GPT-2"],
        help="Choose the language model for generation"
    )
    
    # Retriever selection
    retriever_type = st.selectbox(
        "Select Retriever",
        ["Hybrid (Dense + Sparse)", "Dense Only", "Sparse Only"],
        help="Choose the retrieval method"
    )
    
    # Number of chunks
    num_chunks = st.slider(
        "Number of chunks to retrieve",
        min_value=1,
        max_value=10,
        value=3,
        help="More chunks = more context but slower"
    )
    
    # Show advanced options
    show_advanced = st.checkbox("Show Advanced Options")
    if show_advanced:
        chunk_size = st.slider("Chunk Size", 200, 1000, 500)
        temperature = st.slider("Temperature", 0.1, 1.0, 0.2)
        max_tokens = st.slider("Max Tokens", 100, 1000, 512)

# Upload zone
uploaded_files = st.file_uploader(
    "Upload PDF files", 
    type="pdf", 
    accept_multiple_files=True,
    help="Upload one or more PDF files to analyze"
)

# Initialize session state
if "conversation" not in st.session_state:
    st.session_state.conversation = []
if "rag_system" not in st.session_state:
    st.session_state.rag_system = None
if "processing" not in st.session_state:
    st.session_state.processing = False

# Process uploaded files
if uploaded_files:
    if st.session_state.rag_system is None:
        st.info("Getting model ready...")
        try:
            # Extract text from all PDFs
            all_pdf_text = ""
            for uploaded_file in uploaded_files:
                with open(f"temp_{uploaded_file.name}", "wb") as f:
                    f.write(uploaded_file.getbuffer())
                raw_text = extract_text_from_pdf(f"temp_{uploaded_file.name}")
                cleaned_text = clean_text(raw_text)
                all_pdf_text += f" {cleaned_text}"
                os.remove(f"temp_{uploaded_file.name}")
            
            # Chunk the text
            chunks = chunk_text(all_pdf_text, chunk_size=500, overlap=50)
            
            # Generate embeddings
            embedding_model = model_manager.get_embedding_model()
            texts = [chunk["text"] for chunk in chunks]
            embeddings = embedding_model.encode(texts)
            
            # Add embeddings to chunks
            chunks_with_embeddings = []
            for i, chunk in enumerate(chunks):
                enhanced_chunk = chunk.copy()
                enhanced_chunk["embedding"] = embeddings[i]
                chunks_with_embeddings.append(enhanced_chunk)
            
            # Create FAISS index
            embeddings_array = np.array([chunk["embedding"] for chunk in chunks_with_embeddings])
            dimension = embeddings_array.shape[1]
            faiss_index = faiss.IndexFlatL2(dimension)
            faiss_index.add(embeddings_array.astype('float32'))
            
            # Create BM25 index
            tokenized_texts = [text.split() for text in texts]
            bm25_index = BM25Okapi(tokenized_texts)
            
            # Load reranker
            reranker = model_manager.get_reranker()
            
            # Create retriever
            retriever = SimpleRetriever(embedding_model, faiss_index, bm25_index, {"chunks": chunks_with_embeddings}, reranker)
            
            # Load Mistral model
            mistral_model, mistral_tokenizer = model_manager.get_mistral_model()
            
            # Create RAG system
            st.session_state.rag_system = SimpleRAGSystem(retriever, mistral_model, mistral_tokenizer)
            
            st.success(f"Successfully processed {len(uploaded_files)} PDF(s) with {len(chunks)} chunks!")
            
        except Exception as e:
            st.error(f"Error processing PDFs: {str(e)}")
            st.stop()
    
    # Chat interface
    st.markdown('<div class="chat-container">', unsafe_allow_html=True)
    
    # Display conversation history
    for i, turn in enumerate(st.session_state.conversation):
        if turn["role"] == "user":
            st.markdown(f'<div class="user-message">{turn["text"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="bot-message">{turn["text"]}</div>', unsafe_allow_html=True)
            
            # Show retrieved context if available
            if "context" in turn and turn["context"]:
                with st.expander(f"Retrieved Context (Turn {i+1})", expanded=False):
                    contexts = turn["context"].split("\n\n")
                    for j, context in enumerate(contexts):
                        if context.strip():
                            st.markdown(f'<div class="context-box"><strong>Source {j+1}:</strong><br>{context}</div>', unsafe_allow_html=True)
    
    # Show typing indicator
    if st.session_state.processing:
        st.markdown('<div class="bot-message">Thinking...</div>', unsafe_allow_html=True)
    
    # Question input
    st.markdown("### Ask a Question")
    user_input = st.text_input(
        "", 
        placeholder="Ask a question about your PDFs...", 
        disabled=st.session_state.processing,
        key="user_input"
    )
    
    # Handle input submission
    if user_input and not st.session_state.processing:
        if "last_input" not in st.session_state or st.session_state.last_input != user_input:
            # Add user message
            st.session_state.conversation.append({"role": "user", "text": user_input})
            st.session_state.processing = True
            st.session_state.last_input = user_input
            st.rerun()
    
    # Process bot response
    if st.session_state.processing and len(st.session_state.conversation) > 0 and st.session_state.conversation[-1]["role"] == "user":
        try:
            user_question = st.session_state.conversation[-1]['text']
            
            # Generate response first
            result = st.session_state.rag_system.answer_question(user_question, k=num_chunks)
            
            # Clear the "Thinking..." message by updating processing state
            st.session_state.processing = False
            
            # Add the response to conversation
            st.session_state.conversation.append({
                "role": "bot", 
                "text": result["answer"],
                "context": result["context"],
                "retrieved_chunks": result["retrieved_chunks"]
            })
            
            # Rerun to show the response
            st.rerun()
            
        except Exception as e:
            st.error(f"Error generating response: {str(e)}")
            st.session_state.conversation.append({
                "role": "bot", 
                "text": "I apologize, but I encountered an error while processing your question. Please try again."
            })
            st.session_state.processing = False
            st.rerun()
    
    st.markdown('</div>', unsafe_allow_html=True)

else:
    st.info("Please upload PDF files to get started")


In [None]:
# Run the Application
import time
import subprocess
import os

# Kill any existing processes
os.system("pkill -f streamlit")
ngrok.kill()
time.sleep(2)

# Start Streamlit
subprocess.Popen(["streamlit", "run", "app.py", "--server.port=8501", "--server.headless=true"])
time.sleep(5)

# Start ngrok tunnel
public_url = ngrok.connect(8501)
print("RAG PDF Chat System URL:", public_url)