In [1]:
# Imports
import os
import re
import nltk
import torch
import numpy as np
import logging
from typing import List, Dict, Any
from docx import Document
from tqdm.notebook import tqdm
from pymilvus import connections, Collection, utility, CollectionSchema, FieldSchema, DataType
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Env 
os.environ["TOKENIZERS_PARALLELISM"] = "false" 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

# Utility Functions
def print_status(section_name: str, status: bool, message: str = ""):
    """Print status of a section with colored output."""
    status_str = "✅ SUCCESS" if status else "❌ FAILED"
    print(f"\n{status_str} | {section_name}")
    if message:
        print(f"  └─ {message}")

def verify_section(section_number: int, verification_func) -> bool:
    """Verify if a section was executed successfully."""
    try:
        result = verification_func()
        print_status(f"Section {section_number} Verification", True, "Successfully executed")
        return True
    except Exception as e:
        print_status(f"Section {section_number} Verification", False, f"Error: {str(e)}")
        return False

def get_model_embedding_dim(model_name: str = "TurkuNLP/bert-base-finnish-cased-v1") -> int:
    """Get embedding dimension from model config."""
    model = AutoModel.from_pretrained(model_name)
    return model.config.hidden_size

def check_cuda():
    """Check CUDA availability and print status."""
    try:
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name(0)
            print_status("CUDA Check", True, f"Using GPU: {device_name}")
            return True
        else:
            print_status("CUDA Check", True, "Using CPU")
            return True
    except Exception as e:
        print_status("CUDA Check", False, str(e))
        return False

def ensure_stopwords_downloaded(language='finnish'):
    """Download NLTK stopwords and print status."""
    try:
        nltk.download('stopwords', quiet=True)
        print_status("NLTK Setup", True, f"Downloaded {language} stopwords")
        return True
    except Exception as e:
        print_status("NLTK Setup", False, str(e))
        return False

# Global Constants
#EMBEDDING_DIM = get_model_embedding_dim()
EMBEDDING_DIM = 1536
MILVUS_HOST = "milvus-standalone"
MILVUS_PORT = "19530"
MILVUS_ALIAS = "default"

# Initial Setup Verification
def verify_initial_setup():
    check_cuda()
    ensure_stopwords_downloaded()
    print_status("Embedding Dimension", True, f"Using dimension: {EMBEDDING_DIM}")
    
# Run initial verification
verify_initial_setup()

# Core Pipeline Components
class DocumentProcessor:
    def __init__(self, chunk_size=400, chunk_overlap=80):
        try:
            self.text_splitter = RecursiveCharacterTextSplitter(
                separators=[
                    "\n\n",  # First split on double newlines
                    "\n",    # Then single newlines
                    ".",     # Then sentence endings
                    ":",     # Then colons (common in Finnish formatting)
                    ";",     # Then semicolons
                    ",",     # Then commas
                    " ",     # Finally, split on spaces if needed
                    ""
                ],
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len,
                keep_separator=True,
                add_start_index=True
            )
            
            # Add Finnish-specific cleaning patterns
            self.clean_patterns = [
                (r'\s+', ' '),  # Normalize whitespace
                (r'[\(\{\[\]\}\)]', ''),  # Remove brackets
                (r'[^\w\s\.\,\?\!\-\:\;äöåÄÖÅ]', ''),  # Keep Finnish characters
                (r'\s+\.', '.'),  # Fix spacing around periods
                (r'\.+', '.'),  # Normalize multiple periods
            ]
            print_status("Document Processor", True, "Initialized with Finnish-optimized settings")
        except Exception as e:
            print_status("Document Processor", False, str(e))
            raise

    def extract_metadata_from_filename(self, filename: str) -> tuple:
        """Extract metadata from filename."""
        title = os.path.splitext(filename)[0]
        match = re.match(r'([A-Za-z]+)\s+(\d{1,3})v\s+([A-Za-z0-9\-]+)', title)
        if match:
            return match.group(1), int(match.group(2)), match.group(3)
        return None, None, None

    def preprocess_text(self, text: str) -> str:
        """Clean and normalize Finnish text with improved handling."""
        try:
            # Apply all cleaning patterns
            for pattern, replacement in self.clean_patterns:
                text = re.sub(pattern, replacement, text)
            
            # Ensure proper sentence boundaries
            text = re.sub(r'([.!?])\s*([A-ZÄÖÅ])', r'\1\n\2', text)
            
            # Remove extra whitespace while preserving paragraph breaks
            text = '\n'.join(line.strip() for line in text.split('\n'))
            return text.strip()
        except Exception as e:
            print_status("Text Preprocessing", False, str(e))
            raise

    def process_document(self, file_path: str) -> List[Dict[str, Any]]:
        """Process document with improved metadata and chunking."""
        try:
            # Read document
            doc = Document(file_path)
            text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
            
            # Extract metadata
            filename = os.path.basename(file_path)
            name, age, doc_id = self.extract_metadata_from_filename(filename)
            
            # Preprocess and split text
            clean_text = self.preprocess_text(text)
            chunks = self.text_splitter.split_text(clean_text)
            
            # Create chunks with enhanced metadata
            processed_chunks = []
            for i, chunk in enumerate(chunks):
                # Calculate semantic importance score
                importance_score = self._calculate_chunk_importance(chunk)
                
                processed_chunks.append({
                    "text": chunk,
                    "metadata": {
                        "source": filename,
                        "person_name": name,
                        "person_age": age,
                        "document_id": doc_id,
                        "chunk_index": i,
                        "importance_score": importance_score,
                        "chunk_length": len(chunk),
                        "contains_question": "?" in chunk,
                    }
                })
            
            print_status("Document Processing", True, 
                        f"Processed {filename} into {len(processed_chunks)} chunks")
            return processed_chunks
            
        except Exception as e:
            print_status("Document Processing", False, f"Error processing {file_path}: {str(e)}")
            raise

    def _calculate_chunk_importance(self, chunk: str) -> float:
        """Calculate importance score for chunk based on Finnish text patterns."""
        score = 1.0
        
        # Key phrase indicators (common in Finnish documentation)
        key_phrases = [
            "tärkeä", "merkittävä", "olennainen", "keskeinen",
            "huomattava", "erityinen", "tärkein", "ensisijainen"
        ]
        
        # Increase score for chunks with key phrases
        if any(phrase in chunk.lower() for phrase in key_phrases):
            score *= 1.2
            
        # Prefer chunks with complete sentences
        if chunk.count('.') > 0:
            score *= 1.1
            
        # Prefer chunks with personal pronouns (common in Finnish personal documents)
        if any(pronoun in chunk.lower() for pronoun in ["minä", "minun", "minua", "minulla"]):
            score *= 1.15
            
        return score


