# PDF to RAG Processing System
# This notebook processes PDFs for RAG systems by extracting structured content, 
# classifying sections, and adding LLM summaries while preventing hallucination

In [None]:
import fitz 
import re
import os
import json
import pandas as pd
from typing import List, Dict, Any
from langdetect import detect, LangDetectException
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI
from chromadb.utils import embedding_functions
import chromadb
from CSMetadataExtractor import CSMetadataExtractor
# Constants
MIN_SECTION_LENGTH = 20  
MIN_SIMILARITY_THRESHOLD = 0.5 
import matplotlib.pyplot as plt

In [None]:
class ImprovedPDFProcessor:

    def __init__(
        self, 
        classifier_model_path: str = None,
        llm_base_url: str = "http://localhost:4500/v1",
        llm_api_key: str = "lm-studio",
        llm_model: str = "meta-llama-3.1-8b-instruct",
        temperature: float = 0.2
    ):
        self.classifier_model_path = classifier_model_path
        self.classifier = None
        self.tokenizer = None
        self.id2label = None
        
        # Setup LLM
        self.llm = ChatOpenAI(
            base_url=llm_base_url,
            api_key=llm_api_key,
            temperature=temperature,
            model=llm_model,
        )
        
        # Initialize prompt templates
        self.setup_prompts()
        
        # Load classifier if provided
        if classifier_model_path:
            self.load_classifier()
            
        print(f"Initialized PDF processor with LLM: {llm_model}")
    
    def load_classifier(self):
        """Load the pretrained text classifier."""
        try:
            from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
            
            self.tokenizer = AutoTokenizer.from_pretrained(self.classifier_model_path)
            model = AutoModelForSequenceClassification.from_pretrained(self.classifier_model_path)
            self.id2label = model.config.id2label
            self.classifier = pipeline(
                "text-classification", 
                model=model, 
                tokenizer=self.tokenizer
            )
            print(f"Loaded classifier with {len(self.id2label)} classes: {list(self.id2label.values())}")
        except Exception as e:
            print(f"Error loading classifier: {e}")
            print("Proceeding without classification capability")
    
    def setup_prompts(self):
        """Setup improved prompts for different types of content."""
        # i am saving them like this, i tried to save them in a txt file but it was not working
        self.general_prompt = PromptTemplate(
            input_variables=["title", "text", "source"],
            template="""
    You are analyzing ACADEMIC CONTENT from a document titled: "{title}" from source "{source}".
    
    IMPORTANT: Use ONLY the information provided in the text below. Do NOT add any information that is not present in the text. If you don't have enough information, simply state what is available without inventing details.
    
    TEXT TO ANALYZE:
    ```
    {text}
    ```
    
    Instructions:
    1. Summarize ONLY the content provided above
    2. Extract key concepts and terms mentioned explicitly in the text
    3. Identify the main topic based solely on the content provided
    4. Do not add any information that's not in the original text
    5. If the content is incomplete or unclear, reflect that in your response - don't hallucinate details
    
    Format your response exactly as follows:
    Title: <title from the document, or "Untitled" if unclear>
    Content: <key points from the text, maintaining the original hierarchy and organization>
    Summary: <brief factual summary based only on what's in the text>
    Key Concepts: <list of key concepts explicitly mentioned, comma-separated>
    Main Topic: <the overarching topic based only on content provided>
    Source: <source document title>
    """
        )
        
        # Code-specific prompt with strict instructions
        self.code_prompt = PromptTemplate(
            input_variables=["title", "text", "source"],
            template="""
    You are analyzing CODE CONTENT from a document titled: "{title}" from source "{source}".
    
    IMPORTANT: Use ONLY the information provided in the text below. Do NOT add any information that is not present in the text.
    
    CODE TO ANALYZE:
    ```
    {text}
    ```
    
    Instructions:
    1. Identify the programming language ONLY if it's clearly identifiable from the code
    2. Describe only what is visible in the code snippet
    3. Do not invent function behaviors that aren't shown
    4. Extract only concepts that are explicitly demonstrated
    5. Do not add any information that's not in the original code
    
    Format your response exactly as follows:
    Title: <title from the document, or "Code Snippet" if unclear>
    Language: <programming language if identifiable, "Unknown" if unclear>
    Description: <brief description of visible code elements ONLY>
    Key Concepts: <list of programming concepts explicitly demonstrated, comma-separated>
    Content: <the code with proper formatting>
    Source: <source document title>
    """
            )
        
        # Mathematical/theoretical content prompt
        self.math_prompt = PromptTemplate(
            input_variables=["title", "text", "source"],
            template="""
        You are analyzing MATHEMATICAL/THEORETICAL CONTENT from a document titled: "{title}" from source "{source}".
        
        IMPORTANT: Use ONLY the information provided in the text below. Do NOT add any information that is not present in the text.
        
        TEXT TO ANALYZE:
        ```
        {text}
        ```
        
        Instructions:
        1. Identify mathematical concepts, theorems, proofs, or algorithms described
        2. Preserve all mathematical notation and formulas
        3. Maintain the logical flow of the mathematical arguments
        4. Extract only concepts that are explicitly mentioned
        5. Do not invent or extend mathematical content that isn't present
        
        Format your response exactly as follows:
        Title: <title from the document>
        Content: <mathematical content with hierarchy and notation preserved>
        Summary: <brief factual summary of the mathematical content>
        Key Concepts: <list of mathematical concepts explicitly mentioned, comma-separated>
        Main Topic: <the main mathematical topic discussed>
        Source: <source document title>
        """
        )
    
    def extract_structured_content(self, pdf_path: str) -> List[Dict[str, Any]]:
        """
        Extract structured content from PDF with hierarchy preserved. - > keeping the bulletpoints
        """
        doc = fitz.open(pdf_path)
        doc_title = os.path.basename(pdf_path)
        
        # First collect all blocks with their metadata
        all_blocks = []
        
        for page_num in range(len(doc)):
            page = doc[page_num]
            blocks = page.get_text("dict")["blocks"]
            
            for block in blocks:
                if "lines" not in block:
                    continue
                
                # Process each line separately to handle multi-line bullet points
                for line in block["lines"]:
                    line_text = ""
                    first_char = ""
                    max_font_size = 0
                    
                    for span in line["spans"]:
                        if not first_char and span["text"].strip():
                            first_char = span["text"].strip()[0]
                        max_font_size = max(max_font_size, span["size"])
                        line_text += span["text"] + " "
                    
                    line_text = line_text.strip()
                    if not line_text:
                        continue
                    
                    if (re.match(r'^[\d]+', line_text) or 
                    re.match(r'^Page \d+', line_text, re.IGNORECASE) or
                    'LECTURE' in line_text and len(line_text) < 25 or
                    re.match(r'^[\d]+ / \d+', line_text)):  
                        continue
                    
                    # Calculate x position (for indentation)
                    x_pos = line["spans"][0]["bbox"][0] if line["spans"] else block["bbox"][0]
                    
                    # Check if this is a bullet point by first character -> in the future i want to search for a better way to detect bullet points
                    is_bullet = (first_char in "■◦•o-*" or 
                                line_text.lstrip().startswith(("■", "◦", "•", "o", "-", "*")))
                    
                    # Add to collection
                    all_blocks.append({
                        "text": line_text,
                        "page": page_num,
                        "bbox": line["bbox"],
                        "font_size": max_font_size,
                        "x_pos": x_pos,  #
                        "y_pos": line["bbox"][1], 
                        "first_char": first_char,
                        "is_bullet": is_bullet
                    })
        
        all_blocks.sort(key=lambda b: (b["page"], b["y_pos"]))
        
        font_sizes = sorted([b["font_size"] for b in all_blocks if b["font_size"] > 0], reverse=True)
        main_heading_threshold = font_sizes[min(3, len(font_sizes)-1)] if font_sizes else 15
        
        # Process blocks to build sections
        sections = []
        current_section = None
        current_title = None
        current_content = []
        
        for block in all_blocks:
            # Determine if this is a main heading
            is_heading = (
                block["font_size"] >= main_heading_threshold or 
                (block["text"].isupper() and len(block["text"]) > 3 and 
                 not block["text"].startswith(("◦", "o", "•", "*", "-")))
            )
            
            # Get indentation level based on x position
            indent_level = 0
            if block["x_pos"] > 70:
                indent_level = 1
            if block["x_pos"] > 130:
                indent_level = 2
            
            # Format the text with appropriate indentation
            formatted_text = ""
            for _ in range(indent_level):
                formatted_text += "\t"
            
            # Add bullet marker if needed
            if block["is_bullet"]:
                # Clean the text by removing the original bullet marker
                clean_text = re.sub(r'^[\s■◼▪◦•o\-*]+', '', block["text"])
                formatted_text += "- " + clean_text
            else:
                formatted_text += block["text"]
            
            # Handle section structure
            if is_heading:
                # Save previous section if it exists
                if current_title and current_content:
                    section_content = "\n".join(current_content)
                    if len(section_content) > MIN_SECTION_LENGTH:
                        sections.append({
                            "title": current_title,
                            "content": section_content,
                            "source": doc_title
                        })
                
                # Start new section
                current_title = block["text"]
                current_content = []
            else:
                # Add to current section
                if current_title:
                    current_content.append(formatted_text)
                else:
                    # If no section yet, use this as the section title
                    current_title = block["text"]
        
        # Add the last section
        if current_title and current_content:
            section_content = "\n".join(current_content)
            if len(section_content) > MIN_SECTION_LENGTH:
                sections.append({
                    "title": current_title,
                    "content": section_content,
                    "source": doc_title
                })
        
        # If no sections were created (possibly due to unusual PDF structure),
        # try grouping blocks by page
        if not sections:
            print("No clear sections found - grouping by page")
            pages_content = defaultdict(list)
            for block in all_blocks:
                pages_content[block["page"]].append(block["text"])
            
            for page_num, content in pages_content.items():
                full_content = "\n".join(content)
                if len(full_content) > MIN_SECTION_LENGTH:
                    sections.append({
                        "title": f"Page {page_num + 1}",
                        "content": full_content,
                        "source": doc_title
                    })
                    
        # Apply post-processing to clean up sections
        return self.post_process_sections(sections)
    
    def post_process_sections(self, sections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Apply post-processing to clean up sections."""
        processed = []
        
        # Merge short sections with the next one if appropriate
        i = 0
        while i < len(sections):
            section = sections[i]
            
            # Clean title and content
            section["title"] = self.clean_text(section["title"])
            section["content"] = self.clean_text(section["content"])
            
            # Skip sections with metadata
            if self.is_metadata(section):
                i += 1
                continue
                
            if len(section["content"]) < 200 and i + 1 < len(sections):
                next_section = sections[i+1]
                
                similarity = self.calculate_similarity(section["content"], next_section["content"])
                
                if similarity > MIN_SIMILARITY_THRESHOLD:
                    # Merge with next section
                    merged = {
                        "title": section["title"],
                        "content": section["content"] + "\n\n" + next_section["content"],
                        "source": section["source"]
                    }
                    processed.append(merged)
                    i += 2  
                    continue
            
            processed.append(section)
            i += 1
        
        return processed
    
    def clean_text(self, text: str) -> str:
        """Clean text from common PDF extraction artifacts."""
        if not text:
            return ""
            
        text = re.sub(r'\s+', ' ', text)
        
        text = re.sub(r'MIT OpenCourseWare.*?http://ocw\.mit\.edu', '', text)
        text = re.sub(r'https?://[^\s]+', '', text)  # Remove URLs
        
        text = re.sub(r'(?<=\n)[\s■◼▪◦•o*]+', '- ', text)
        
        return text.strip()
    
    def is_metadata(self, section: Dict[str, Any]) -> bool:
        """Check if a section is metadata (not content)."""
        combined = section["title"] + " " + section["content"]
        metadata_patterns = [
            r'MIT OpenCourseWare',
            r'copyright',
            r'https?://',
            r'License',
            r'Terms of Use',
            r'Table of Contents',
            r'Recapitulare',
            r'Bibliografie',
            r'References'
        ]
        
        return any(re.search(pattern, combined, re.IGNORECASE) for pattern in metadata_patterns)
    
    def calculate_similarity(self, text1: str, text2: str) -> float:
        """Calculate cosine similarity between two text segments."""
        try:
            vectorizer = TfidfVectorizer()
            tfidf_matrix = vectorizer.fit_transform([text1, text2])
            return cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
        except:
            return 0.0
    
    def classify_section(self, section: Dict[str, Any]) -> str:
        if not self.classifier:
            # Fallback to pattern-based classification
            return self.classify_by_patterns(section)
        
        # Use trained classifier
        try:
            # Use a shortened version to avoid tokenizer limits
            text = section["content"][:1000]  
            result = self.classifier(text)[0]
            return result["label"]
        except Exception as e:
            print(f"Classification error: {e}")
            return self.classify_by_patterns(section)
    
    def classify_by_patterns(self, section: Dict[str, Any]) -> str:
        """Classify based on text patterns when model is not available."""
        title = section["title"]
        content = section["content"]
        combined = title + "\n" + content
        
        # Check for code patterns
        code_patterns = [
            r'^\s*(def|if|for|while|print|import|return)\b',
            r'==|!=|<=|>=|\+=|-=|\*=|/=',
            r'\brange\(|\bbreak\b|\breturn\b',
            r'^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*[^=]',
            r'```python|```java|```c\+\+|```javascript',
            r'>>> |In \[\d+\]:'
        ]
        
        if any(re.search(pattern, combined, re.MULTILINE) for pattern in code_patterns):
            return 'code'
            
        # Check for mathematical content
        math_patterns = [
            r'(\b[A-Z]\([^\)]+\))|(\b[a-z]\([^\)]+\))', 
            r'[≤≥≠∈∉∑∏∫∂∇∞∝∀∃]',  
            r'\bproof\b|\btheorem\b|\blemma\b|\bclaim\b',
            r'\blim\b|\bmax\b|\bmin\b|\bsup\b|\binf\b',
            r'[A-Za-z]_{\d+}|[A-Za-z]_\d',  
            r'O\([^)]*n[^)]*\)',  
        ]
        
        if any(re.search(pattern, combined, re.IGNORECASE | re.MULTILINE) for pattern in math_patterns):
            return 'math'
        
        # Check for examples
        example_patterns = [
            r'EXAMPLE|for example',
            r'e\.g\.',
            r'^\d+\)\s+'
        ]
        
        if any(re.search(pattern, combined, re.IGNORECASE | re.MULTILINE) for pattern in example_patterns):
            return 'example'
        
        # Default to context
        return 'context'
    
    def detect_language(self, text: str) -> str:
        """Detect the language of a text."""
        try:
            return detect(text)
        except LangDetectException:
            return "en"  # Default to English
    
    def process_with_llm(self, section: Dict[str, Any]) -> Dict[str, Any]:
        # First classify the section
        section_class = self.classify_section(section)
        section["class"] = section_class
        
        # Choose appropriate prompt based on classification
        if section_class == "code":
            prompt = self.code_prompt
        elif section_class == "math":
            prompt = self.math_prompt
        else:
            prompt = self.general_prompt
        
        # Create chain
        chain = LLMChain(llm=self.llm, prompt=prompt)
        
        # Process with LLM
        try:
            result = chain.run(
                title=section["title"], 
                text=section["content"],
                source=section["source"]
            )
            section["llm_output"] = result
            
            # Extract structured data from LLM output
            extracted = self.extract_fields_from_llm_output(result)
            section.update(extracted)
        except Exception as e:
            print(f"LLM processing error: {e}")
            section["llm_output"] = f"Error: {str(e)}"
            section["summary"] = "Failed to process with LLM"
        
        return section
    
    def extract_fields_from_llm_output(self, text: str) -> Dict[str, Any]:
        """Extract structured fields from LLM output."""
        extracted = {}
        
        # Extract title
        title_match = re.search(r'Title:\s*(.*?)(?:\n|$)', text)
        if title_match:
            extracted["processed_title"] = title_match.group(1).strip()
        
        # Extract content/text
        content_match = re.search(r'Content:\s*(.*?)(?:Summary:|Key Concepts:|Main Topic:|Language:|Description:|$)', 
                                 text, re.DOTALL)
        if content_match:
            extracted["processed_content"] = content_match.group(1).strip()
        
        # Extract summary
        summary_match = re.search(r'Summary:\s*(.*?)(?:Key Concepts:|Main Topic:|$)', text, re.DOTALL)
        if summary_match:
            extracted["summary"] = summary_match.group(1).strip()
        
        # Extract key concepts
        concepts_match = re.search(r'Key Concepts:\s*(.*?)(?:Main Topic:|Language:|Description:|Source:|$)', text, re.DOTALL)
        if concepts_match:
            concepts = concepts_match.group(1).strip()

            extracted["key_concepts"] = [c.strip() for c in concepts.split(',')]
        
        # Extract main topic
        topic_match = re.search(r'Main Topic:\s*(.*?)(?:Source:|$)', text, re.DOTALL)
        if topic_match:
            extracted["main_topic"] = topic_match.group(1).strip()
        
        lang_match = re.search(r'Language:\s*(.*?)(?:Description:|$)', text)
        if lang_match:
            extracted["programming_language"] = lang_match.group(1).strip()
        
        return extracted
    
    def process_pdf(self, pdf_path: str, output_path: str = None) -> List[Dict[str, Any]]:
        """
        Process a PDF file end-to-end.
        """
        print(f"Processing {pdf_path}...")
        
        # Extract structured content
        sections = self.extract_structured_content(pdf_path)
        print(f"Extracted {len(sections)} sections")
        
        # Process each section with the LLM
        processed_sections = []
        for i, section in enumerate(sections):
            print(f"Processing section {i+1}/{len(sections)}: {section['title']}")
            processed = self.process_with_llm(section)
            processed_sections.append(processed)
        
        # Save to JSON if output path provided
        if output_path:
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(processed_sections, f, ensure_ascii=False, indent=2)
            print(f"Saved processed data to {output_path}")
        
        return processed_sections
    
    def validate_llm_outputs(self, processed_sections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Validate LLM outputs for hallucination by comparing against original text.
        """
        for section in processed_sections:
            # Skip sections with errors
            if "summary" not in section or section["summary"].startswith("Failed"):
                section["validation_score"] = 0.0
                section["hallucination_risk"] = "high"
                continue
                
            # Calculate similarity between original content and processed content
            original_content = section["content"]
            processed_content = section.get("processed_content", "")
            summary = section.get("summary", "")
            
            # Compare original content to processed content and summary
            content_similarity = self.calculate_similarity(original_content, processed_content)
            summary_similarity = self.calculate_similarity(original_content, summary)
            
            # Average the similarities
            avg_similarity = (content_similarity + summary_similarity) / 2.0
            
            # Add validation score
            section["validation_score"] = round(avg_similarity, 2)
            
            # Label hallucination risk
            if avg_similarity < 0.3:
                section["hallucination_risk"] = "high"
            elif avg_similarity < 0.6:
                section["hallucination_risk"] = "medium"
            else:
                section["hallucination_risk"] = "low"
                
        return processed_sections

In [None]:

def process_pdf_for_rag(pdf_path, output_path=None, llm_base_url="http://localhost:4500/v1", 
                       llm_api_key="lm-studio", llm_model="meta-llama-3.1-8b-instruct"):
    """Process a PDF for RAG and return the processed data and metrics"""
    
    # Initialize processor
    processor = ImprovedPDFProcessor(
        llm_base_url=llm_base_url,
        llm_api_key=llm_api_key,
        llm_model=llm_model,
        temperature=0.2 
    )
    
    # Process PDF
    processed_data = processor.process_pdf(pdf_path, output_path)
    
    validated_data = processor.validate_llm_outputs(processed_data)
    
    print("\n===== PROCESSING SUMMARY =====")
    for i, section in enumerate(validated_data):
        print(f"\nSection {i+1}: {section.get('title', 'Untitled')}")
        print(f"Classification: {section.get('class', 'Unknown')}")
        print(f"Main Topic: {section.get('main_topic', 'Unknown')}")
        print(f"Validation Score: {section.get('validation_score', 0.0)} (Hallucination Risk: {section.get('hallucination_risk', 'Unknown')})")
    

    sections_df = pd.DataFrame([
        {
            "title": section.get("title", ""),
            "classification": section.get("class", "unknown"),
            "main_topic": section.get("main_topic", ""),
            "num_concepts": len(section.get("key_concepts", [])),
            "content_length": len(section.get("content", "")),
            "summary_length": len(section.get("summary", "")),
            "validation_score": section.get("validation_score", 0.0),
            "hallucination_risk": section.get("hallucination_risk", "unknown")
        }
        for section in validated_data
    ])
    
    print("\n===== SECTIONS SUMMARY =====")
    print(sections_df)
    
    return validated_data, sections_df

In [None]:
pdf_path = "data/slides/MIT6_0001F16_Lec8.pdf" 
output_json = "processed_data.json" 

data, metrics = process_pdf_for_rag(pdf_path, output_json)


print("\nStep 2: Visualizing the results")

# Plot classification distribution
plt.figure(figsize=(10, 5))
metrics['classification'].value_counts().plot(kind='bar')
plt.title('Content Classification Distribution')
plt.xlabel('Class')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

# Plot hallucination risk
plt.figure(figsize=(10, 5))
metrics['hallucination_risk'].value_counts().plot(kind='bar', color=['green', 'orange', 'red'])
plt.title('Hallucination Risk Assessment')
plt.xlabel('Risk Level')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

In [None]:
def process_directory(pdf_dir, output_dir='processed_data'):
    """Process all PDFs in a directory"""
    import os
    
    os.makedirs(output_dir, exist_ok=True)
    all_data = []
    all_metrics = []
    
    pdf_files = [f for f in os.listdir(pdf_dir) if f.endswith('.pdf')]
    
    for pdf_file in pdf_files:
        pdf_path = os.path.join(pdf_dir, pdf_file)
        output_json = os.path.join(output_dir, f"{os.path.splitext(pdf_file)[0]}_processed.json")
        
        try:
            print(f"\nProcessing {pdf_file}...")
            data, metrics = process_pdf_for_rag(pdf_path, output_json)
            all_data.extend(data)
            all_metrics.append(metrics)
        except Exception as e:
            print(f"Error processing {pdf_file}: {e}")
    
    # Combine all metrics
    if all_metrics:
        combined_metrics = pd.concat(all_metrics, ignore_index=True)
        combined_metrics.to_csv(os.path.join(output_dir, "combined_metrics.csv"), index=False)
        print(f"\nProcessed {len(pdf_files)} PDF files. Combined metrics saved to {os.path.join(output_dir, 'combined_metrics.csv')}")
        return all_data, combined_metrics
    else:
        print("No files were processed successfully.")
        return [], pd.DataFrame()


In [None]:

def prepare_for_rag(processed_data):
    """
    Prepare processed data for use in a RAG system by:
    1. Filtering out high hallucination risk content
    2. Structuring data for embedding and retrieval
    3. Creating metadata for improved retrieval
    """
    rag_documents = []
    
    for item in processed_data:
        # Skip items with high hallucination risk
        if item.get('hallucination_risk') == 'high':
            continue
            
        content = item.get('summary', item.get('processed_content', item.get('content', '')))
        
        # Create metadata for better retrieval
        metadata = {
            'source': item.get('source', ''),
            'title': item.get('processed_title', item.get('title', '')),
            'classification': item.get('class', 'unknown'),
            'main_topic': item.get('main_topic', ''),
            'key_concepts': item.get('key_concepts', []),
        }
        
        # Create a retrieval-friendly document
        rag_document = {
            'content': content,
            'metadata': metadata
        }
        
        rag_documents.append(rag_document)
    
    print(f"Prepared {len(rag_documents)} documents for RAG (filtered out {len(processed_data) - len(rag_documents)} high-risk items)")
    return rag_documents


In [None]:

def export_for_embeddings(rag_documents, output_file="rag_documents.jsonl"):
    """Export in JSONL format suitable for embedding generation"""
    import json
    
    with open(output_file, 'w', encoding='utf-8') as f:
        for doc in rag_documents:
            f.write(json.dumps(doc) + '\n')
    
    print(f"Exported {len(rag_documents)} documents to {output_file}")


In [None]:
pdf_processor = ImprovedPDFProcessor()
metadata_extractor = CSMetadataExtractor()
embedding_model = "sentence-transformers/LaBSE"
vector_db_path = "./rag/chroma_db"

# Set up ChromaDB with LaBSE embeddings
chroma_client = chromadb.PersistentClient(path=vector_db_path)
sentence_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name=embedding_model
)

# Create or get collection
collection = chroma_client.get_or_create_collection(
    name="cs_multilingual_docs", 
    embedding_function=sentence_ef
)
def process_and_index_pdf(pdf_path):
    # Process the PDF
    processed_data = pdf_processor.process_pdf(pdf_path)
    
    # Add to vector database with enhanced metadata
    for i, section in enumerate(processed_data):
        # Skip high hallucination risk content
        if section.get('hallucination_risk') == 'high':
            continue
            
        # Extract CS-specific metadata
        cs_metadata = metadata_extractor.extract_metadata(
            text=section.get('content', ''),
            title=section.get('title', ''),
            lang=section.get('language', 'en')
        )
        
        # Combine metadata
        combined_metadata = {
            **{k: v for k, v in section.items() if k not in ['content', 'processed_content', 'summary']},
            **cs_metadata
        }
        
        # Choose the best text representation for embedding
        if section.get('summary') and len(section.get('summary')) > 50:
            content_for_embedding = section['summary']
        else:
            content_for_embedding = section.get('processed_content', section.get('content', ''))
        
        # Add to collection
        collection.add(
            documents=[content_for_embedding],
            metadatas=[combined_metadata],
            ids=[f"{os.path.basename(pdf_path)}_section_{i}"]
        )