In [1]:
import sys
!{sys.executable} -m pip install --upgrade pip setuptools wheel



In [None]:
import sys
!{sys.executable} -m pip install chromadb gradio sentence-transformers PyPDF2 langchain-text-splitters torch torchvision ollama langchain



In [3]:
import os
import uuid
import torch
import chromadb
import gradio as gr
from sentence_transformers import SentenceTransformer
from ollama import Client
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter


In [4]:
ollama_client = Client()
print("Ollama client initialized!")

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory/1024**3:.1f} GB)")
else:
    print("Running on CPU")

Ollama client initialized!
CUDA available: False
Running on CPU


In [5]:
class DocumentProcessor:
    def __init__(self):
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
        )

    def extract_text_from_pdf(self, file_path: str) -> str:
        text = ""
        try:
            with open(file_path, "rb") as f:
                reader = PyPDF2.PdfReader(f)
                for page in reader.pages:
                    page_text = page.extract_text() or ""  # may return None
                    text += page_text + "\n"
            print(f"Extracted {len(text)} chars from PDF: {file_path}")
        except Exception as e:
            print(f"Error reading PDF: {e}")
        return text

    def extract_text_from_txt(self, file_path: str) -> str:
        text = ""
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                text = f.read()
            print(f"Extracted {len(text)} chars from TXT: {file_path}")
        except Exception as e:
            print(f"Error reading TXT: {e}")
        return text

    def process_document(self, file_path: str, file_type: str):
        if file_type == "pdf":
            text = self.extract_text_from_pdf(file_path)
        elif file_type == "txt":
            text = self.extract_text_from_txt(file_path)
        else:
            raise ValueError(f"Unsupported file type: {file_type}")

        if not text.strip():
            raise ValueError("No text could be extracted from document")

        chunks = self.text_splitter.split_text(text)
        print(f"✂ Split into {len(chunks)} chunks. Example: {chunks[:2]}")
        return chunks

doc_processor = DocumentProcessor()

In [6]:
class VectorStore:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.embedding_model = SentenceTransformer(
            "all-MiniLM-L6-v2", device=self.device
        )
        self.chroma_client = chromadb.Client()
        self.collection_name = "ai_ml_documents"
        try:
            self.chroma_client.delete_collection(name=self.collection_name)
        except:
            pass
        self.collection = self.chroma_client.create_collection(name=self.collection_name)
        print(f"Vector store initialized on {self.device}")

    def add_documents(self, chunks, source_file: str):
        embeddings = self.embedding_model.encode(chunks, convert_to_tensor=False)
        ids = [f"{source_file}{i}{uuid.uuid4().hex[:8]}" for i in range(len(chunks))]
        metadatas = [{"source": source_file, "chunk_id": i} for i in range(len(chunks))]
        self.collection.add(
            documents=chunks,
            embeddings=embeddings.tolist(),
            metadatas=metadatas,
            ids=ids,
        )
        print(f"Added {len(chunks)} chunks from {source_file}")

    def search(self, query: str, top_k: int = 5):
        query_embedding = self.embedding_model.encode([query])[0].tolist()
        results = self.collection.query(query_embeddings=[query_embedding], n_results=top_k)
        print(f"Search for '{query}' retrieved {len(results['documents'][0])} chunks")
        return results

    def get_collection_stats(self):
        try:
            count = self.collection.count()
            all_results = self.collection.get() if count > 0 else {"metadatas":[]}
            sources = {m.get("source","Unknown") for m in all_results["metadatas"]}
            return {
                "total_chunks": count,
                "total_documents": len(sources),
                "sources": list(sources),
                "embedding_model": "all-MiniLM-L6-v2",
                "device": self.device,
            }
        except Exception as e:
            return {"error": str(e)}

    def clear_collection(self):
        try:
            self.chroma_client.delete_collection(name=self.collection_name)
            self.collection = self.chroma_client.create_collection(name=self.collection_name)
            print("Vector collection cleared")
        except Exception as e:
            print(f"Error clearing collection: {e}")

vector_store = VectorStore()

Vector store initialized on cpu