class MilvusManager:
    def __init__(self, host: str = "milvus-standalone", port: str = "19530", alias: str = "default"):
        self.host = host
        self.port = port
        self.alias = alias
        self.connected = False
        self.connect()
        
    def connect(self):
        """Establish connection to Milvus."""
        try:
            try:
                connections.remove_connection(alias=self.alias)
                print_status("Milvus Connection", True, "Cleaned up existing connection")
            except:
                pass
            
            connections.connect(
                alias=self.alias,
                host=self.host,
                port=self.port,
                timeout=10.0
            )
            
            try:
                utility.get_server_version()
                self.connected = True
                print_status("Milvus Connection", True, f"Connected to {self.host}:{self.port}")
            except Exception as ve:
                raise Exception(f"Connection verification failed: {str(ve)}")
                
        except Exception as e:
            self.connected = False
            print_status("Milvus Connection", False, str(e))
            raise

    def create_collection(self, collection_name: str = "document_embeddings"):
        """Create Milvus collection with appropriate schema."""
        try:
            if utility.has_collection(collection_name):
                Collection(name=collection_name).drop()
                print_status("Milvus Collection", True, f"Dropped existing collection: {collection_name}")
                
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
                FieldSchema(name="person_name", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="person_age", dtype=DataType.INT64),
                FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="chunk_index", dtype=DataType.INT64)
            ]
            
            schema = CollectionSchema(
                fields=fields,
                description="Document embeddings collection",
                enable_dynamic_field=False
            )
            collection = Collection(name=collection_name, schema=schema)
            
            self.create_and_load_index(collection)
            
            print_status("Milvus Collection", True, f"Created new collection: {collection_name} with dim={EMBEDDING_DIM}")
            return collection
        except Exception as e:
            print_status("Milvus Collection", False, str(e))
            raise

    def create_and_load_index(self, collection):
        try:
            index_params = {
                "metric_type": "IP",
                "index_type": "IVF_FLAT",
                "params": {"nlist": 1024}
            }
            collection.create_index(field_name="embedding", index_params=index_params)
            print_status("Index Creation", True, "Created IVF_FLAT index")
            
            collection.load()
            print_status("Collection Load", True, "Loaded collection into memory")
            
        except Exception as e:
            print_status("Index Creation", False, str(e))
            raise

    def reload_collection(self, collection_name: str = "document_embeddings"):
        try:
            if utility.has_collection(collection_name):
                collection = Collection(name=collection_name)
                collection.load()
                print_status("Collection Reload", True, f"Reloaded collection: {collection_name}")
                return collection
            else:
                raise Exception(f"Collection {collection_name} does not exist")
        except Exception as e:
            print_status("Collection Reload", False, str(e))
            raise
class EmbeddingGenerator:
    def __init__(self, model_name: str = "TurkuNLP/gpt3-finnish-large"):
        try:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModel.from_pretrained(model_name).to(self.device)
            self.embedding_dim = self.model.config.hidden_size
            print_status("Embedding Model", True, f"Loaded {model_name} (dim={self.embedding_dim})")
        except Exception as e:
            print_status("Embedding Model", False, str(e))
            raise
        
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
    def generate(self, texts: List[str], batch_size: int = 8) -> np.ndarray:
        try:
            all_embeddings = []
            
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                
                encoded_input = self.tokenizer(
                    batch_texts,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ).to(self.device)
                
                with torch.no_grad():
                    model_output = self.model(**encoded_input)
                
                sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
                sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
                
                all_embeddings.append(sentence_embeddings.cpu().numpy())
            
            result = np.concatenate(all_embeddings)
            
            if result.shape[1] != EMBEDDING_DIM:
                raise ValueError(f"Embedding dimension mismatch. Expected {EMBEDDING_DIM}, got {result.shape[1]}")
            
            print_status("Embedding Generation", True, 
                    f"Generated {len(texts)} embeddings with dimension {result.shape[1]}")
            return result
        except Exception as e:
            print_status("Embedding Generation", False, str(e))
            raise

    def get_embedding_dim(self) -> int:
        return self.embedding_dim


