<a href="https://colab.research.google.com/github/badrmellal/project2_MIT_Rag_Rlhf/blob/main/MitRag%26Rlhf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


---

# Project 2: MIT RAG System with RLHF
# Author: Badr Mellal & Ilyass Benayed
# Date: 27 April 2025

---

Setup & Imports

---


In [None]:
# Install required libraries
!pip install -q chromadb sentence-transformers transformers ipywidgets==7.7.1 scikit-learn pandas matplotlib seaborn PyPDF2 python-docx tiktoken


# Import libraries
import os, re, time, datetime, uuid, json, sqlite3, tempfile
from pathlib import Path
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import chromadb
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import PyPDF2
from docx import Document
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Setup environment
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("/content/data/chroma_db", exist_ok=True)
os.makedirs("/content/data/feedback", exist_ok=True)
from google.colab import output, files
output.enable_custom_widget_manager()


---

Document Processing Components

---


In [None]:
class MitChunker:
    def __init__(self, chunk_size=500, chunk_overlap=100):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def split_text(self, text):
        if not text:
            return []

        text_length = len(text)

        # Adjust chunk size dynamically
        if text_length < 5000:
            chunk_size = 400
        elif text_length < 15000:
            chunk_size = 800
        else:
            chunk_size = 1200
        chunk_overlap = self.chunk_overlap  # Keep overlap as given

        paragraphs = re.split(r'\n\s*\n', text)
        print(f"Paragraphs found: {len(paragraphs)}")  # Moved print up

        sections = []

        # If badly formatted (1 paragraph = whole text), fallback
        if len(paragraphs) <= 1:
            for i in range(0, len(text), chunk_size - chunk_overlap):
                sections.append(text[i:i+chunk_size])
            return sections

        current_section, current_length = [], 0

        for paragraph in paragraphs:
            if not paragraph.strip():
                continue
            para_length = len(paragraph)

            if current_length + para_length > chunk_size and current_section:
                sections.append("\n\n".join(current_section))

                # Start new section with overlap
                overlap_size, overlap_paragraphs = 0, []
                for prev_para in reversed(current_section):
                    if overlap_size + len(prev_para) <= chunk_overlap:
                        overlap_paragraphs.insert(0, prev_para)
                        overlap_size += len(prev_para)
                    else:
                        break
                current_section = overlap_paragraphs
                current_length = overlap_size

            current_section.append(paragraph)
            current_length += para_length

        if current_section:
            sections.append("\n\n".join(current_section))

        return sections


class DocumentProcessor:
    def __init__(self, chunk_size=500, chunk_overlap=100):
        self.chunker = MitChunker(chunk_size, chunk_overlap)
        self.metadata = {}

    def preprocess_text(self, text):
        if not text: return ""
        text = text.replace('\xa0', ' ')
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'\s([.,:;?!])', r'\1', text)
        return text.strip()

    def process_pdf(self, file_path):
        try:
            text, metadata = "", {}
            with open(file_path, 'rb') as file:
                reader = PyPDF2.PdfReader(file)
                metadata['pages'] = len(reader.pages)
                for page in reader.pages:
                    page_text = page.extract_text()
                    if page_text:
                        text += page_text + "\n\n"

            self.metadata = metadata
            text = self.preprocess_text(text)
            chunks = self.chunker.split_text(text)
            return text, chunks
        except Exception as e:
            print(f"Error processing PDF: {e}")
            return "", []

    def process_docx(self, file_path):
        try:
            text, metadata = "", {}
            doc = Document(file_path)
            for para in doc.paragraphs:
                if para.text.strip():
                    text += para.text.strip() + "\n"

            self.metadata = metadata
            text = self.preprocess_text(text)
            chunks = self.chunker.split_text(text)
            return text, chunks
        except Exception as e:
            print(f"Error processing DOCX: {e}")
            return "", []

    def process_text(self, file_path):
        try:
            with open(file_path, 'rb') as file:
                file_content = file.read()

            text = None
            for encoding in ['utf-8', 'latin-1', 'windows-1252']:
                try:
                    text = file_content.decode(encoding)
                    break
                except UnicodeDecodeError: continue

            if text is None:
                raise ValueError("Could not decode file with any supported encoding")

            self.metadata = {'size': len(file_content), 'format': 'text', 'path': file_path}
            text = self.preprocess_text(text)
            chunks = self.chunker.split_text(text)
            return text, chunks
        except Exception as e:
            print(f"Error processing text file: {e}")
            return "", []

    def process_document(self, file_path=None, file_content=None, file_name=None):
        start_time = time.time()
        text, chunks = "", []

        if file_path is not None:
            if not os.path.exists(file_path):
                return "", []

            file_extension = Path(file_path).suffix.lower()
            file_name = os.path.basename(file_path)

            if file_extension == '.pdf':
                text, chunks = self.process_pdf(file_path)
            elif file_extension in ['.docx', '.doc']:
                text, chunks = self.process_docx(file_path)
            elif file_extension in ['.txt', '.md']:
                text, chunks = self.process_text(file_path)

        elif file_content is not None and file_name:
            file_extension = Path(file_name).suffix.lower()

            with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
                temp_file.write(file_content)
                temp_path = temp_file.name

            if file_extension == '.pdf':
                text, chunks = self.process_pdf(temp_path)
            elif file_extension in ['.docx', '.doc']:
                text, chunks = self.process_docx(temp_path)
            elif file_extension in ['.txt', '.md']:
                text, chunks = self.process_text(temp_path)

            try:
                os.unlink(temp_path)
            except: pass

            self.metadata.update({"uploaded_filename": file_name, "file_size_bytes": len(file_content)})
        else:
            return "", []

        self.metadata.update({
            "filename": file_name,
            "file_type": Path(file_name).suffix.lower(),
            "total_chars": len(text),
            "chunk_count": len(chunks),
            "processing_timestamp": datetime.datetime.now().isoformat()
        })

        return text, chunks