In [None]:
class RAGChatbot:
    def __init__(self, vector_store, ollama_client):
        self.vector_store = vector_store
        self.ollama_client = ollama_client
        self.conversation_history = []

    def get_relevant_context(self, query):
        try:
            results = self.vector_store.search(query, top_k=5)
            docs = results["documents"][0]
            return "\n\n".join(docs)
        except Exception as e:
            print(f"Retrieval error: {e}")
            return ""

    def generate_response(self, query: str):
        context = self.get_relevant_context(query)
        history = "\n".join([
            f"Q: {h['question']}\nA: {h['answer']}" for h in self.conversation_history[-3:]
        ])

        prompt = f"""
You are a helpful assistant answering based ONLY on the uploaded document context.
If unsure, say so.

Conversation history:
{history}

Relevant document context:
{context}

User question: {query}
"""

        try:
            response = self.ollama_client.chat(
                model="qwen3:4b",
                messages=[{"role": "user", "content": prompt}],
                stream=False
            )
            print("DEBUG Ollama raw response:", response)

            answer = ""
            if hasattr(response, "message") and hasattr(response.message, "content"):
                answer = response.message.content
            elif isinstance(response, dict):
                message = response.get("message", {})
                if isinstance(message, dict):
                    answer = message.get("content", "")
                elif "content" in response:
                    answer = response["content"]
            else:
                answer = str(response)

            #Strip trailing whitespace
            answer = answer.strip()

            #Remove <think>...</think> sections
            if "<think>" in answer:
                try:
                    # Split off the visible portion after </think>
                    answer = answer.split("</think>")[-1].strip()
                except Exception:
                    pass

            self.conversation_history.append({"question": query, "answer": answer})
            return answer

        except Exception as e:
            return f"Error from Ollama: {e}"

    def clear_history(self):
        self.conversation_history = []
        print("Conversation history cleared")

chatbot = RAGChatbot(vector_store, ollama_client)

In [8]:
def get_db_stats():
    stats = vector_store.get_collection_stats()
    if "error" in stats:
        return f"DB Error: {stats['error']}"
    return f"""
Vector DB Stats:
- Documents: {stats['total_documents']}
- Chunks: {stats['total_chunks']}
- Model: {stats['embedding_model']}
- Device: {stats['device']}
Sources: {stats['sources']}
"""

def upload_and_process_files(files):
    if not files: 
        return "No files uploaded!", get_db_stats()
    messages = []
    for file in files:
        ext = file.name.split(".")[-1].lower()
        chunks = doc_processor.process_document(file.name, ext)
        vector_store.add_documents(chunks, file.name)
        messages.append(f"{file.name} ({len(chunks)} chunks)")
    return "\n".join(messages), get_db_stats()

def chat_response(message, history):
    if not message.strip():
        return history, ""
    response = chatbot.generate_response(message)
    history.append([message, response])
    return history, ""

def clear_all_data():
    vector_store.clear_collection()
    chatbot.clear_history()
    return "All data cleared!", [], get_db_stats()

In [None]:
with gr.Blocks(title="Document Q&A Assistant") as demo:
    gr.Markdown("# Document-Based AI Q&A")

    with gr.Row():
        with gr.Column(scale=1):
            file_upload = gr.Files(file_types=[".pdf",".txt"], file_count="multiple")
            upload_btn = gr.Button("Process Files", variant="primary")
            upload_status = gr.Textbox(label="File Status")
            db_stats = gr.Textbox(label="DB Stats", value=get_db_stats())
            clear_btn = gr.Button("Clear All Data")
            refresh_btn = gr.Button("Refresh Stats")

        with gr.Column(scale=2):
            chatbot_ui = gr.Chatbot(label="Chatbot", height=500)
            msg_box = gr.Textbox(label="Your Question")
            send_btn = gr.Button("Send")
            clear_chat_btn = gr.Button("Clear Chat")

    upload_btn.click(upload_and_process_files, file_upload, [upload_status, db_stats])
    send_btn.click(chat_response, [msg_box, chatbot_ui], [chatbot_ui, msg_box])
    msg_box.submit(chat_response, [msg_box, chatbot_ui], [chatbot_ui, msg_box])
    clear_btn.click(clear_all_data, outputs=[upload_status, chatbot_ui, db_stats])
    refresh_btn.click(get_db_stats, outputs=[db_stats])
    clear_chat_btn.click(lambda: ([], "Chat cleared!"), outputs=[chatbot_ui, upload_status])

if __name__ == "__main__":
    print("Launching Gradio...")
    demo.launch(share=True, debug=True)

  chatbot_ui = gr.Chatbot(label="Chatbot", height=500)


Launching Gradio...
* Running on local URL:  http://127.0.0.1:7860

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


Keyboard interruption in main thread... closing server.