class RAGPipeline:
    def __init__(self, model_id: str = "Finnish-NLP/llama-7b-finnish-instruct-v0.2"):
        try:
            self.setup_llm(model_id)
            print_status("LLM Setup", True, f"Loaded {model_id}")
            
            self.doc_processor = DocumentProcessor()
            self.embedding_generator = EmbeddingGenerator()
            
            if self.embedding_generator.get_embedding_dim() != EMBEDDING_DIM:
                raise ValueError(f"Embedding dimension mismatch. Global: {EMBEDDING_DIM}, " 
                               f"Generator: {self.embedding_generator.get_embedding_dim()}")
            
            self.milvus_manager = MilvusManager(
                host=MILVUS_HOST,
                port=MILVUS_PORT,
                alias=MILVUS_ALIAS
            )
            self.collection = self.milvus_manager.create_collection()
            print_status("RAG Pipeline", True, "All components initialized")
        except Exception as e:
            print_status("RAG Pipeline", False, str(e))
            raise
    def process_documents(self, folder_path: str):
        """Process all documents in the specified folder."""
        try:
            # Get all .docx files in the folder
            file_paths = [f for f in os.listdir(folder_path) if f.endswith('.docx')]
            if not file_paths:
                raise ValueError(f"No .docx files found in {folder_path}")
                
            print_status("Document Loading", True, f"Found {len(file_paths)} documents")
            all_chunks = []
            
            # Process each document
            for file in tqdm(file_paths, desc="Processing documents"):
                file_path = os.path.join(folder_path, file)
                chunks = self.doc_processor.process_document(file_path)
                all_chunks.extend(chunks)
                
            # Generate embeddings for all chunks
            texts = [chunk["text"] for chunk in all_chunks]
            embeddings = self.embedding_generator.generate(texts)
            
            # Prepare entities for Milvus
            entities = []
            for i, (text, embedding, chunk) in enumerate(zip(texts, embeddings, all_chunks)):
                entity = {
                    "text": text,
                    "embedding": embedding.tolist(),
                    "person_name": chunk["metadata"]["person_name"],
                    "person_age": chunk["metadata"]["person_age"],
                    "document_id": chunk["metadata"]["document_id"],
                    "chunk_index": chunk["metadata"]["chunk_index"]
                }
                entities.append(entity)
            
            # Insert into Milvus in batches
            batch_size = 100
            for i in range(0, len(entities), batch_size):
                batch = entities[i:i + batch_size]
                self.collection.insert(batch)
            
            # Ensure data is persisted
            self.collection.flush()
            
            # Create index and load collection
            self.milvus_manager.create_and_load_index(self.collection)
            
            print_status("Document Processing", True, 
                        f"Processed {len(texts)} chunks from {len(file_paths)} documents")
                        
        except Exception as e:
            print_status("Document Processing", False, f"Error: {str(e)}")
            raise
        
    def setup_llm(self, model_id: str):
        try:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                quantization_config=bnb_config,
                torch_dtype=torch.float16,
                device_map="auto",
                max_memory={0: "6GiB"},
                offload_folder="offload"
            )
            
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            self.pipeline = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.3,
                top_p=0.95,
                repetition_penalty=1.15
            )
            print_status("LLM Pipeline", True, "Pipeline configured successfully")
        except Exception as e:
            print_status("LLM Pipeline", False, str(e))
            raise

    def query(self, question: str, top_k: int = 10):
        try:
            self.collection = self.milvus_manager.reload_collection()
            
            # Preprocess question
            question = question.strip()
            if not question.endswith('?'):
                question += '?'
            
            # Generate question embedding and search
            question_embedding = self.embedding_generator.generate([question])[0]
            
            search_params = {
                "metric_type": "IP",
                "params": {"nprobe": 50}
            }
            
            # Get search results
            search_results = self.collection.search(
                data=[question_embedding.tolist()],
                anns_field="embedding",
                param=search_params,
                limit=top_k * 2,
                output_fields=["text", "person_name", "document_id", "chunk_index"]
            )
    
            # Convert search results to a format we can work with
            initial_results = []
            for hit in search_results[0]:  # Fix: Properly handle Milvus Hit objects
                initial_results.append({
                    'text': hit.entity.get('text'),
                    'person_name': hit.entity.get('person_name'),
                    'document_id': hit.entity.get('document_id'),
                    'chunk_index': hit.entity.get('chunk_index'),
                    'score': float(hit.score)
                })
            
            # Rerank results
            reranked_results = self._rerank_results(question, initial_results)
            
            # Take top_k after reranking
            results = reranked_results[:top_k]
            
            context_parts = []
            for i, hit in enumerate(results):
                context_parts.append(
                    f"[Dokumentti {i+1}]\n"
                    f"Lähde: {hit['document_id']}\n"
                    f"Henkilö: {hit['person_name']}\n"
                    f"Luotettavuus: {hit['score']:.2%}\n"
                    f"Tekstikatkelma:\n{hit['text']}\n"
                    f"{'-' * 40}\n"
                )
            
            context = "\n".join(context_parts)
            
            prompt = f"""Tehtävä: Etsi tarkka vastaus annettuun kysymykseen käyttäen vain alla olevaa kontekstia.
    
            Kysymys: {question}
            
            Konteksti:
            {context}
            
            Tärkeät ohjeet:
            1. Jos löydät suoran vastauksen:
               - Mainitse AINA ensin dokumentti, josta vastaus löytyy (esim. "Dokumentti 1:")
               - Lainaa TARKASTI alkuperäistä tekstiä käyttäen lainausmerkkejä
               - Perustele vastauksen luotettavuus yhdellä lyhyellä lauseella
            
            2. Jos löydät vain osittaisen vastauksen:
               - Kerro selkeästi mikä osa vastauksesta löytyi ja mikä puuttuu
               - Käytä silti suoria lainauksia löytyneestä osasta
            
            3. Jos et löydä minkäänlaista vastausta:
               - Vastaa vain: "En löydä suoraa vastausta annetusta kontekstista"
            
            Vastaus:"""
            
            response = self.pipeline(
                prompt,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.1,
                top_p=0.85,
                repetition_penalty=1.2
            )[0]["generated_text"]
    
            response = self._clean_response(response)
            
            return {
                "answer": response,
                "sources": results,
                "metadata": {
                    "question": question,
                    "num_chunks_retrieved": len(results),
                    "max_similarity_score": max(hit['score'] for hit in results)
                }
            }
            
        except Exception as e:
            print_status("Query", False, str(e))
            raise
    
    def _rerank_results(self, question: str, results: List[Dict]) -> List[Dict]:
        """Rerank results using both semantic similarity and context relevance."""
        reranked = []
        question_embedding = self.embedding_generator.generate([question])[0]
        
        for hit in results:
            text = hit['text']
            score = hit['score']
            
            # Compute semantic similarity
            chunk_embedding = self.embedding_generator.generate([text])[0]
            semantic_similarity = np.dot(question_embedding, chunk_embedding)
            
            # Adjust score based on semantic similarity
            score *= (1 + semantic_similarity)
            
            # Further adjustments based on context relevance
            if '?' in text:
                score *= 1.05  # Slight boost if chunk contains a question
            if any(entity in text.lower() for entity in self._extract_entities(question)):
                score *= 1.1  # Boost if entities match
            
            hit['score'] = score
            reranked.append(hit)
        
        # Sort the reranked results
        reranked.sort(key=lambda x: x['score'], reverse=True)
        return reranked
    
    def _extract_entities(self, text: str) -> List[str]:
        """Simple entity extraction based on capitalized words."""
        return re.findall(r'\b[A-Z][a-zäöå]*\b', text)

    def _clean_response(self, response: str) -> str:
        response = re.sub(r'(\[Dokumentti \d+\])\s*\1', r'\1', response)
        response = re.sub(r'Luottamus: \d+%\s*Selitys:', '', response)
        if len(response) > 500:
            response = response[:497] + '...'
        return response.strip()

    def _create_prompt(self, question: str, context: str) -> str:
        return f"""Tehtävä: Vastaa annettuun kysymykseen käyttäen vain alla olevaa kontekstia.

        Kysymys: {question}
        
        Konteksti:
        {context}
        
        Vastausohjeet:
        1. Jos löydät suoran vastauksen:
           - Mainitse ensin dokumentti, josta vastaus löytyy
           - Lainaa tekstiä tarkasti käyttäen lainausmerkkejä
           - Mainitse vastauksen luotettavuus prosentteina
        
        2. Jos löydät osittaisen vastauksen:
           - Kerro, mitä tietoa löysit ja mistä
           - Mainitse selkeästi, mitä tietoa puuttuu
        
        3. Jos et löydä vastausta:
           - Vastaa vain: "En löydä suoraa vastausta annetusta kontekstista"
        
        Vastaus:"""