---

Vector Retrieval System

---



In [None]:
class MitRetriever:
    def __init__(self, collection_name="default_collection",
                 model_name='paraphrase-multilingual-MiniLM-L12-v2',
                 persist_directory="/content/data/chroma_db"):
        self.model_name = model_name
        self.collection_name = collection_name
        self.persist_directory = persist_directory
        self.document_chunks = []
        self.chunk_ids = []

        # Initialize ChromaDB
        self.chroma_client = chromadb.PersistentClient(path=persist_directory)
        try:
            self.collection = self.chroma_client.get_collection(name=collection_name)
            print(f"Retrieved existing collection '{collection_name}' with {self.collection.count()} documents")
        except:
            self.collection = self.chroma_client.create_collection(name=collection_name)
            print(f"Created new collection '{collection_name}'")

        # Initialize embedding model
        print(f"Loading embedding model: {model_name}")
        self.model = SentenceTransformer(model_name)

        # Fallback TF-IDF
        self.vectorizer = TfidfVectorizer(lowercase=True, stop_words='english', ngram_range=(1, 2))
        self.tfidf_matrix = None

    def add_documents(self, chunks, metadata=None):
        if not chunks: return

        self.document_chunks = chunks
        self.chunk_ids = [f"chunk_{i}" for i in range(len(chunks))]

        if metadata is None:
            metadata = [{"index": i, "source": "document"} for i in range(len(chunks))]

        # Compute embeddings and add to ChromaDB
        print(f"Computing embeddings for {len(chunks)} chunks...")
        embeddings = self.model.encode(chunks, show_progress_bar=True)
        self.collection.add(
            embeddings=embeddings.tolist(),
            documents=chunks,
            ids=self.chunk_ids,
            metadatas=metadata
        )
        print(f"Added {len(chunks)} chunks to vector database")

        # Also compute TF-IDF as fallback
        self.tfidf_matrix = self.vectorizer.fit_transform(chunks)

    def search(self, query, top_k=5):
        if not self.document_chunks and self.collection.count() == 0:
            return []

        results = []

        # Try ChromaDB with embeddings
        try:
            query_embedding = self.model.encode(query).tolist()
            chroma_results = self.collection.query(
                query_embeddings=query_embedding,
                n_results=top_k,
                include=["documents", "distances", "metadatas"]
            )

            if chroma_results and 'documents' in chroma_results and len(chroma_results['documents']) > 0:
                documents = chroma_results['documents'][0]
                distances = chroma_results.get('distances', [[]])[0]
                metadatas = chroma_results.get('metadatas', [[]])[0]

                for i, (doc, distance) in enumerate(zip(documents, distances)):
                    metadata = metadatas[i] if i < len(metadatas) else {}
                    score = 1.0 - distance if distance < 1.0 else 0.1

                    results.append({
                        'rank': i + 1,
                        'index': metadata.get('index', i),
                        'score': score,
                        'text': doc,
                        'method': 'embedding',
                        'metadata': metadata
                    })

                return results
        except Exception as e:
            print(f"Error in embedding search: {e}")

        # Fallback to TF-IDF if needed
        if not results and self.tfidf_matrix is not None:
            query_vector = self.vectorizer.transform([query])
            similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
            top_indices = np.argsort(similarities)[-top_k:][::-1]

            for i, idx in enumerate(top_indices):
                score = float(similarities[idx])
                if score > 0:
                    results.append({
                        'rank': i + 1,
                        'index': int(idx),
                        'score': score,
                        'text': self.document_chunks[idx],
                        'method': 'tfidf'
                    })

        return results