## validating 
def verify_pipeline_components():
    doc_processor = DocumentProcessor()
    milvus_manager = MilvusManager()
    embedding_generator = EmbeddingGenerator()
    pipeline = RAGPipeline()
    return all([doc_processor, milvus_manager, embedding_generator, pipeline])

verify_section("Pipeline Components", verify_pipeline_components)

class Neo4jDocumentManager:
    def __init__(self, uri: str = "neo4j://neo4j:7687", 
                 username: str = "neo4j", 
                 password: str = "test"):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))
        self.setup_constraints()
        self.person_names = set()  # Store discovered person names
        
    def _extract_person_names(self, text: str) -> set:
        """Extract potential person names from text using Finnish patterns."""
        names = set()
        
        # Finnish name patterns
        patterns = [
            r'\b[A-ZÄÖÅ][a-zäöå]+\b',  # Basic capitalized word
            r'(?:herra|rouva|neiti)\s+[A-ZÄÖÅ][a-zäöå]+',  # Titles
            r'(?:ystävä|naapuri)\s+[A-ZÄÖÅ][a-zäöå]+',  # Relationships
        ]
        
        for pattern in patterns:
            matches = re.finditer(pattern, text)
            for match in matches:
                name = match.group().split()[-1]  # Get the actual name part
                if len(name) > 2:  # Filter out very short words
                    names.add(name)
        
        return names

    def create_document_graph(self, processed_chunks: List[Dict]):
        """Create graph structure from processed document chunks."""
        # First pass: collect all person names
        for chunk in processed_chunks:
            # Add names from metadata
            if chunk["metadata"].get("person_name"):
                self.person_names.add(chunk["metadata"]["person_name"])
            
            # Add names found in content
            self.person_names.update(self._extract_person_names(chunk["text"]))
        
        with self.driver.session() as session:
            for chunk in processed_chunks:
                # Create Document and Person nodes
                self._create_document_and_person_nodes(session, chunk)
                
                # Extract and create relationships
                self._process_relationships(session, chunk)
    
    def _create_document_and_person_nodes(self, session, chunk: Dict):
        """Create document and person nodes with relationships."""
        session.run("""
            MERGE (p:Person {name: $person_name})
            SET p.age = $age
            
            MERGE (d:Document {id: $doc_id})
            SET d.content = $content,
                d.chunk_index = $chunk_index
            
            MERGE (p)-[:APPEARS_IN]->(d)
            
            WITH p, d
            MATCH (p)-[:APPEARS_IN]->(other:Document)
            WHERE other.id <> d.id
            MERGE (d)-[:RELATED_TO]->(other)
        """, {
            "person_name": chunk["metadata"]["person_name"],
            "age": chunk["metadata"]["person_age"],
            "doc_id": f"{chunk['metadata']['document_id']}_{chunk['metadata']['chunk_index']}",
            "content": chunk["text"],
            "chunk_index": chunk["metadata"]["chunk_index"]
        })

    def find_mentioned_persons(self, text: str) -> List[str]:
        """Find persons mentioned in text from known person names."""
        mentioned_persons = []
        text_lower = text.lower()
        
        # Look for known person names in the text
        for name in self.person_names:
            if name.lower() in text_lower:
                mentioned_persons.append(name)
                
        return mentioned_persons

    def query(self, question: str) -> Dict:
        """Enhanced query processing with dynamic person detection."""
        # Find mentioned persons in the question
        mentioned_persons = self.find_mentioned_persons(question)
        
        with self.driver.session() as session:
            # Get person contexts for all mentioned persons
            person_contexts = []
            for person_name in mentioned_persons:
                context = self._get_person_context(session, person_name)
                if context:
                    person_contexts.append(context)
            
            # Get related documents
            documents = self._get_related_documents(session, mentioned_persons)
            
            return {
                "person_contexts": person_contexts,
                "related_documents": documents,
                "mentioned_persons": mentioned_persons
            }
    
    def _get_person_context(self, session, person_name: str) -> Dict:
        """Get comprehensive context for a person."""
        result = session.run("""
            MATCH (p:Person {name: $name})
            OPTIONAL MATCH (p)-[:APPEARS_IN]->(d:Document)
            OPTIONAL MATCH (d)-[:HAS_RELATIONSHIP]->(r:Relationship)
            
            WITH p,
                 COLLECT(DISTINCT r.type) as relationship_types,
                 COLLECT(DISTINCT d.content) as contents,
                 COLLECT(DISTINCT d.id) as doc_ids
            
            RETURN {
                name: p.name,
                age: p.age,
                relationship_types: relationship_types,
                document_count: SIZE(contents),
                document_ids: doc_ids
            } as context
        """, {"name": person_name})
        
        return result.single()["context"] if result.peek() else None

class Neo4jEnhancedRAGPipeline:
    def __init__(self, base_rag_pipeline, neo4j_uri: str = "neo4j://neo4j:7687"):
        self.base_rag = base_rag_pipeline
        self.neo4j_manager = Neo4jDocumentManager(
            uri=neo4j_uri,
            username="neo4j",
            password="test"
        )
        
    def query(self, question: str) -> Dict:
        """Enhanced query processing with dynamic person detection."""
        # Get initial results from base RAG
        base_results = self.base_rag.query(question)
        
        # Get Neo4j enhanced results with dynamic person detection
        neo4j_results = self.neo4j_manager.query(question)
        
        # Combine contexts
        enhanced_context = self._combine_contexts(
            base_results["sources"],
            neo4j_results["related_documents"],
            neo4j_results["person_contexts"]
        )
        
        # Generate final answer
        final_answer = self._generate_enhanced_answer(
            question,
            enhanced_context
        )
        
        return {
            "answer": final_answer,
            "sources": base_results["sources"],
            "graph_enhanced_sources": neo4j_results["related_documents"],
            "person_contexts": neo4j_results["person_contexts"],
            "mentioned_persons": neo4j_results["mentioned_persons"]
        }

# Testing Infrastructure
class RAGTester:
    def __init__(self, pipeline: RAGPipeline):
        self.pipeline = pipeline
        self.test_results = []
        
    def run_test_suite(self, test_questions: List[str]) -> Dict[str, Any]:
        """Run comprehensive tests on the RAG pipeline."""
        test_results = []
        summary_stats = {
            "total_questions": len(test_questions),
            "successful_responses": 0,
            "failed_responses": 0,
            "average_similarity": 0.0,
            "direct_quote_ratio": 0.0
        }
        
        for question in test_questions:
            try:
                # Get response from pipeline
                result = self.pipeline.query(question)
                
                # Validate response
                validation = self._validate_response(result["answer"])
                
                # Calculate response metrics
                metrics = {
                    "question": question,
                    "has_direct_quote": validation["has_direct_quote"],
                    "source_count": len(result["sources"]),
                    "max_similarity": max(s["similarity_score"] for s in result["sources"]),
                    "response_quality": validation,
                    "response_length": len(result["answer"]),
                }
                
                test_results.append(metrics)
                summary_stats["successful_responses"] += 1
                summary_stats["average_similarity"] += metrics["max_similarity"]
                summary_stats["direct_quote_ratio"] += int(metrics["has_direct_quote"])
                
                # Print detailed results
                self._print_test_result(question, result, metrics)
                
            except Exception as e:
                print(f"❌ Error testing question '{question}': {str(e)}")
                summary_stats["failed_responses"] += 1
                
        # Calculate final statistics
        if summary_stats["successful_responses"] > 0:
            summary_stats["average_similarity"] /= summary_stats["successful_responses"]
            summary_stats["direct_quote_ratio"] /= summary_stats["successful_responses"]
        
        return {
            "detailed_results": test_results,
            "summary": summary_stats
        }
    
    def _validate_response(self, response: str) -> Dict[str, bool]:
        """Validate Finnish language response with detailed checks."""
        validation = {
            # Basic structural checks
            "has_source_reference": bool(re.search(r'\[Dokumentti \d+\]', response)),
            "has_direct_quote": '"' in response,
            "is_complete_sentence": response.strip().endswith(('.', '?', '!')),
            "has_confidence": bool(re.search(r'\d+\s*%', response)),
            "reasonable_length": 10 <= len(response) <= 500,
            
            # Finnish language specific checks
            "has_finnish_chars": bool(re.search(r'[äöåÄÖÅ]', response)),
            "proper_finnish_structure": self._check_finnish_structure(response)
        }
        return validation
    
    def _check_finnish_structure(self, text: str) -> bool:
        """Check if the response follows typical Finnish sentence structure."""
        finnish_endings = [
            'ssa', 'ssä', 'sta', 'stä', 'lla', 'llä', 'lta', 'ltä',
            'ksi', 'in', 'en', 'teen', 'seen'
        ]
        words = text.lower().split()
        if not words:
            return False
            
        has_finnish_ending = any(
            any(word.endswith(ending) for ending in finnish_endings)
            for word in words
        )
        
        return has_finnish_ending
    
    def _print_test_result(self, question: str, result: Dict, metrics: Dict):
        """Print formatted test results."""
        print("\n" + "="*80)
        print(f"Question: {question}")
        print(f"Answer: {result['answer']}")
        print("\nMetrics:")
        print(f"- Source count: {metrics['source_count']}")
        print(f"- Max similarity: {metrics['max_similarity']:.2%}")
        print(f"- Response length: {metrics['response_length']}")
        print("\nValidation:")
        for key, value in metrics['response_quality'].items():
            print(f"- {key}: {'✅' if value else '❌'}")
        print("="*80)


def run_rag_tests(pipeline: RAGPipeline, test_questions: List[str] = None):
    """Execute RAG tests with default or custom test questions."""
    if test_questions is None:
        test_questions = [
            "Onko Marjatta Eilan ystävä?",
            "Miten Sulo kokee sosiaalisen kanssakäymisen merkityksen?",
            "Montako sisarusta Sulolla on?",
            "Millainen on Eilan arki?",
            "Mikä on Sulolle tärkeää?"
        ]
    
    tester = RAGTester(pipeline)
    results = tester.run_test_suite(test_questions)
    
    # Print summary
    print("\nTest Summary:")
    print(f"Total questions: {results['summary']['total_questions']}")
    print(f"Successful responses: {results['summary']['successful_responses']}")
    print(f"Failed responses: {results['summary']['failed_responses']}")
    print(f"Average similarity: {results['summary']['average_similarity']:.2%}")
    print(f"Direct quote ratio: {results['summary']['direct_quote_ratio']:.2%}")
    
    return results

def verify_testing_infrastructure():
    pipeline = RAGPipeline()
    tester = RAGTester(pipeline)
    return bool(tester)

verify_section("Testing Infrastructure", verify_testing_infrastructure)