---

LLM Integration

---


In [None]:
class HuggingFaceLLM:
    def __init__(self, model_name="google/flan-t5-base"):
        self.model_name = model_name
        try:
            print(f"Loading LLM: {model_name}...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            self.generator = pipeline("text2text-generation", model=self.model, tokenizer=self.tokenizer, max_length=512)
            self.available = True
            print(f"Successfully loaded {model_name}")
        except Exception as e:
            print(f"Error loading model: {e}")
            self.available = False

    def generate_answer(self, query, contexts):
        if not contexts:
            return "No relevant information found to answer this question."

        # Prepare context
        context_text = "\n\n".join([c['text'] for c in contexts])

        # Generate answer with model if available
        if self.available:
            try:
                # Prepare prompt - keep it simple for T5 models
                prompt = f"Answer based on this context: {context_text[:1500]}\n\nQuestion: {query}"

                # Generate answer
                outputs = self.generator(prompt, max_length=300, do_sample=True, temperature=0.7)
                answer = outputs[0]['generated_text']

                # Format answer
                if len(answer.split()) < 5:  # If too short, use extraction
                    return self._extract_answer(query, contexts)

                return f"Based on the document information:\n\n{answer}"

            except Exception as e:
                print(f"Error generating with model: {e}")
                return self._extract_answer(query, contexts)
        else:
            return self._extract_answer(query, contexts)

    def _extract_answer(self, query, contexts):
        # Simple extraction fallback
        if not contexts:
            return "No relevant information found."

        # Get the most relevant passages
        best_context = contexts[0]['text']

        # Find sentences that might answer the query
        all_text = " ".join([c['text'] for c in contexts[:3]])
        sentences = re.split(r'(?<=[.!?])\s+', all_text)

        # Remove common words from query
        stop_words = {'the', 'and', 'is', 'in', 'it', 'to', 'of', 'for', 'a', 'on', 'with', 'what', 'how', 'why'}
        query_words = set(word.lower() for word in query.split() if word.lower() not in stop_words)

        # Score sentences by relevance to query
        best_sentences = []
        for sentence in sentences:
            if not sentence.strip(): continue
            sentence_words = set(word.lower() for word in sentence.split())
            matching_words = sentence_words.intersection(query_words)
            if matching_words:
                best_sentences.append((len(matching_words), sentence))

        # Sort by relevance
        best_sentences.sort(reverse=True)

        if best_sentences:
            answer_sentences = [s[1] for s in best_sentences[:3]]
            return f"Based on the document information:\n\n{' '.join(answer_sentences)}"
        else:
            return f"Based on the document information:\n\n{best_context}"



---

Feedback System

---


In [None]:
class FeedbackSystem:
    def __init__(self, db_path="/content/data/feedback/feedback.db"):
        self.db_path = db_path
        os.makedirs(os.path.dirname(db_path), exist_ok=True)

        # Initialize database
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS feedback (
            id TEXT PRIMARY KEY,
            query TEXT,
            answer TEXT,
            context TEXT,
            rating INTEGER,
            timestamp TEXT,
            model TEXT,
            retrieval_method TEXT,
            metadata TEXT
        )
        ''')
        conn.commit()
        conn.close()

    def add_feedback(self, query, answer, context, rating, model="unknown", retrieval_method="unknown", metadata=None):
        try:
            feedback_id = str(uuid.uuid4())
            timestamp = datetime.datetime.now().isoformat()
            metadata_json = json.dumps(metadata) if metadata else "{}"

            if isinstance(context, (list, dict)):
                context = json.dumps(context)

            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            cursor.execute('''
            INSERT INTO feedback (id, query, answer, context, rating, timestamp, model, retrieval_method, metadata)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', (feedback_id, query, answer, context, rating, timestamp, model, retrieval_method, metadata_json))
            conn.commit()
            conn.close()
            return True
        except Exception as e:
            print(f"Error storing feedback: {e}")
            return False

    def get_feedback_stats(self):
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            cursor.execute("SELECT COUNT(*) FROM feedback")
            total_count = cursor.fetchone()[0]

            cursor.execute("SELECT AVG(rating) FROM feedback")
            avg_rating = cursor.fetchone()[0]

            cursor.execute("SELECT rating, COUNT(*) FROM feedback GROUP BY rating")
            rating_dist = dict(cursor.fetchall())

            conn.close()

            return {
                "total_feedback": total_count,
                "average_rating": avg_rating,
                "rating_distribution": rating_dist
            }
        except Exception as e:
            print(f"Error getting feedback stats: {e}")
            return {}

    def analyze_feedback(self):
        stats = self.get_feedback_stats()
        if not stats or stats.get("total_feedback", 0) == 0:
            return {"status": "No data"}

        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            cursor.execute("SELECT query, answer, context, rating FROM feedback WHERE rating >= 4")
            high_rated = cursor.fetchall()

            cursor.execute("SELECT query, answer, context, rating FROM feedback WHERE rating <= 2")
            low_rated = cursor.fetchall()

            conn.close()

            # Extract insights
            insights = []

            # Add default insights with limited data
            if len(high_rated) < 3 or len(low_rated) < 3:
                insights.append("Not enough feedback data for detailed analysis")
                if len(high_rated) > len(low_rated):
                    insights.append("Users seem to prefer more detailed answers")
                if len(low_rated) > len(high_rated):
                    insights.append("Users prefer more concise answers")
            else:
                # Compare answer length
                high_lengths = [len(answer) for _, answer, _, _ in high_rated]
                low_lengths = [len(answer) for _, answer, _, _ in low_rated]
                avg_high = sum(high_lengths) / len(high_lengths)
                avg_low = sum(low_lengths) / len(low_lengths)

                if avg_high > avg_low * 1.2:
                    insights.append("Longer answers tend to receive higher ratings")
                elif avg_low > avg_high * 1.2:
                    insights.append("Shorter, more concise answers tend to receive higher ratings")

            return {
                "total_feedback": stats["total_feedback"],
                "average_rating": stats["average_rating"],
                "high_rated_count": len(high_rated),
                "low_rated_count": len(low_rated),
                "insights": insights
            }
        except Exception as e:
            print(f"Error analyzing feedback: {e}")
            return {"status": "Error", "message": str(e)}



---

Main RAG System

---


In [None]:
class MitRAGSystem:
    def __init__(self, collection_name="default_collection", chunk_size=500, chunk_overlap=100,
                model_name="google/flan-t5-base", embedding_model="paraphrase-multilingual-MiniLM-L12-v2",
                persist_directory="/content/data/chroma_db"):
        # Initialize components
        self.document_processor = DocumentProcessor(chunk_size, chunk_overlap)
        self.retriever = MitRetriever(collection_name, embedding_model, persist_directory)
        self.llm = HuggingFaceLLM(model_name)
        self.feedback = FeedbackSystem()

        # Document information
        self.document_text = ""
        self.document_chunks = []
        self.document_name = None
        self.document_metadata = {}

        # Conversation tracking
        self.conversation_history = []

        # System metrics
        self.metrics = {
            'document_processing': [],
            'queries': [],
            'feedback': []
        }

        print(f"MIT RAG System initialized with LLM: {model_name} and Embedding: {embedding_model}")

    def process_document(self, file_path=None, file_content=None, file_name=None):
        start_time = time.time()

        # Process document
        text, chunks = self.document_processor.process_document(
            file_path=file_path,
            file_content=file_content,
            file_name=file_name
        )

        if not chunks:
            print("No chunks produced from document")
            return False

        # Store document info
        self.document_text = text
        self.document_chunks = chunks
        self.document_name = file_name or (file_path and os.path.basename(file_path))
        self.document_metadata = self.document_processor.metadata

        # Prepare chunk metadata
        chunk_metadata = [
            {
                "document_name": self.document_name,
                "chunk_index": i,
                "document_type": self.document_metadata.get("file_type", "unknown"),
                "source": "document"
            } for i in range(len(chunks))
        ]

        # Add to retriever
        self.retriever.add_documents(chunks, metadata=chunk_metadata)

        # Track metrics
        processing_time = time.time() - start_time
        self.metrics['document_processing'].append({
            'filename': self.document_name,
            'time': processing_time,
            'chunks': len(chunks)
        })

        print(f"Document processed: {len(chunks)} chunks in {processing_time:.2f}s")
        return True

    def answer_question(self, query, top_k=3):
        if not self.document_chunks and not hasattr(self.retriever, 'collection'):
            return "Please process a document first.", [], None

        start_time = time.time()

        # Retrieve relevant passages
        retrieval_results = self.retriever.search(query, top_k=top_k)

        if not retrieval_results:
            return "No relevant information found to answer this question.", [], None

        # Generate answer
        answer = self.llm.generate_answer(query, retrieval_results)

        # Create conversation entry
        conversation_id = str(uuid.uuid4())
        self.conversation_history.append({
            "id": conversation_id,
            "timestamp": datetime.datetime.now().isoformat(),
            "query": query,
            "answer": answer,
            "contexts": retrieval_results,
            "model": self.llm.model_name,
            "feedback": None
        })

        # Track metrics
        total_time = time.time() - start_time
        self.metrics['queries'].append({
            'query': query,
            'time': total_time,
            'result_count': len(retrieval_results)
        })

        return answer, retrieval_results, conversation_id

    def add_feedback(self, conversation_id, rating):
        # Find conversation entry
        entry = next((item for item in self.conversation_history if item["id"] == conversation_id), None)

        if not entry:
            print(f"Conversation ID {conversation_id} not found")
            return False

        # Update entry with feedback
        entry["feedback"] = {
            "rating": rating,
            "timestamp": datetime.datetime.now().isoformat()
        }

        # Store in feedback database
        success = self.feedback.add_feedback(
            query=entry["query"],
            answer=entry["answer"],
            context=[c["text"] for c in entry["contexts"]],
            rating=rating,
            model=entry["model"],
            retrieval_method=entry["contexts"][0]["method"] if entry["contexts"] else "unknown",
            metadata={"conversation_id": entry["id"], "document_name": self.document_name}
        )

        # Track metrics
        self.metrics['feedback'].append({
            'conversation_id': entry["id"],
            'query': entry["query"],
            'rating': rating
        })

        return success

    def get_document_info(self):
        if not self.document_name:
            return {"status": "No document processed"}

        return {
            "filename": self.document_name,
            "chunks": len(self.document_chunks),
            "total_characters": len(self.document_text),
            **{k: v for k, v in self.document_metadata.items()
               if k not in ['processing_timestamp']}
        }

    def analyze_feedback(self):
        return self.feedback.analyze_feedback()



---

UI & Run System

---


In [None]:
def create_mit_rag_system():
    # Initialize system with improved model
    rag_system = MitRAGSystem(
      collection_name=f"collection_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}",
      model_name="google/flan-t5-base",
      embedding_model="paraphrase-multilingual-MiniLM-L12-v2"
    )

    # Output areas
    doc_info_output = widgets.Output()
    answer_output = widgets.Output()
    analytics_output = widgets.Output()

    # Create upload button that uses file API directly
    upload_button = widgets.Button(
        description='Upload Document',
        button_style='primary'
    )

    process_button = widgets.Button(
        description='Process Document',
        disabled=True,
        button_style='primary'
    )

    # Text display to show uploaded filename
    file_text = widgets.HTML("No file uploaded")

    # Store the uploaded file
    uploaded_file = [None]

    # Create query widgets
    query_input = widgets.Text(
        description='Question:',
        placeholder='Ask a question about the document...',
        disabled=True,
        layout={'width': '80%'}
    )

    query_button = widgets.Button(
        description='Search',
        disabled=True,
        button_style='success'
    )

    # Create feedback widgets
    feedback_widget = widgets.HBox([
        widgets.Label("Rate:"),
        widgets.RadioButtons(
            options=[('⭐', 1), ('⭐⭐', 2), ('⭐⭐⭐', 3), ('⭐⭐⭐⭐', 4), ('⭐⭐⭐⭐⭐', 5)],
            layout={'width': 'max-content'},
            disabled=True
        ),
        widgets.Button(
            description='Submit',
            disabled=True,
            button_style='info',
            layout={'width': 'auto'}
        )
    ])

    # Analytics button
    analytics_button = widgets.Button(
        description='Show Analytics',
        disabled=False,
        button_style='info'
    )

    # Store the current conversation ID
    current_conversation_id = [None]

    def on_upload_click(b):
        try:
            clear_output(wait=True)
            display(widgets.HTML("<h2>MIT RAG System with RLHF</h2>"))
            display(widgets.VBox([
                widgets.HBox([upload_button, process_button]),
                file_text
            ]))
            display(doc_info_output)
            display(widgets.HTML("<h3>Ask a Question</h3>"))
            display(widgets.HBox([query_input, query_button]))
            display(answer_output)
            display(widgets.HBox([feedback_widget, analytics_button]))
            display(analytics_output)

            print("Please select a file in the pop-up dialog...")
            uploaded = files.upload()

            if uploaded:
                filename = next(iter(uploaded))
                content = uploaded[filename]
                file_text.value = f"<b>Uploaded:</b> {filename} ({len(content)} bytes)"
                uploaded_file[0] = (filename, content)
                process_button.disabled = False
            else:
                file_text.value = "<b>Upload cancelled or failed</b>"
        except Exception as e:
            file_text.value = f"<b>Error:</b> {str(e)}"

    def on_process_click(b):
        doc_info_output.clear_output()
        with doc_info_output:
            if not uploaded_file[0]:
                print("Please upload a document first.")
                return

            filename, content = uploaded_file[0]
            print(f"Processing: {filename}")

            success = rag_system.process_document(
                file_content=content,
                file_name=filename
            )

            if success:
                print("\n✅ Document Information:")
                for key, value in rag_system.get_document_info().items():
                    print(f"  {key}: {value}")
                query_input.disabled = False
                query_button.disabled = False
            else:
                print("❌ Document processing failed.")

    def on_query_click(b):
        answer_output.clear_output()
        with answer_output:
            query = query_input.value
            if not query:
                print("Please enter a question.")
                return

            print(f"Query: '{query}'")
            answer, results, conversation_id = rag_system.answer_question(query, top_k=3)

            current_conversation_id[0] = conversation_id
            feedback_widget.children[1].disabled = False
            feedback_widget.children[2].disabled = False

            print("\nAnswer:")
            print(answer)

            print("\nRetrieved passages:")
            for i, result in enumerate(results[:2]):  # Show only top 2 for brevity
                print(f"  Passage {i+1} (Score: {result['score']:.4f})")
                print(f"  {result['text'][:150]}..." if len(result['text']) > 150 else f"  {result['text']}")

    def on_feedback_submit(b):
        if current_conversation_id[0] is None:
            return

        rating = feedback_widget.children[1].value
        rag_system.add_feedback(current_conversation_id[0], rating)

        feedback_widget.children[1].disabled = True
        feedback_widget.children[2].disabled = True

        with answer_output:
            print(f"\n✅ Thank you for your feedback! (Rating: {rating})")

    def on_analytics_click(b):
        analytics_output.clear_output()
        with analytics_output:
            print("System Analytics and Feedback")
            print("=============================")

            # Plot feedback stats
            stats = rag_system.feedback.get_feedback_stats()
            if stats and stats.get("total_feedback", 0) > 0:
                plt.figure(figsize=(10, 4))
                ratings = list(stats["rating_distribution"].keys())
                counts = list(stats["rating_distribution"].values())
                plt.bar(ratings, counts)
                plt.title('Feedback Ratings')
                plt.xlabel('Rating')
                plt.ylabel('Count')
                plt.xticks(range(1, 6))
                plt.show()

                print(f"Average Rating: {stats['average_rating']:.2f} ({stats['total_feedback']} ratings)")
            else:
                print("No feedback data available yet.")

            # Show RLHF analysis
            analysis = rag_system.analyze_feedback()
            if isinstance(analysis, dict) and "insights" in analysis:
                print("\nInsights from feedback:")
                for i, insight in enumerate(analysis["insights"]):
                    print(f"{i+1}. {insight}")
            else:
                print("\nNot enough feedback data for detailed analysis.")

    # Connect callbacks
    upload_button.on_click(on_upload_click)
    process_button.on_click(on_process_click)
    query_button.on_click(on_query_click)
    feedback_widget.children[2].on_click(on_feedback_submit)
    analytics_button.on_click(on_analytics_click)

    # Display UI
    display(widgets.HTML("<h2>MIT RAG System with RLHF</h2>"))
    display(widgets.VBox([
        widgets.HBox([upload_button, process_button]),
        file_text
    ]))
    display(doc_info_output)
    display(widgets.HTML("<h3>Ask a Question</h3>"))
    display(widgets.HBox([query_input, query_button]))
    display(answer_output)
    display(widgets.HBox([feedback_widget, analytics_button]))
    display(analytics_output)

    return rag_system

mit_rag = create_mit_rag_system()