# Add these imports at the top of your file
from langchain.memory import ConversationBufferMemory
from typing import List, Dict, Any, Optional
import re

class FinnishRAGAgent:
    def __init__(self, base_pipeline: RAGPipeline):
        self.pipeline = base_pipeline
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.setup_tools()
        
    def setup_tools(self):
        """Initialize search and analysis tools."""
        self.tools = {
            "semantic_search": self._semantic_search,
            "exact_match": self._exact_match_search,
            "context_analysis": self._analyze_context
        }
    
    def _semantic_search(self, query: str) -> List[Dict]:
        """Enhanced semantic search with Finnish preprocessing."""
        preprocessed_query = self._preprocess_finnish_text(query)
        base_results = self.pipeline.query(preprocessed_query, top_k=5)
        
        # Enhance results with Finnish-specific scoring
        enhanced_results = []
        for result in base_results['sources']:
            score = self._calculate_finnish_relevance(preprocessed_query, result['text'])
            enhanced_results.append({
                'text': result['text'],
                'score': score * result['score'],  # Combine scores
                'source': result['document_id'],
                'metadata': {
                    'person_name': result['person_name'],
                    'chunk_index': result['chunk_index']
                }
            })
        
        return sorted(enhanced_results, key=lambda x: x['score'], reverse=True)

    def _exact_match_search(self, query: str) -> List[Dict]:
        """Direct text matching with Finnish normalization."""
        normalized_query = self._normalize_finnish_text(query)
        
        # Use pipeline's collection directly
        collection = self.pipeline.collection
        results = collection.query(
            expr=f'text like "%{normalized_query}%"',
            output_fields=["text", "person_name", "document_id", "chunk_index"]
        )
        
        return [
            {
                'text': r['text'],
                'score': 1.0, 
                'source': r['document_id'],
                'metadata': {'person_name': r['person_name']}
            }
            for r in results
        ]

    def _analyze_context(self, passages: List[Dict]) -> Dict:
        """Analyze relationships and context in Finnish text."""
        context_data = {
            'entities': set(),
            'relationships': [],
            'temporal_refs': [],
            'key_topics': set()
        }
        
        for passage in passages:
            text = passage['text']
            
            # Extract Finnish names and entities
            entities = self._extract_finnish_entities(text)
            context_data['entities'].update(entities)
            
            # Find relationships
            relationships = self._find_relationships(text)
            context_data['relationships'].extend(relationships)
            
            # Extract temporal information
            temporal = self._extract_temporal_refs(text)
            context_data['temporal_refs'].extend(temporal)
            
        return context_data

    def _preprocess_finnish_text(self, text: str) -> str:
        """Preprocess Finnish text for better matching."""
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text.strip())
        
        # Handle common Finnish abbreviations
        abbreviations = {
            'esim.': 'esimerkiksi',
            'ns.': 'niin sanottu',
            'jne.': 'ja niin edelleen'
        }
        for abbr, full in abbreviations.items():
            text = text.replace(abbr, full)
        
        return text

    # Add these methods to your FinnishRAGAgent class

    def _find_relationships(self, text: str) -> List[Dict]:
        """Find relationships in Finnish text."""
        relationships = []
        
        # Common Finnish relationship patterns
        patterns = [
            (r'(\w+)\s+on\s+(\w+)\s+ystävä', 'ystävyys'),
            (r'(\w+)\s+asuu\s+(\w+)', 'asuminen'),
            (r'(\w+)\s+tekee\s+(\w+)', 'toiminta'),
            (r'(\w+)\s+pitää\s+(\w+)', 'pitäminen'),
            (r'(\w+)\s+kanssa', 'yhteys'),
            (r'(\w+)\s+tärkeä', 'tärkeys'),
            (r'(\w+)\s+auttaa\s+(\w+)', 'auttaminen')
        ]
        
        for pattern, rel_type in patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                relationships.append({
                    'type': rel_type,
                    'entities': match.groups(),
                    'text': match.group(0)
                })
        
        return relationships

    def _extract_temporal_refs(self, text: str) -> List[Dict]:
        """Extract temporal references from Finnish text."""
        temporal_refs = []
        
        # Common Finnish temporal patterns
        patterns = [
            (r'\d+\s*vuotta', 'duration'),
            (r'\d+\s*vuotias', 'age'),
            (r'(maanantai|tiistai|keskiviikko|torstai|perjantai|lauantai|sunnuntai)', 'weekday'),
            (r'(tammikuu|helmikuu|maaliskuu|huhtikuu|toukokuu|kesäkuu|heinäkuu|elokuu|syyskuu|lokakuu|marraskuu|joulukuu)', 'month'),
            (r'(aamu|päivä|ilta|yö)', 'time_of_day'),
            (r'(eilen|tänään|huomenna)', 'relative_day'),
            (r'(viikko|kuukausi|vuosi)', 'time_unit')
        ]
        
        for pattern, ref_type in patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                temporal_refs.append({
                    'type': ref_type,
                    'text': match.group(0),
                    'position': match.span()
                })
        
        return temporal_refs
    
    def _identify_key_topics(self, text: str) -> set:
        """Identify key topics in Finnish text."""
        topics = set()
        
        # Common Finnish topic indicators
        key_patterns = [
            (r'tärkeä\w*\s+(\w+)', 'importance'),
            (r'harrastaa\w*\s+(\w+)', 'hobby'),
            (r'pitää\w*\s+(\w+)', 'preference'),
            (r'ongelma\w*\s+(\w+)', 'problem'),
            (r'tavoite\w*\s+(\w+)', 'goal')
        ]
        
        for pattern, topic_type in key_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                if len(match.groups()) > 0:
                    topics.add(f"{topic_type}:{match.group(1)}")
        
        return topics

    def _normalize_finnish_text(self, text: str) -> str:
        """Normalize Finnish text for comparison."""
        text = text.lower().strip()
        text = re.sub(r'[^\w\s\äöåÄÖÅ]', '', text)
        return text

    def _calculate_finnish_relevance(self, query: str, text: str) -> float:
        """Calculate relevance score for Finnish text."""
        score = 1.0
        
        # Boost score for Finnish grammar patterns
        if re.search(r'\b(ssa|ssä|sta|stä|lla|llä|lta|ltä)\b', text):
            score *= 1.1
            
        # Boost for question-answer pairs
        if '?' in query and '.' in text[:100]:
            score *= 1.2
            
        # Check for named entity matches
        query_entities = self._extract_finnish_entities(query)
        text_entities = self._extract_finnish_entities(text)
        if query_entities & text_entities:
            score *= 1.3
            
        return min(1.0, score)

    def _extract_finnish_entities(self, text: str) -> set:
        """Extract Finnish named entities."""
        entities = set()
        
        # Match Finnish names (simplified)
        name_pattern = r'\b[A-ZÄÖÅ][a-zäöå]+\b'
        entities.update(re.findall(name_pattern, text))
        
        return entities
    
    def _generate_response(self, query: str, results: List[Dict], context: Dict) -> Dict:
        """Generate an enhanced response with proper citations."""
        if not results:
            return {
                "answer": "En löytänyt vastausta kysymykseesi saatavilla olevista dokumenteista.",
                "confidence": 0.0,
                "sources": []
            }
    
        try:
            # Build context-rich query
            enhanced_query = self._build_enhanced_query(query, results)
            
            # Get response from base pipeline
            response = self.pipeline.query(enhanced_query)
            
            # Process and validate response
            answer = self._process_response(response, results, query)
            
            return {
                "answer": answer,
                "confidence": max(r['score'] for r in results),
                "sources": results[:3],
                "context_analysis": context
            }
                
        except Exception as e:
            print(f"Error in response generation: {str(e)}")
            return {
                "answer": "Virhe vastauksen muodostamisessa.",
                "confidence": 0.0,
                "sources": results[:3],
                "context_analysis": context
            }

    def _build_enhanced_query(self, query: str, results: List[Dict]) -> str:
        """Build a context-rich query with relevant information."""
        # Analyze query intent
        query_type = self._analyze_query_type(query)
        
        # Build appropriate context structure based on query type
        context = f"""Etsi tarkka vastaus seuraavaan kysymykseen käyttäen alla olevia dokumentteja.
    
    Kysymys: {query}
    
    Dokumentit:
    """
        
        # Add relevant documents with source information
        for i, result in enumerate(results[:3], 1):
            context += f"\nDokumentti {i} ({result['source']}):\n{result['text'].strip()}\n"
    
        # Add query-specific instructions
        context += self._get_query_instructions(query_type)
        
        return context
    
    def _analyze_query_type(self, query: str) -> str:
        """Analyze the type of query for better response structuring."""
        query_lower = query.lower()
        
        # Extract query characteristics without hardcoding specific names or values
        characteristics = {
            'personal_info': any(word in query_lower for word in ['kuka', 'kenen', 'kenelle']),
            'age_related': any(word in query_lower for word in ['vanha', 'ikä', 'syntymä', 'vuosi']),
            'activity_related': any(word in query_lower for word in ['tekee', 'harrastaa', 'pitää', 'tykkää']),
            'ability_related': any(word in query_lower for word in ['pystyy', 'osaa', 'liikkuu', 'käyttää']),
            'preference_related': any(word in query_lower for word in ['pitää', 'tykkää', 'haluaa']),
            'relationship_related': any(word in query_lower for word in ['ystävä', 'tärkeä', 'läheinen'])
        }
        
        # Return the most likely query type
        return max(characteristics.items(), key=lambda x: x[1])[0]
    
    def _get_query_instructions(self, query_type: str) -> str:
        """Get specific instructions based on query type."""
        instructions = {
            'personal_info': """
        Ohje: 
        1. Etsi henkilöön liittyvät suorat maininnat
        2. Käytä suoria lainauksia henkilöiden nimistä ja suhteista
        3. Mainitse dokumentin lähde""",
                
                'age_related': """
        Ohje:
        1. Etsi tarkat ikä- ja syntymävuositiedot
        2. Ilmoita sekä ikä että syntymävuosi jos molemmat löytyvät
        3. Mainitse dokumentin lähde""",
                
                'activity_related': """
        Ohje:
        1. Etsi kaikki mainitut aktiviteetit ja harrastukset
        2. Käytä suoria lainauksia aktiviteettien kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'ability_related': """
        Ohje:
        1. Etsi kuvaukset henkilön kyvyistä ja toiminnasta
        2. Käytä suoria lainauksia toimintakyvyn kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'preference_related': """
        Ohje:
        1. Etsi kaikki mainitut mieltymykset ja kiinnostukset
        2. Käytä suoria lainauksia mieltymysten kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'relationship_related': """
        Ohje:
        1. Etsi kuvaukset ihmissuhteista ja tärkeistä henkilöistä
        2. Käytä suoria lainauksia suhteiden kuvauksista
        3. Mainitse dokumentin lähde"""
            }
            
        return instructions.get(query_type, """
        Ohje:
        1. Etsi suora vastaus kysymykseen dokumenteista
        2. Käytä suoria lainauksia
        3. Mainitse dokumentin lähde""")
    
    def _process_response(self, response: Dict, results: List[Dict], query: str) -> str:
        """Process and validate the response."""
        if not response or 'answer' not in response:
            return "En löytänyt vastausta annetusta kontekstista."
        
        answer = response['answer'].strip()
        
        # Validate the answer has proper citations
        if not any(f"Dokumentin {result['source']}" in answer for result in results):
            # Try to add source information if missing
            for result in results:
                if any(quote in result['text'] for quote in re.findall(r'"([^"]*)"', answer)):
                    answer = f"Dokumentin {result['source']} mukaan {answer}"
                    break
        
        # Validate answer has quotes
        if '"' not in answer and any(result['text'] in answer for result in results):
            answer = re.sub(r'(Dokumentin [^\s]+ mukaan) (.*)', r'\1 "\2"', answer)
        
        return answer
    
    def process_query(self, query: str) -> Dict:
        """Process a query using all available tools."""
        try:
            # Step 1: Get semantic search results
            semantic_results = self._semantic_search(query)
            
            # Step 2: Check for exact matches if needed
            if not semantic_results or max(r['score'] for r in semantic_results) < 0.5:
                exact_matches = self._exact_match_search(query)
                all_results = semantic_results + exact_matches
            else:
                all_results = semantic_results
            
            # Step 3: Analyze context
            context_data = self._analyze_context(all_results[:3])
            
            # Step 4: Generate enhanced response
            response = self._generate_response(query, all_results, context_data)
            
            # Store in memory for future context
            self.memory.save_context(
                {"input": query},
                {"output": response['answer']}
            )
            
            # Print formatted response for immediate feedback
            print("\nVASTAUS:")
            print("-" * 40)
            print(response['answer'])
            print("-" * 40)
            
            return response
            
        except Exception as e:
            print(f"Error processing query: {str(e)}")
            raise
def run_neo4j_enhanced_tests():
    try:
        print("\nStarting Neo4j Enhanced RAG test execution...")
        
        # Initialize base pipeline
        print("\n1. Initializing base pipeline...")
        base_pipeline = RAGPipeline()
        print("   ✓ Base pipeline initialized")
        
        # Initialize Neo4j enhanced pipeline
        print("\n2. Creating Neo4j Enhanced Pipeline...")
        neo4j_rag = Neo4jEnhancedRAGPipeline(
            base_rag_pipeline=base_pipeline,
            neo4j_uri="neo4j://neo4j:7687"  # Update with your Neo4j URI
        )
        print("   ✓ Neo4j Enhanced pipeline created")
        
        # Process documents
        print("\n3. Processing documents...")
        folder_path = '/home/jovyan/work/notebooks/data/'  # Update with your actual path
        neo4j_rag.process_documents(folder_path)
        print("   ✓ Documents processed with Neo4j enhancement")
        
        # Define test questions
        test_questions = [
            "Onko Marjatta Eilan ystävä?",
            "Miten Sulo kokee sosiaalisen kanssakäymisen merkityksen?",
            "Montako sisarusta Sulolla on?",
            "Millainen on Eilan arki?",
            "Mikä on Sulolle tärkeää?"
        ]
        
        # Run tests
        print("\n4. Running test questions with Neo4j enhancement...")
        results = []
        for i, question in enumerate(test_questions, 1):
            try:
                print(f"\nKysymys {i}/{len(test_questions)}:")
                print("-"*80)
                print(f"{question}")
                print("-"*80)
                
                # Get results from Neo4j enhanced pipeline
                response = neo4j_rag.query(question)
                results.append({
                    "question": question,
                    "response": response
                })
                
                # Print results
                print("\nNeo4j Enhanced Vastaus:")
                print(response.get('answer', 'Ei vastausta'))
                print("\nLähteet ja konteksti:")
                print("- Dokumenttilähteet:")
                for src in response.get('sources', [])[:2]:
                    print(f"  • {src['document_id']}: {src['text'][:100]}...")
                print("- Graafikonteksti:")
                if response.get('person_context'):
                    print(f"  • Henkilö: {response['person_context']['name']}")
                    print(f"  • Suhteet: {', '.join(response['person_context'].get('relationship_types', []))}")
                print("-"*80)
                    
            except Exception as e:
                print(f"Virhe kysymyksen käsittelyssä: {str(e)}")
        
        print("\n5. Neo4j Enhanced testaus valmis")
        return results
        
    except Exception as e:
        print(f"\nVirhe Neo4j testauksessa: {str(e)}")
        raise

# Add this after your existing test runs
if __name__ == "__main__":
    # Run both standard and Neo4j enhanced tests for comparison
    print("\nRunning standard RAG tests...")
    standard_results = run_improved_rag_tests()
    
    print("\nRunning Neo4j enhanced RAG tests...")
    neo4j_results = run_neo4j_enhanced_tests()
    
    # Compare results
    print("\nComparison of results:")
    for std, neo in zip(standard_results, neo4j_results):
        print("\nQuestion:", std["question"])
        print("\nStandard RAG Answer:")
        print(std["response"]["answer"])
        print("\nNeo4j Enhanced Answer:")
        print(neo["response"]["answer"])
        print("-"*80)
    
def run_improved_rag_tests():
    try:
        print("\nStarting RAG test execution...")
        
        # Initialize base pipeline
        print("\n1. Initializing base pipeline...")
        base_pipeline = RAGPipeline()
        print("   ✓ Base pipeline initialized")
        
        # Initialize Finnish RAG Agent
        print("\n2. Creating Finnish RAG Agent...")
        agent = FinnishRAGAgent(base_pipeline)
        print("   ✓ Agent created")
        
        # Process documents
        print("\n3. Processing documents...")
        folder_path = '/home/jovyan/work/notebooks/data/'
        base_pipeline.process_documents(folder_path)
        print("   ✓ Documents processed")
        
        # Define test questions
        test_questions = [
            "Kuka on tärkeitä henkilöitä Eilalle?",
            "Kuinka vanha on Sulo ja mikä on hänen syntymävuotensa?",
            "Mitä Eila harrastaa?",
            "Miten Sulo liikkuu?",
            "Mistä asioista Eila pitää?"
        ]
        
        # Run tests through the agent
        print("\n4. Running test questions...")
        results = []
        for i, question in enumerate(test_questions, 1):
            try:
                print(f"\nKysymys {i}/{len(test_questions)}:")
                print("-"*80)
                print(f"{question}")
                print("-"*80)
                
                response = agent.process_query(question)
                results.append({
                    "question": question,
                    "response": response
                })
                
                print("\nVastaus:")
                print(response.get('answer', 'Ei vastausta'))
                print("\nLähde:")
                for src in response.get('sources', [])[:1]:
                    print(f"- {src['source']}: {src['text'][:150]}...")
                print("-"*80)
                    
            except Exception as e:
                print(f"Virhe kysymyksen käsittelyssä: {str(e)}")
        
        print("\n5. Testaus valmis")
        return results
        
    except Exception as e:
        print(f"\nVirhe testauksessa: {str(e)}")
        raise
## call the function so that it runs
run_improved_rag_tests()





KeyboardInterrupt: 