In [None]:
## section 1
# Imports
import os
import re
import nltk
import torch
import numpy as np
import logging
from typing import List, Dict, Any
from docx import Document
from tqdm.notebook import tqdm
from pymilvus import connections, Collection, utility, CollectionSchema, FieldSchema, DataType
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO)
import hashlib
import time
from typing import Optional
import logging

logger = logging.getLogger(__name__)

## Env 
os.environ["TOKENIZERS_PARALLELISM"] = "false" 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

# Utility Functions
def print_status(section_name: str, status: bool, message: str = ""):
    """Print status of a section with colored output."""
    status_str = "✅ SUCCESS" if status else "❌ FAILED"
    print(f"\n{status_str} | {section_name}")
    if message:
        print(f"  └─ {message}")

def verify_section(section_number: int, verification_func) -> bool:
    """Verify if a section was executed successfully."""
    try:
        result = verification_func()
        print_status(f"Section {section_number} Verification", True, "Successfully executed")
        return True
    except Exception as e:
        print_status(f"Section {section_number} Verification", False, f"Error: {str(e)}")
        return False

def get_model_embedding_dim(model_name: str = "TurkuNLP/bert-base-finnish-cased-v1") -> int:
    """Get embedding dimension from model config."""
    model = AutoModel.from_pretrained(model_name)
    return model.config.hidden_size

def check_cuda():
    """Check CUDA availability and print status."""
    try:
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name(0)
            print_status("CUDA Check", True, f"Using GPU: {device_name}")
            return True
        else:
            print_status("CUDA Check", True, "Using CPU")
            return True
    except Exception as e:
        print_status("CUDA Check", False, str(e))
        return False

def ensure_stopwords_downloaded(language='finnish'):
    """Download NLTK stopwords and print status."""
    try:
        nltk.download('stopwords', quiet=True)
        print_status("NLTK Setup", True, f"Downloaded {language} stopwords")
        return True
    except Exception as e:
        print_status("NLTK Setup", False, str(e))
        return False

# Global Constants
#EMBEDDING_DIM = get_model_embedding_dim()
EMBEDDING_DIM = 768
MILVUS_HOST = "87.92.59.201"
MILVUS_PORT = "19530"
MILVUS_ALIAS = "default"

# Initial Setup Verification
def verify_initial_setup():
    check_cuda()
    ensure_stopwords_downloaded()
    print_status("Embedding Dimension", True, f"Using dimension: {EMBEDDING_DIM}")
    
# Run initial verification
verify_initial_setup()

## section2
# import just in case it doesn't load properly!
from typing import Optional, List, Dict
import logging
import hashlib
import time
import torch.nn.functional as F

logger = logging.getLogger(__name__)
class TokenManager:
    def __init__(self, max_tokens: int = 2048):
        self.max_tokens = max_tokens
        
    def truncate_context(self, context: List[Dict], tokenizer) -> List[Dict]:
        """Truncate context while preserving most relevant information."""
        total_tokens = 0
        truncated_context = []
        
        # Sort by relevance
        sorted_context = sorted(context, key=lambda x: x.get('score', 0), reverse=True)
        
        for doc in sorted_context:
            # Get tokens for current document
            try:
                tokens = tokenizer.encode(doc['text'])
                token_count = len(tokens)
                
                # Check if adding this document would exceed limit
                if total_tokens + token_count > self.max_tokens:
                    # If we can fit a partial document, do so
                    remaining_tokens = self.max_tokens - total_tokens
                    if remaining_tokens > 100:  # Only if worth including
                        partial_tokens = tokens[:remaining_tokens]
                        truncated_text = tokenizer.decode(partial_tokens, skip_special_tokens=True)
                        # End at a complete sentence if possible
                        last_period = truncated_text.rfind('.')
                        if last_period > 0:
                            truncated_text = truncated_text[:last_period + 1]
                        truncated_context.append({
                            **doc,
                            'text': truncated_text,
                            'truncated': True
                        })
                    break
                
                total_tokens += token_count
                truncated_context.append(doc)
                
            except Exception as e:
                logger.error(f"Error truncating document: {str(e)}")
                continue
                
        return truncated_context

    def format_for_llm(self, context: List[Dict], question: str) -> str:
        """Format truncated context for LLM input."""
        formatted_parts = [
            "Kysymys: " + question,
            "\nKonteksti:"
        ]
        
        for i, doc in enumerate(context, 1):
            doc_text = (
                f"\nDokumentti {i}:\n"
                f"Lähde: {doc.get('source', 'Tuntematon')}\n"
                f"Teksti: {doc['text']}\n"
                f"{'(Katkaistu)' if doc.get('truncated') else ''}"
            )
            formatted_parts.append(doc_text)
            
        return "\n".join(formatted_parts)
class TokenLimitManager:
    def __init__(self, max_tokens: int = 2048):
        self.max_tokens = max_tokens
        self.tokenizer = None
        
    def setup_tokenizer(self, tokenizer):
        """Set up tokenizer for token counting."""
        self.tokenizer = tokenizer
        
    def truncate_text(self, text: str, max_tokens: Optional[int] = None) -> str:
        """Truncate text to fit within token limit."""
        if not max_tokens:
            max_tokens = self.max_tokens
            
        tokens = self.tokenizer.encode(text)
        if len(tokens) <= max_tokens:
            return text
            
        # Truncate tokens and decode back to text
        truncated_tokens = tokens[:max_tokens]
        truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
        
        # Try to end at a sentence boundary
        last_period = truncated_text.rfind('.')
        if last_period > 0:
            return truncated_text[:last_period + 1]
        return truncated_text
        
# Core Pipeline Components
class DocumentProcessor:
    def __init__(self, chunk_size=400, chunk_overlap=80):
        try:
            self.text_splitter = RecursiveCharacterTextSplitter(
                separators=[
                    "\n\n",  # First split on double newlines
                    "\n",    # Then single newlines
                    ".",     # Then sentence endings
                    ":",     # Then colons (common in Finnish formatting)
                    ";",     # Then semicolons
                    ",",     # Then commas
                    " ",     # Finally, split on spaces if needed
                    ""
                ],
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len,
                keep_separator=True,
                add_start_index=True
            )

            # Add Finnish-specific cleaning patterns
            self.clean_patterns = [
                (r'\s+', ' '),  # Normalize whitespace
                (r'[\(\{\[\]\}\)]', ''),  # Remove brackets
                (r'[^\w\s\.\,\?\!\-\:\;äöåÄÖÅ]', ''),  # Keep Finnish characters
                (r'\s+\.', '.'),  # Fix spacing around periods
                (r'\.+', '.'),  # Normalize multiple periods
            ]
            print_status("Document Processor", True, "Initialized with Finnish-optimized settings")
        except Exception as e:
            print_status("Document Processor", False, str(e))
            raise

    def _calculate_chunk_importance(self, chunk: str) -> float:
        """Calculate importance score for chunk based on Finnish text patterns."""
        score = 1.0
        
        # Key phrase indicators (extended for improved scoring)
        key_phrases = {
            'high': ['erittäin tärkeä', 'merkittävä', 'olennainen', 'keskeinen'],
            'medium': ['tärkeä', 'huomattava', 'kiinnostava'],
            'low': ['mainittava', 'mahdollinen']
        }
        
        # Check for key phrases with weighted importance
        chunk_lower = chunk.lower()
        for phrase in key_phrases['high']:
            if phrase in chunk_lower:
                score *= 1.3
        for phrase in key_phrases['medium']:
            if phrase in chunk_lower:
                score *= 1.2
        for phrase in key_phrases['low']:
            if phrase in chunk_lower:
                score *= 1.1
        
        # Prefer complete sentences
        sentence_count = len(re.findall(r'[.!?]+', chunk))
        if sentence_count > 0:
            score *= (1 + (0.1 * sentence_count))
        
        # Prefer chunks with personal pronouns
        if any(pronoun in chunk_lower for pronoun in ["minä", "minun", "minua", "minulla"]):
            score *= 1.15
        
        # Prefer chunks with specific details
        if re.search(r'\d+', chunk):  # Contains numbers
            score *= 1.1
        
        # Prefer chunks with names
        if re.search(r'[A-ZÄÖÅ][a-zäöå]+', chunk):
            score *= 1.1
        
        return score

    def process_document(self, file_path: str) -> List[Dict[str, Any]]:
        """Process document with improved metadata and chunking."""
        try:
            # Read document
            doc = Document(file_path)
            text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
            
            # Extract metadata
            filename = os.path.basename(file_path)
            name, age, doc_id = self.extract_metadata_from_filename(filename)
            
            # Preprocess and split text
            clean_text = self.preprocess_text(text)
            chunks = self.text_splitter.split_text(clean_text)
            
            # Create chunks with enhanced metadata
            processed_chunks = []
            for i, chunk in enumerate(chunks):
                # Calculate semantic importance score
                importance_score = self._calculate_chunk_importance(chunk)
                
                processed_chunks.append({
                    "text": chunk,
                    "metadata": {
                        "source": filename,
                        "person_name": name,
                        "person_age": age,
                        "document_id": doc_id,
                        "chunk_index": i,
                        "importance_score": importance_score,
                        "chunk_length": len(chunk),
                        "contains_question": "?" in chunk,
                    }
                })
            
            print_status("Document Processing", True, 
                        f"Processed {filename} into {len(processed_chunks)} chunks")
            return processed_chunks
            
        except Exception as e:
            print_status("Document Processing", False, f"Error processing {file_path}: {str(e)}")
            raise

    def extract_metadata_from_filename(self, filename: str) -> tuple:
        """Extract metadata from filename."""
        title = os.path.splitext(filename)[0]
        match = re.match(r'([A-Za-z]+)\s+(\d{1,3})v\s+([A-Za-z0-9\-]+)', title)
        if match:
            return match.group(1), int(match.group(2)), match.group(3)
        return None, None, None

    def preprocess_text(self, text: str) -> str:
        """Clean and normalize Finnish text with improved handling."""
        try:
            # Apply all cleaning patterns
            for pattern, replacement in self.clean_patterns:
                text = re.sub(pattern, replacement, text)
            
            # Ensure proper sentence boundaries
            text = re.sub(r'([.!?])\s*([A-ZÄÖÅ])', r'\1\n\2', text)
            
            # Remove extra whitespace while preserving paragraph breaks
            text = '\n'.join(line.strip() for line in text.split('\n'))
            return text.strip()
        except Exception as e:
            print_status("Text Preprocessing", False, str(e))
            raise

class MilvusManager:
    def __init__(self, host: str = "87.92.59.201", port: str = "19530", alias: str = "default"):
        self.host = str(host)
        self.port = str(port)
        self.alias = alias
        self.connected = False
        self.connect()
        
    def connect(self):
        """Establish connection to Milvus."""
        try:
            try:
                connections.remove_connection(alias=self.alias)
                print_status("Milvus Connection", True, "Cleaned up existing connection")
            except:
                pass
            
            connections.connect(
                alias=self.alias,
                host=self.host,
                port=self.port,
                timeout=10.0
            )
            
            try:
                utility.get_server_version()
                self.connected = True
                print_status("Milvus Connection", True, f"Connected to {self.host}:{self.port}")
            except Exception as ve:
                raise Exception(f"Connection verification failed: {str(ve)}")
                
        except Exception as e:
            self.connected = False
            print_status("Milvus Connection", False, str(e))
            raise

    def create_collection(self, collection_name: str = "document_embeddings"):
        """Create Milvus collection with appropriate schema."""
        try:
            if utility.has_collection(collection_name):
                Collection(name=collection_name).drop()
                print_status("Milvus Collection", True, f"Dropped existing collection: {collection_name}")
                
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
                FieldSchema(name="person_name", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="person_age", dtype=DataType.INT64),
                FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="chunk_index", dtype=DataType.INT64)
            ]
            
            schema = CollectionSchema(
                fields=fields,
                description="Document embeddings collection",
                enable_dynamic_field=False
            )
            collection = Collection(name=collection_name, schema=schema)
            
            self.create_and_load_index(collection)
            
            print_status("Milvus Collection", True, f"Created new collection: {collection_name} with dim={EMBEDDING_DIM}")
            return collection
        except Exception as e:
            print_status("Milvus Collection", False, str(e))
            raise

    def create_and_load_index(self, collection):
        try:
            index_params = {
                "metric_type": "IP",
                "index_type": "IVF_FLAT",
                "params": {"nlist": 1024}
            }
            collection.create_index(field_name="embedding", index_params=index_params)
            print_status("Index Creation", True, "Created IVF_FLAT index")
            
            collection.load()
            print_status("Collection Load", True, "Loaded collection into memory")
            
        except Exception as e:
            print_status("Index Creation", False, str(e))
            raise

    def reload_collection(self, collection_name: str = "document_embeddings"):
        try:
            if utility.has_collection(collection_name):
                collection = Collection(name=collection_name)
                collection.load()
                print_status("Collection Reload", True, f"Reloaded collection: {collection_name}")
                return collection
            else:
                raise Exception(f"Collection {collection_name} does not exist")
        except Exception as e:
            print_status("Collection Reload", False, str(e))
            raise
class EmbeddingCache:
    def __init__(self, cache_dir: str = "/scratch/project_2011638/safdarih/huggingface_cache/embedding_cache"):
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        self.cache = {}
        
    def get_cache_path(self, text_hash: str) -> str:
        return os.path.join(self.cache_dir, f"{text_hash}.npy")
        
    def get(self, text: str) -> Optional[np.ndarray]:
        """Get embedding from cache."""
        text_hash = hashlib.md5(text.encode()).hexdigest()
        if text_hash in self.cache:
            return self.cache[text_hash]
            
        cache_path = self.get_cache_path(text_hash)
        if os.path.exists(cache_path):
            embedding = np.load(cache_path)
            self.cache[text_hash] = embedding
            return embedding
        return None
        
    def put(self, text: str, embedding: np.ndarray):
        """Store embedding in cache."""
        text_hash = hashlib.md5(text.encode()).hexdigest()
        self.cache[text_hash] = embedding
        cache_path = self.get_cache_path(text_hash)
        np.save(cache_path, embedding)
        
class EnhancedEmbeddingGenerator:
    def __init__(self, 
                 model_name: str = "TurkuNLP/sbert-cased-finnish-paraphrase",
                 cache_dir: str = "/scratch/project_2011638/safdarih/huggingface_cache",
                 device: str = None):
        self.model_name = model_name
        self.cache_dir = cache_dir
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.embedding_cache = {}
        self._initialize_model()

    def _initialize_model(self):
        """Initialize model with proper error handling."""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir
            )
            self.model = AutoModel.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir
            ).to(self.device)
            
            # Set model to evaluation mode
            self.model.eval()
            
            self.embedding_dim = self.model.config.hidden_size
            logger.info(f"Loaded {self.model_name} (dim={self.embedding_dim})")
        except Exception as e:
            logger.error(f"Model initialization failed: {str(e)}")
            raise

    def _mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Perform mean pooling on token embeddings."""
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
        )

    def generate(self, texts: List[str], batch_size: int = 8) -> np.ndarray:
        """Generate embeddings with improved batching and error handling."""
        try:
            all_embeddings = []
            
            # Process in batches
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                
                # Check cache first
                batch_embeddings = []
                uncached_texts = []
                uncached_indices = []
                
                for j, text in enumerate(batch_texts):
                    cache_key = hash(text)
                    if cache_key in self.embedding_cache:
                        batch_embeddings.append(self.embedding_cache[cache_key])
                    else:
                        uncached_texts.append(text)
                        uncached_indices.append(j)
                
                if uncached_texts:
                    # Generate new embeddings
                    with torch.no_grad():
                        encoded_input = self.tokenizer(
                            uncached_texts,
                            padding=True,
                            truncation=True,
                            max_length=512,
                            return_tensors='pt'
                        ).to(self.device)
                        
                        model_output = self.model(**encoded_input)
                        sentence_embeddings = self._mean_pooling(
                            model_output,
                            encoded_input['attention_mask']
                        )
                        # Normalize embeddings
                        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
                        
                        # Move to CPU and convert to numpy
                        embeddings_np = sentence_embeddings.cpu().numpy()
                        
                        # Cache new embeddings
                        for text, embedding in zip(uncached_texts, embeddings_np):
                            self.embedding_cache[hash(text)] = embedding
                            
                        # Insert new embeddings into correct positions
                        for idx, embedding in zip(uncached_indices, embeddings_np):
                            batch_embeddings.insert(idx, embedding)
                
                all_embeddings.extend(batch_embeddings)
            
            final_embeddings = np.vstack(all_embeddings)
            
            # Verify embedding dimension
            if final_embeddings.shape[1] != self.embedding_dim:
                raise ValueError(
                    f"Embedding dimension mismatch. Expected {self.embedding_dim}, "
                    f"got {final_embeddings.shape[1]}"
                )
            
            return final_embeddings
            
        except Exception as e:
            logger.error(f"Error generating embeddings: {str(e)}")
            raise

    def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray:
        """Compute cosine similarity between query and documents."""
        return np.dot(doc_embeddings, query_embedding.T).squeeze()

    def batch_compute_similarity(self, queries: List[str], documents: List[str]) -> np.ndarray:
        """Compute similarities between multiple queries and documents efficiently."""
        query_embeddings = self.generate(queries)
        doc_embeddings = self.generate(documents)
        
        # Compute similarity matrix
        similarity_matrix = np.dot(query_embeddings, doc_embeddings.T)
        
        return similarity_matrix

class RAGPipeline:
    def __init__(self, model_id: str = "Finnish-NLP/llama-7b-finnish-instruct-v0.2"):
        try:
            # Initialize token managers
            self.token_manager = TokenManager(max_tokens=2048)
            self.token_limit_manager = TokenLimitManager(max_tokens=2048)
            
            # Initialize core components
            self.doc_processor = DocumentProcessor()  
            self.embedding_generator = EnhancedEmbeddingGenerator(
                model_name="TurkuNLP/sbert-cased-finnish-paraphrase",
                cache_dir="/scratch/project_2011638/safdarih/huggingface_cache"
            )
            EMBEDDING_DIM = self.embedding_generator.embedding_dim
            
            # Initialize Milvus components
            self.milvus_manager = MilvusManager(
                host=MILVUS_HOST,
                port=MILVUS_PORT,
                alias=MILVUS_ALIAS
            )
            
            # Create collection if it doesn't exist
            self.collection_manager = CollectionManager()
            try:
                self.collection = self.collection_manager.collection
            except Exception as e:
                print(f"Creating new collection: {str(e)}")
                self.collection = self.milvus_manager.create_collection()

            # Initialize pipeline components
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )

            # Initialize tokenizer and model
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_id,
                cache_dir="/scratch/project_2011638/safdarih/huggingface_cache" 
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                quantization_config=bnb_config,
                device_map="auto",
                cache_dir="/scratch/project_2011638/safdarih/huggingface_cache"
            )

            # Initialize pipeline
            self.pipeline = pipeline(
                "text-generation",
                model=model,
                tokenizer=self.tokenizer,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.1,
                top_p=0.95,
                repetition_penalty=1.2
            )

            print_status("RAG Pipeline", True, "Successfully initialized all components")
        except Exception as e:
            print_status("RAG Pipeline", False, f"Initialization error: {str(e)}")
            raise

    def process_documents(self, folder_path: str) -> List[Dict]:
        """Process documents and prepare them for RAG."""
        try:
            # Get all .docx files in the folder
            file_paths = [f for f in os.listdir(folder_path) if f.endswith('.docx')]
            if not file_paths:
                raise ValueError(f"No .docx files found in {folder_path}")
            
            print(f"Found {len(file_paths)} documents to process")
            
            all_chunks = []
            for file in file_paths:
                file_path = os.path.join(folder_path, file)
                chunks = self.doc_processor.process_document(file_path)
                all_chunks.extend(chunks)
            
            if not all_chunks:
                raise ValueError("No chunks were processed from documents")
            
            # Generate embeddings for all chunks
            texts = [chunk["text"] for chunk in all_chunks]
            try:
                embeddings = self.embedding_generator.generate(texts)
                
                # Prepare entities for Milvus
                entities = []
                for chunk, embedding in zip(all_chunks, embeddings):
                    entity = {
                        "text": chunk["text"],
                        "embedding": embedding.tolist(),
                        "person_name": chunk["metadata"]["person_name"],
                        "person_age": chunk["metadata"]["person_age"],
                        "document_id": chunk["metadata"]["document_id"],
                        "chunk_index": chunk["metadata"]["chunk_index"]
                    }
                    entities.append(entity)
                
                # Insert into Milvus
                self.collection.insert(entities)
                self.collection.flush()
                
                print_status("Document Processing", True, 
                            f"Processed {len(all_chunks)} chunks from {len(file_paths)} documents")
                return all_chunks
                
            except Exception as e:
                print_status("Embedding Generation", False, f"Error generating embeddings: {str(e)}")
                raise
                
        except Exception as e:
            print_status("Document Processing", False, f"Error: {str(e)}")
            raise

    def _prepare_context(self, question: str, top_k: int = 5) -> List[Dict]:
        """Prepare and validate context for question."""
        try:
            # Reload collection only if needed (add this flag to __init__)
            if not hasattr(self, '_collection_loaded'):
                self.collection = self.milvus_manager.reload_collection()
                self._collection_loaded = True

            # Generate question embedding
            question_embedding = self.embedding_generator.generate([question])[0]
            
            # Search parameters
            search_params = {
                "metric_type": "IP",
                "params": {"nprobe": 50}
            }
            
            # Perform search with larger initial pool
            initial_results = self.collection.search(
                data=[question_embedding.tolist()],
                anns_field="embedding",
                param=search_params,
                limit=top_k * 2,  # Get more results initially for better reranking
                output_fields=["text", "person_name", "document_id", "chunk_index"]
            )
            
            # Process results
            processed_results = []
            for hit in initial_results[0]:
                processed_results.append({
                    'text': hit.entity.get('text'),
                    'source': hit.entity.get('document_id'),
                    'person_name': hit.entity.get('person_name'),
                    'chunk_index': hit.entity.get('chunk_index'),
                    'score': float(hit.score)
                })
            
            return processed_results[:top_k]  # Return top_k after processing
            
        except Exception as e:
            logger.error(f"Error preparing context: {str(e)}")
            raise

    def query(self, question: str, top_k: int = 5) -> Dict:
        """Enhanced query processing with token management."""
        try:
            # Get initial results
            results = self._prepare_context(question, top_k)
            
            # Truncate context to fit token limits
            truncated_context = self.token_manager.truncate_context(
                results,
                self.pipeline.tokenizer
            )
            
            # Format context for LLM
            formatted_input = self.token_manager.format_for_llm(
                truncated_context,
                question
            )
            
            # Generate response with proper prompt
            prompt = f"""Tehtävä: Etsi tarkka vastaus kysymykseen käyttäen vain annettua kontekstia.

{formatted_input}

Vastausohjeet:
1. Jos löydät suoran vastauksen:
   - Mainitse dokumentti, josta vastaus löytyy
   - Käytä suoria lainauksia
   - Arvioi vastauksen luotettavuus
2. Jos et löydä vastausta:
   - Ilmoita selkeästi, ettei vastausta löydy annetusta kontekstista
3. Jos löydät vain osittaisen vastauksen:
   - Kerro mitä tietoa löysit ja mitä puuttuu

Vastaus:"""

            # Generate response
            response = self.pipeline(
                prompt,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.1,
                top_p=0.85,
                repetition_penalty=1.2
            )[0]["generated_text"]
            
            return {
                "answer": response,
                "sources": truncated_context,
                "metadata": {
                    "question": question,
                    "num_chunks_retrieved": len(truncated_context),
                    "max_similarity_score": max(doc['score'] for doc in truncated_context)
                }
            }
            
        except Exception as e:
            logger.error(f"Error in query processing: {str(e)}")
            raise    
            
    def _rerank_results(self, question: str, results: List[Dict]) -> List[Dict]:
        """Enhanced reranking with Finnish-specific handling."""
        reranked = []
        
        # Get question embedding once and normalize
        question_embedding = self.embedding_generator.generate([question])[0]
        
        # Get all text embeddings at once for efficiency
        texts = [hit['text'] for hit in results]
        text_embeddings = self.embedding_generator.generate(texts)
        
        # Compute similarities all at once
        similarities = self.embedding_generator.compute_similarity(
            question_embedding, 
            text_embeddings
        )
        
        for hit, similarity in zip(results, similarities):
            score = hit['score']
            text = hit['text']
            
            # Finnish-specific boosts
            boost_score = 1.0
            
            # Boost for complete sentences
            if re.search(r'[.!?][\s]*', text):
                boost_score *= 1.1
                
            # Boost for answer indicators
            if any(indicator in text.lower() for indicator in ['on', 'ovat', 'oli']):
                boost_score *= 1.15
                
            # Boost for question-context match
            if any(word in text.lower() for word in question.lower().split()):
                boost_score *= 1.2
            
            final_score = score * similarity * boost_score
            reranked.append({**hit, 'score': final_score})
            
        return sorted(reranked, key=lambda x: x['score'], reverse=True)
    
    def _extract_entities(self, text: str) -> List[str]:
        """Simple entity extraction based on capitalized words."""
        return re.findall(r'\b[A-Z][a-zäöå]*\b', text)

    def _clean_response(self, response: str) -> str:
        response = re.sub(r'(\[Dokumentti \d+\])\s*\1', r'\1', response)
        response = re.sub(r'Luottamus: \d+%\s*Selitys:', '', response)
        if len(response) > 500:
            response = response[:497] + '...'
        return response.strip()

    def _create_prompt(self, question: str, context: str) -> str:
        return f"""Tehtävä: Vastaa annettuun kysymykseen käyttäen vain alla olevaa kontekstia.

        Kysymys: {question}
        
        Konteksti:
        {context}
        
        Vastausohjeet:
        1. Jos löydät suoran vastauksen:
           - Mainitse ensin dokumentti, josta vastaus löytyy
           - Lainaa tekstiä tarkasti käyttäen lainausmerkkejä
           - Mainitse vastauksen luotettavuus prosentteina
        
        2. Jos löydät osittaisen vastauksen:
           - Kerro, mitä tietoa löysit ja mistä
           - Mainitse selkeästi, mitä tietoa puuttuu
        
        3. Jos et löydä vastausta:
           - Vastaa vain: "En löydä suoraa vastausta annetusta kontekstista"
        
        Vastaus:"""
class CollectionManager:
    def __init__(self):
        self._collection = None
        self._last_reload_time = 0
        self.reload_interval = 300  # 5 minutes
        
    @property
    def collection(self):
        current_time = time.time()
        if (not self._collection or 
            current_time - self._last_reload_time > self.reload_interval):
            self._reload_collection()
        return self._collection
        
    def _reload_collection(self):
        try:
            if not utility.has_collection("document_embeddings"):
                raise Exception("Collection does not exist")
                
            self._collection = Collection(name="document_embeddings")
            self._collection.load()
            self._last_reload_time = time.time()
        except Exception as e:
            logger.error(f"Error reloading collection: {str(e)}")
            raise

    def create_collection(self, collection_name: str = "document_embeddings"):
        """Create a new collection if it doesn't exist."""
        try:
            if utility.has_collection(collection_name):
                Collection(name=collection_name).drop()
                print_status("Collection Manager", True, f"Dropped existing collection: {collection_name}")
            
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
                FieldSchema(name="person_name", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="person_age", dtype=DataType.INT64),
                FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="chunk_index", dtype=DataType.INT64)
            ]
            
            schema = CollectionSchema(
                fields=fields,
                description="Document embeddings collection",
                enable_dynamic_field=False
            )
            
            self._collection = Collection(name=collection_name, schema=schema)
            self._create_index()
            self._collection.load()
            self._last_reload_time = time.time()
            
            print_status("Collection Manager", True, f"Created new collection: {collection_name}")
            return self._collection
            
        except Exception as e:
            logger.error(f"Error creating collection: {str(e)}")
            raise

    def _create_index(self):
        """Create index for the collection."""
        try:
            index_params = {
                "metric_type": "IP",
                "index_type": "IVF_FLAT",
                "params": {"nlist": 1024}
            }
            self._collection.create_index(
                field_name="embedding",
                index_params=index_params
            )
            print_status("Collection Manager", True, "Created index")
        except Exception as e:
            logger.error(f"Error creating index: {str(e)}")
            raise
            

def verify_pipeline_components():
    doc_processor = DocumentProcessor()
    milvus_manager = MilvusManager()
    embedding_generator = EnhancedEmbeddingGenerator()
    pipeline = RAGPipeline()
    return all([doc_processor, milvus_manager, embedding_generator, pipeline])

verify_section("Pipeline Components", verify_pipeline_components)

## section3:
from neo4j import GraphDatabase
class Neo4jDocumentManager:
    def __init__(self, uri: str = "bolt://87.92.59.201:7687", 
                 username: str = "neo4j", 
                 password: str = "test"):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))
        self.setup_constraints()
        self.person_names = set()
        
    def get_related_documents(self, session, mentioned_persons: List[str]) -> List[Dict]:
        """Get documents related to mentioned persons."""
        try:
            documents = []
            for person in mentioned_persons:
                result = session.run("""
                    MATCH (p:Person {name: $name})-[:APPEARS_IN]->(d:Document)
                    RETURN d.content as content, d.id as doc_id, 
                           d.chunk_index as chunk_index
                    ORDER BY d.chunk_index
                """, {"name": person})
                
                for record in result:
                    documents.append({
                        "content": record["content"],
                        "doc_id": record["doc_id"],
                        "chunk_index": record["chunk_index"],
                        "relevance": 1.0
                    })
            
            return documents
            
        except Exception as e:
            logger.error(f"Error retrieving related documents: {str(e)}")
            return []

    # Remove the duplicate _get_related_documents method
    # All other methods remain the same
    def check_connection(self) -> bool:
        """Check if Neo4j connection is working."""
        try:
            with self.driver.session() as session:
                result = session.run("RETURN 1 as test")
                assert result.single()["test"] == 1
                print("Neo4j connection test successful")
                return True
        except Exception as e:
            print(f"Neo4j connection test failed: {str(e)}")
            return False

    def setup_constraints(self):
        """Setup necessary constraints for the graph."""
        try:
            with self.driver.session() as session:
                # Neo4j 5+ syntax for constraints
                constraints = [
                    """CREATE CONSTRAINT person_name_unique IF NOT EXISTS 
                       FOR (p:Person) REQUIRE p.name IS UNIQUE""",
                    """CREATE CONSTRAINT document_id_unique IF NOT EXISTS 
                       FOR (d:Document) REQUIRE d.id IS UNIQUE"""
                ]
                
                for constraint in constraints:
                    try:
                        session.run(constraint)
                    except Exception as e:
                        print(f"Warning: Could not create constraint: {str(e)}")
                        # Try older Neo4j syntax if newer one fails
                        try:
                            if "person_name" in constraint:
                                session.run("""CREATE CONSTRAINT ON (p:Person) 
                                             ASSERT p.name IS UNIQUE""")
                            elif "document_id" in constraint:
                                session.run("""CREATE CONSTRAINT ON (d:Document) 
                                             ASSERT d.id IS UNIQUE""")
                        except Exception as e2:
                            print(f"Warning: Could not create constraint with old syntax: {str(e2)}")
                            
                print("Neo4j constraints setup completed")
                
        except Exception as e:
            print(f"Warning: Error during constraint setup: {str(e)}")
            # Continue even if constraints fail - they're helpful but not crit
            pass
        
    def _extract_person_names(self, text: str) -> set:
        """Extract potential person names from text using Finnish patterns."""
        names = set()
        
        # Finnish name patterns
        patterns = [
            r'\b[A-ZÄÖÅ][a-zäöå]+\b',  # Basic capitalized word
            r'(?:herra|rouva|neiti)\s+[A-ZÄÖÅ][a-zäöå]+',  # Titles
            r'(?:ystävä|naapuri)\s+[A-ZÄÖÅ][a-zäöå]+',  # Relationships
        ]
        
        for pattern in patterns:
            matches = re.finditer(pattern, text)
            for match in matches:
                name = match.group().split()[-1]  # Get the actual name part
                if len(name) > 2:  # Filter out very short words
                    names.add(name)
        
        return names

    def create_document_graph(self, processed_chunks: List[Dict]):
        """Create graph structure from processed document chunks."""
        # First pass: collect all person names
        for chunk in processed_chunks:
            # Add names from metadata
            if chunk["metadata"].get("person_name"):
                self.person_names.add(chunk["metadata"]["person_name"])
            
            # Add names found in content
            self.person_names.update(self._extract_person_names(chunk["text"]))
        
        with self.driver.session() as session:
            for chunk in processed_chunks:
                # Create Document and Person nodes
                self._create_document_and_person_nodes(session, chunk)
                
                # Extract and create relationships
                self._process_relationships(session, chunk)
    
    def _create_document_and_person_nodes(self, session, chunk: Dict):
        """Create document and person nodes with relationships."""
        session.run("""
            MERGE (p:Person {name: $person_name})
            SET p.age = $age
            
            MERGE (d:Document {id: $doc_id})
            SET d.content = $content,
                d.chunk_index = $chunk_index
            
            MERGE (p)-[:APPEARS_IN]->(d)
            
            WITH p, d
            MATCH (p)-[:APPEARS_IN]->(other:Document)
            WHERE other.id <> d.id
            MERGE (d)-[:RELATED_TO]->(other)
        """, {
            "person_name": chunk["metadata"]["person_name"],
            "age": chunk["metadata"]["person_age"],
            "doc_id": f"{chunk['metadata']['document_id']}_{chunk['metadata']['chunk_index']}",
            "content": chunk["text"],
            "chunk_index": chunk["metadata"]["chunk_index"]
        })


    def _process_relationships(self, session, chunk: Dict):
        """Create relationships based on the content of each chunk."""
        # Extract persons mentioned in the text
        mentioned_persons = self.find_mentioned_persons(chunk["text"])
        
        # Get the main person from metadata
        main_person = chunk["metadata"].get("person_name")
        
        # If there is no main person in this chunk, skip relationship creation
        if not main_person:
            return
        
        # Loop over each mentioned person to create a relationship in the graph
        for mentioned_person in mentioned_persons:
            # Define the relationship type - here it's just an example relationship type
            relationship_type = "KNOWS"
            
            # Create relationship in Neo4j
            session.run("""
                MATCH (p1:Person {name: $main_person})
                MATCH (p2:Person {name: $mentioned_person})
                MERGE (p1)-[r:KNOWS]->(p2)
                RETURN r
            """, {
                "main_person": main_person,
                "mentioned_person": mentioned_person
            })   

    def find_mentioned_persons(self, text: str) -> List[str]:
        """Find persons mentioned in text from known person names."""
        mentioned_persons = []
        text_lower = text.lower()
        
        # Look for known person names in the text
        for name in self.person_names:
            if name.lower() in text_lower:
                mentioned_persons.append(name)
                
        return mentioned_persons

    def query(self, question: str) -> Dict:
        """Enhanced query processing with dynamic person detection."""
        # Find mentioned persons in the question
        
        
        with self.driver.session() as session:
            # Get person contexts for all mentioned persons
            mentioned_persons = self.find_mentioned_persons(question)
            related_docs = self.get_related_documents(session, mentioned_persons)
            
            person_contexts = []
            for person_name in mentioned_persons:
                context = self._get_person_context(session, person_name)
                if context:
                    person_contexts.append(context)       
            return {
                "person_contexts": person_contexts,
                "related_documents": related_docs,
                "mentioned_persons": mentioned_persons
            }
    
    def _get_person_context(self, session, person_name: str) -> Dict:
        """Get comprehensive context for a person."""
        result = session.run("""
            MATCH (p:Person {name: $name})
            OPTIONAL MATCH (p)-[:APPEARS_IN]->(d:Document)
            OPTIONAL MATCH (d)-[:HAS_RELATIONSHIP]->(r:Relationship)
            
            WITH p,
                 COLLECT(DISTINCT r.type) as relationship_types,
                 COLLECT(DISTINCT d.content) as contents,
                 COLLECT(DISTINCT d.id) as doc_ids
            
            RETURN {
                name: p.name,
                age: p.age,
                relationship_types: relationship_types,
                document_count: SIZE(contents),
                document_ids: doc_ids
            } as context
        """, {"name": person_name})
        
        return result.single()["context"] if result.peek() else None

class Neo4jEnhancedRAGPipeline:
    def __init__(self, base_rag_pipeline, neo4j_uri: str = "bolt://87.92.59.201:7687"):
        self.base_rag = base_rag_pipeline
        self.neo4j_manager = Neo4jDocumentManager(
            uri=neo4j_uri,
            username="neo4j",
            password="test"
        )
    
    def process_documents(self, folder_path: str):
        """Process documents and create graph structure."""
        try:
            # Process with base RAG first
            print("Processing documents with base RAG...")
            # Call process_documents and store all chunks
            all_chunks = []
            
            # Get all .docx files in the folder
            file_paths = [f for f in os.listdir(folder_path) if f.endswith('.docx')]
            if not file_paths:
                raise ValueError(f"No .docx files found in {folder_path}")
            
            print(f"Found {len(file_paths)} documents to process")
            
            for file in file_paths:
                file_path = os.path.join(folder_path, file)
                chunks = self.base_rag.doc_processor.process_document(file_path)
                all_chunks.extend(chunks)
            
            if not all_chunks:
                raise ValueError("No chunks were processed from documents")
            
            print(f"Processed {len(all_chunks)} chunks from documents")
            
            # Create graph structure
            print("Creating Neo4j graph structure...")
            self.neo4j_manager.create_document_graph(all_chunks)
            
            # Process with base RAG pipeline for vector embeddings
            self.base_rag.process_documents(folder_path)
            
            return all_chunks
            
        except Exception as e:
            print(f"Error processing documents: {str(e)}")
            raise
            
    def query(self, question: str) -> Dict:
        """Process a query using both vector and graph capabilities."""
        try:
            # Get base results from vector search
            base_response = self.base_rag.query(question)
            
            # Get graph context from Neo4j
            graph_context = self.neo4j_manager.query(question)
            
            # Combine contexts
            combined_context = self._combine_contexts(
                base_sources=base_response.get("sources", []),
                graph_sources=graph_context.get("related_documents", []),
                person_contexts=graph_context.get("person_contexts", [])
            )
            
            # Generate enhanced answer
            enhanced_answer = self._generate_enhanced_answer(question, combined_context)
            
            # Format response
            response = {
                "answer": enhanced_answer,
                "sources": base_response.get("sources", []),
                "person_context": graph_context.get("person_contexts", [])[0] if graph_context.get("person_contexts") else None,
                "metadata": {
                    "question": question,
                    "mentioned_persons": graph_context.get("mentioned_persons", []),
                    "base_similarity_score": base_response.get("metadata", {}).get("max_similarity_score", 0)
                }
            }
            
            return response
            
        except Exception as e:
            print(f"Error in Neo4j enhanced query: {str(e)}")
            # Fallback to base RAG if Neo4j enhancement fails
            return self.base_rag.query(question)
      
    def _combine_contexts(self, base_sources: List[Dict], 
                         graph_sources: List[Dict], 
                         person_contexts: List[Dict]) -> Dict:
        """Combine all sources of context intelligently."""
        combined = {
            "documents": [],
            "relationships": [],
            "persons": []
        }
        
        # Add base sources
        for source in base_sources:
            combined["documents"].append({
                "content": source["text"],
                "source": source["document_id"],
                "relevance": source["score"]
            })
        
        # Add graph-enhanced sources
        for source in graph_sources:
            if source not in combined["documents"]:
                combined["documents"].append({
                    "content": source["content"],
                    "source": source["doc_id"],
                    "relevance": source.get("relevance", 0.5)
                })
        
        # Add person contexts
        for context in person_contexts:
            combined["persons"].append({
                "name": context["name"],
                "age": context["age"],
                "relationships": context["relationship_types"],
                "document_count": context["document_count"]
            })
            
        return combined

    def _generate_enhanced_answer(self, question: str, context: Dict) -> str:
        """Generate answer using combined context with token management."""
        # Create enhanced prompt with token management
        formatted_context = self._format_context_within_limits(context)
        
        prompt = f"""Käytä seuraavaa kontekstia vastataksesi kysymykseen. Huomioi erityisesti henkilöiden väliset suhteet.
    
        Kysymys: {question}
    
        {formatted_context}
    
        Vastausohjeet:
        1. Käytä dokumenttien ja suhteiden tietoja yhdessä
        2. Mainitse lähteet selkeästi
        3. Korosta henkilöiden välisiä suhteita
        4. Käytä suoria lainauksia kun mahdollista
    
        Vastaus:"""
    
        # Generate answer using base pipeline
        response = self.base_rag.pipeline(
            prompt,
            max_new_tokens=300,
            do_sample=True,
            temperature=0.1,
            top_p=0.85,
            repetition_penalty=1.2
        )[0]["generated_text"]
        
        return response.strip()

    def _format_document_context(self, documents: List[Dict]) -> str:
        """Format document context for prompt."""
        return "\n\n".join([
            f"Dokumentti {i+1} ({doc['source']}):\n{doc['content']}"
            for i, doc in enumerate(documents)
        ])

    def _format_person_context(self, persons: List[Dict]) -> str:
        """Format person context for prompt."""
        return "\n\n".join([
            f"Henkilö: {person['name']}\n"
            f"Ikä: {person['age']}\n"
            f"Suhteet: {', '.join(person['relationships'])}"
            for person in persons
        ])

    def _truncate_text(self, text: str, max_tokens: int = 1800) -> str:
        """Truncate text to stay within token limits."""
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("Finnish-NLP/llama-7b-finnish-instruct-v0.2")
        tokens = tokenizer.encode(text)
        
        if len(tokens) <= max_tokens:
            return text
            
        # Truncate while trying to keep complete sentences
        truncated_tokens = tokens[:max_tokens]
        truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
        
        # Try to end at a sentence boundary
        last_period = truncated_text.rfind('.')
        if last_period > 0:
            truncated_text = truncated_text[:last_period + 1]
            
        return truncated_text

    def _format_context_within_limits(self, context: Dict) -> str:
        """Format context while respecting token limits."""
        # Allocate tokens for different parts
        DOC_TOKENS = 1200  # Reserve tokens for documents
        PERSON_TOKENS = 400  # Reserve tokens for person context
        INSTRUCTION_TOKENS = 200  # Reserve tokens for instructions
        
        formatted_docs = self._format_document_context(context['documents'])
        formatted_persons = self._format_person_context(context['persons'])
        
        # Truncate if needed
        formatted_docs = self._truncate_text(formatted_docs, DOC_TOKENS)
        formatted_persons = self._truncate_text(formatted_persons, PERSON_TOKENS)
        
        return f"Dokumenttikonteksti:\n{formatted_docs}\n\nHenkilökonteksti:\n{formatted_persons}"

def verify_neo4j_components():
    """Verify Neo4j components are working correctly."""
    try:
        # Create base pipeline
        base_pipeline = RAGPipeline()
        
        # Create enhanced pipeline
        neo4j_rag = Neo4jEnhancedRAGPipeline(
            base_rag_pipeline=base_pipeline,
            neo4j_uri="bolt://87.92.59.201:7687"
        )
        
        # Verify Neo4j connection
        with neo4j_rag.neo4j_manager.driver.session() as session:
            result = session.run("RETURN 1 as test")
            assert result.single()["test"] == 1
        
        print_status("Neo4j Components", True, "Connection verified")
        return True
        
    except Exception as e:
        print_status("Neo4j Components", False, f"Error: {str(e)}")
        return False

# Add verification call after other verifications
verify_section("Neo4j Components", verify_neo4j_components)

## section 4: 
# Testing Infrastructure
class RAGTester:
    def __init__(self, pipeline: RAGPipeline):
        self.pipeline = pipeline
        self.test_results = []
        
    def run_test_suite(self, test_questions: List[str]) -> Dict[str, Any]:
        """Run comprehensive tests on the RAG pipeline."""
        test_results = []
        summary_stats = {
            "total_questions": len(test_questions),
            "successful_responses": 0,
            "failed_responses": 0,
            "average_similarity": 0.0,
            "direct_quote_ratio": 0.0
        }
        
        for question in test_questions:
            try:
                # Get response from pipeline
                result = self.pipeline.query(question)
                
                # Validate response
                validation = self._validate_response(result["answer"])
                
                # Calculate response metrics
                metrics = {
                    "question": question,
                    "has_direct_quote": validation["has_direct_quote"],
                    "source_count": len(result["sources"]),
                    "max_similarity": max(s["similarity_score"] for s in result["sources"]),
                    "response_quality": validation,
                    "response_length": len(result["answer"]),
                }
                
                test_results.append(metrics)
                summary_stats["successful_responses"] += 1
                summary_stats["average_similarity"] += metrics["max_similarity"]
                summary_stats["direct_quote_ratio"] += int(metrics["has_direct_quote"])
                
                # Print detailed results
                self._print_test_result(question, result, metrics)
                
            except Exception as e:
                print(f"❌ Error testing question '{question}': {str(e)}")
                summary_stats["failed_responses"] += 1
                
        # Calculate final statistics
        if summary_stats["successful_responses"] > 0:
            summary_stats["average_similarity"] /= summary_stats["successful_responses"]
            summary_stats["direct_quote_ratio"] /= summary_stats["successful_responses"]
        
        return {
            "detailed_results": test_results,
            "summary": summary_stats
        }
    
    def _validate_response(self, response: str) -> Dict[str, bool]:
        """Validate Finnish language response with detailed checks."""
        validation = {
            # Basic structural checks
            "has_source_reference": bool(re.search(r'\[Dokumentti \d+\]', response)),
            "has_direct_quote": '"' in response,
            "is_complete_sentence": response.strip().endswith(('.', '?', '!')),
            "has_confidence": bool(re.search(r'\d+\s*%', response)),
            "reasonable_length": 10 <= len(response) <= 500,
            
            # Finnish language specific checks
            "has_finnish_chars": bool(re.search(r'[äöåÄÖÅ]', response)),
            "proper_finnish_structure": self._check_finnish_structure(response)
        }
        return validation
    
    def _check_finnish_structure(self, text: str) -> bool:
        """Check if the response follows typical Finnish sentence structure."""
        finnish_endings = [
            'ssa', 'ssä', 'sta', 'stä', 'lla', 'llä', 'lta', 'ltä',
            'ksi', 'in', 'en', 'teen', 'seen'
        ]
        words = text.lower().split()
        if not words:
            return False
            
        has_finnish_ending = any(
            any(word.endswith(ending) for ending in finnish_endings)
            for word in words
        )
        
        return has_finnish_ending
    
    def _print_test_result(self, question: str, result: Dict, metrics: Dict):
        """Print formatted test results."""
        print("\n" + "="*80)
        print(f"Question: {question}")
        print(f"Answer: {result['answer']}")
        print("\nMetrics:")
        print(f"- Source count: {metrics['source_count']}")
        print(f"- Max similarity: {metrics['max_similarity']:.2%}")
        print(f"- Response length: {metrics['response_length']}")
        print("\nValidation:")
        for key, value in metrics['response_quality'].items():
            print(f"- {key}: {'✅' if value else '❌'}")
        print("="*80)


def run_rag_tests(pipeline: RAGPipeline, test_questions: List[str] = None):
    """Execute RAG tests with default or custom test questions."""
    if test_questions is None:
        test_questions = [
            "Kuinka kauan Annikki oli naimisissa miehensä kanssa?",  # Clear answer: 55 years from Annikki's document
            #"Mitä digitaalisia laitteita Elisa käyttää?",  # Clear answer from Elisa's document about her devices
            #"Mikä on Jarmon tärkein toive arjessa?",  # Clear answer about wanting to maintain possibility for solitude
            #"Millä alalla Kosti työskentelee eläkkeellä ollessaan?",  # Clear answer: kauppa-ala from Kosti's document
            #"Montako lastenlasta Annikilla on?"
        ]
    
    tester = RAGTester(pipeline)
    results = tester.run_test_suite(test_questions)
    
    # Print summary
    print("\nTest Summary:")
    print(f"Total questions: {results['summary']['total_questions']}")
    print(f"Successful responses: {results['summary']['successful_responses']}")
    print(f"Failed responses: {results['summary']['failed_responses']}")
    print(f"Average similarity: {results['summary']['average_similarity']:.2%}")
    print(f"Direct quote ratio: {results['summary']['direct_quote_ratio']:.2%}")
    
    return results

def verify_testing_infrastructure():
    pipeline = RAGPipeline()
    tester = RAGTester(pipeline)
    return bool(tester)

verify_section("Testing Infrastructure", verify_testing_infrastructure)

## section 5: 
class TokenManager:
    def __init__(self, max_tokens: int = 2048):
        self.max_tokens = max_tokens
        
    def truncate_context(self, context: List[Dict], tokenizer) -> List[Dict]:
        """Truncate context while preserving most relevant information."""
        total_tokens = 0
        truncated_context = []
        
        # Sort by relevance
        sorted_context = sorted(context, key=lambda x: x.get('score', 0), reverse=True)
        
        for doc in sorted_context:
            # Get tokens for current document
            try:
                tokens = tokenizer.encode(doc['text'])
                token_count = len(tokens)
                
                # Check if adding this document would exceed limit
                if total_tokens + token_count > self.max_tokens:
                    # If we can fit a partial document, do so
                    remaining_tokens = self.max_tokens - total_tokens
                    if remaining_tokens > 100:  # Only if worth including
                        partial_tokens = tokens[:remaining_tokens]
                        truncated_text = tokenizer.decode(partial_tokens, skip_special_tokens=True)
                        # End at a complete sentence if possible
                        last_period = truncated_text.rfind('.')
                        if last_period > 0:
                            truncated_text = truncated_text[:last_period + 1]
                        truncated_context.append({
                            **doc,
                            'text': truncated_text,
                            'truncated': True
                        })
                    break
                
                total_tokens += token_count
                truncated_context.append(doc)
                
            except Exception as e:
                logger.error(f"Error truncating document: {str(e)}")
                continue
                
        return truncated_context

    def format_for_llm(self, context: List[Dict], question: str) -> str:
        """Format truncated context for LLM input."""
        formatted_parts = [
            "Kysymys: " + question,
            "\nKonteksti:"
        ]
        
        for i, doc in enumerate(context, 1):
            doc_text = (
                f"\nDokumentti {i}:\n"
                f"Lähde: {doc.get('source', 'Tuntematon')}\n"
                f"Teksti: {doc['text']}\n"
                f"{'(Katkaistu)' if doc.get('truncated') else ''}"
            )
            formatted_parts.append(doc_text)
            
        return "\n".join(formatted_parts)

## section 6:

# Add these imports at the top of your file
from langchain.memory import ConversationBufferMemory
from typing import List, Dict, Any, Optional
import re
from neo4j import GraphDatabase
import torch.nn.functional as F

class FinnishRAGAgent:
    def __init__(self, base_pipeline: RAGPipeline):
        self.pipeline = base_pipeline
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        self.min_similarity_threshold = 0.65
        self.max_answer_length = 150
        self.setup_tools()
        
    def setup_tools(self):
        """Initialize search and analysis tools."""
        self.tools = {
            "semantic_search": self._semantic_search,
            "exact_match": self._exact_match_search,
            "context_analysis": self._analyze_context
        }
    
    def _semantic_search(self, query: str) -> List[Dict]:
        """Enhanced semantic search with Finnish preprocessing."""
        try:
            preprocessed_query = self._preprocess_finnish_text(query)
            base_results = self.pipeline.query(preprocessed_query, top_k=5)
            
            # Enhance results with comprehensive scoring
            enhanced_results = []
            for result in base_results.get('sources', []):  # Use .get() with default
                try:
                    score = self._calculate_relevance_score(preprocessed_query, result)
                    enhanced_results.append({
                        'text': result.get('text', ''),
                        'score': score,
                        'source': result.get('source') or result.get('document_id', 'unknown'),  # Handle both keys
                        'metadata': {
                            'person_name': result.get('person_name', ''),
                            'chunk_index': result.get('chunk_index', 0)
                        }
                    })
                except Exception as e:
                    logger.error(f"Error processing result: {str(e)}")
                    continue
            
            return sorted(enhanced_results, key=lambda x: x['score'], reverse=True)
        except Exception as e:
            logger.error(f"Error in semantic search: {str(e)}")
            return []
    def _normalize_finnish_text(self, text: str) -> str:
        """Normalize Finnish text for matching."""
        try:
            # Convert to lowercase
            text = text.lower()
            
            # Remove diacritics but keep ä, ö
            special_chars = {
                'å': 'a',
                'é': 'e',
                'è': 'e',
                'ü': 'u',
                # Keep ä and ö as they are important in Finnish
            }
            for char, replacement in special_chars.items():
                text = text.replace(char, replacement)
            
            # Remove extra whitespace
            text = ' '.join(text.split())
            
            # Handle common Finnish abbreviations
            abbreviations = {
                'esim.': 'esimerkiksi',
                'ns.': 'niin sanottu',
                'jne.': 'ja niin edelleen',
                'ym.': 'ynnä muuta',
                'mm.': 'muun muassa'
            }
            for abbr, full in abbreviations.items():
                text = text.replace(abbr, full)
                
            return text
        except Exception as e:
            logger.error(f"Error normalizing text: {str(e)}")
            return text
            
    def _exact_match_search(self, query: str) -> List[Dict]:
        """Direct text matching with Finnish normalization."""
        try:
            normalized_query = self._normalize_finnish_text(query)
            
            # Use pipeline's collection directly
            collection = self.pipeline.collection
            results = collection.query(
                expr=f'text like "%{normalized_query}%"',
                output_fields=["text", "person_name", "document_id", "chunk_index"]
            )
            
            # Apply comprehensive scoring to exact matches as well
            scored_results = []
            for r in results:
                try:
                    result_dict = {
                        'text': r.get('text', ''),
                        'source': r.get('document_id', 'unknown'),
                        'metadata': {
                            'person_name': r.get('person_name', ''),
                            'chunk_index': r.get('chunk_index', 0)
                        }
                    }
                    score = self._calculate_relevance_score(normalized_query, result_dict)
                    result_dict['score'] = score
                    scored_results.append(result_dict)
                except Exception as e:
                    logger.error(f"Error scoring result: {str(e)}")
                    continue
                    
            return scored_results
        
        except Exception as e:
            logger.error(f"Error in exact match search: {str(e)}")
            return []

    def _analyze_context(self, passages: List[Dict]) -> Dict:
        """Analyze relationships and context in Finnish text."""
        context_data = {
            'entities': set(),
            'relationships': [],
            'temporal_refs': [],
            'key_topics': set()
        }
        
        for passage in passages:
            text = passage['text']
            
            # Extract Finnish names and entities
            entities = self._extract_finnish_entities(text)
            context_data['entities'].update(entities)
            
            # Find relationships
            relationships = self._find_relationships(text)
            context_data['relationships'].extend(relationships)
            
            # Extract temporal information
            temporal = self._extract_temporal_refs(text)
            context_data['temporal_refs'].extend(temporal)
            
        return context_data

    def _preprocess_finnish_text(self, text: str) -> str:
        """Preprocess Finnish text for better matching."""
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text.strip())
        
        # Handle common Finnish abbreviations
        abbreviations = {
            'esim.': 'esimerkiksi',
            'ns.': 'niin sanottu',
            'jne.': 'ja niin edelleen'
        }
        for abbr, full in abbreviations.items():
            text = text.replace(abbr, full)
        
        return text

    def _find_relationships(self, text: str) -> List[Dict]:
        """Find relationships in Finnish text."""
        relationships = []
        
        # Common Finnish relationship patterns
        patterns = [
            (r'(\w+)\s+on\s+(\w+)\s+ystävä', 'ystävyys'),
            (r'(\w+)\s+asuu\s+(\w+)', 'asuminen'),
            (r'(\w+)\s+tekee\s+(\w+)', 'toiminta'),
            (r'(\w+)\s+pitää\s+(\w+)', 'pitäminen'),
            (r'(\w+)\s+kanssa', 'yhteys'),
            (r'(\w+)\s+tärkeä', 'tärkeys'),
            (r'(\w+)\s+auttaa\s+(\w+)', 'auttaminen')
        ]
        
        for pattern, rel_type in patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                relationships.append({
                    'type': rel_type,
                    'entities': match.groups(),
                    'text': match.group(0)
                })
        
        return relationships

    def _extract_temporal_refs(self, text: str) -> List[Dict]:
        """Extract temporal references from Finnish text."""
        temporal_refs = []
        
        # Common Finnish temporal patterns
        patterns = [
            (r'\d+\s*vuotta', 'duration'),
            (r'\d+\s*vuotias', 'age'),
            (r'(maanantai|tiistai|keskiviikko|torstai|perjantai|lauantai|sunnuntai)', 'weekday'),
            (r'(tammikuu|helmikuu|maaliskuu|huhtikuu|toukokuu|kesäkuu|heinäkuu|elokuu|syyskuu|lokakuu|marraskuu|joulukuu)', 'month'),
            (r'(aamu|päivä|ilta|yö)', 'time_of_day'),
            (r'(eilen|tänään|huomenna)', 'relative_day'),
            (r'(viikko|kuukausi|vuosi)', 'time_unit')
        ]
        
        for pattern, ref_type in patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                temporal_refs.append({
                    'type': ref_type,
                    'text': match.group(0),
                    'position': match.span()
                })
        
        return temporal_refs
    
    def _identify_key_topics(self, text: str) -> set:
        """Identify key topics in Finnish text."""
        topics = set()
        
        # Common Finnish topic indicators
        key_patterns = [
            (r'tärkeä\w*\s+(\w+)', 'importance'),
            (r'harrastaa\w*\s+(\w+)', 'hobby'),
            (r'pitää\w*\s+(\w+)', 'preference'),
            (r'ongelma\w*\s+(\w+)', 'problem'),
            (r'tavoite\w*\s+(\w+)', 'goal')
        ]
        
        for pattern, topic_type in key_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                if len(match.groups()) > 0:
                    topics.add(f"{topic_type}:{match.group(1)}")
        
        return topics

    def _extract_finnish_entities(self, text: str) -> set:
        """Extract Finnish named entities."""
        entities = set()
        
        # Match Finnish names (simplified)
        name_pattern = r'\b[A-ZÄÖÅ][a-zäöå]+\b'
        entities.update(re.findall(name_pattern, text))
        
        return entities

    def _calculate_relevance_score(self, query: str, document: Dict) -> float:
        """Calculate comprehensive relevance score."""
        base_score = document.get('score', 0.0)
        text = document.get('text', '')
        
        # Initialize boosting factors
        boosts = {
            'exact_match': 1.0,
            'semantic_similarity': 1.0,
            'finnish_specific': 1.0,
            'context_match': 1.0
        }
        
        # Exact match boost
        if any(term.lower() in text.lower() for term in query.split()):
            boosts['exact_match'] *= 1.3
        
        # Finnish-specific patterns boost
        finnish_patterns = [
            (r'\b(ssa|ssä|sta|stä|lla|llä|lta|ltä)\b', 1.1),  # Case suffixes
            (r'\b(ja|tai|eli|sekä)\b', 1.05),  # Conjunctions
            (r'[.!?][^.!?]+\?', 1.2),  # Question-answer patterns
            (r'\b[A-ZÄÖÅ][a-zäöå]+\b', 1.15)  # Named entities
        ]
        
        for pattern, boost in finnish_patterns:
            if re.search(pattern, text):
                boosts['finnish_specific'] *= boost
        
        # Context match boost
        if self.memory:
            recent_context = self.memory.load_memory_variables({})
            if recent_context.get('chat_history'):
                context_terms = set(word.lower() for word in 
                    ' '.join(recent_context['chat_history']).split())
                text_terms = set(word.lower() for word in text.split())
                overlap = len(context_terms & text_terms)
                if overlap > 0:
                    boosts['context_match'] *= (1 + (overlap * 0.1))
        
        # Calculate final score
        final_score = base_score
        for boost in boosts.values():
            final_score *= boost
        
        return min(1.0, final_score)
    
    def _generate_response(self, question: str, results: List[Dict], chat_history: str = None) -> Dict:
        """Generate focused, concise response with proper citations."""
        if not results:
            return {
                "answer": "En löytänyt vastausta kysymykseesi.",
                "confidence": 0.0,
                "sources": []
            }

        try:
            # Build focused prompt
            prompt = self._build_focused_prompt(question, results)
            
            # Generate initial answer
            response = self.pipeline.generate_answer(prompt)
            
            # Process and validate response
            answer = self._process_response(response, results)
            
            # Force concise summary
            answer = self._enforce_concise_answer(answer)
            
            return {
                "answer": answer,
                "confidence": max(r['score'] for r in results),
                "sources": results[:3]
            }
                
        except Exception as e:
            logger.error(f"Error in response generation: {str(e)}")
            return {
                "answer": "Virhe vastauksen muodostamisessa.",
                "confidence": 0.0,
                "sources": []
            }
            
    def _enforce_concise_answer(self, answer: str) -> str:
        """Force answer to be concise by keeping only essential information."""
        try:
            # Split into sentences
            sentences = re.split(r'(?<=[.!?])\s+', answer)
            
            # Keep only essential parts:
            # 1. Document reference
            # 2. Direct quote or key information
            # 3. Confidence if highly relevant
            essential_sentences = []
            for sentence in sentences:
                if any(x in sentence.lower() for x in ['dokumentin', 'dokumentista', 'lähteen']):
                    essential_sentences.append(sentence)
                elif '"' in sentence:  # Contains a quote
                    essential_sentences.append(sentence)
                    break  # Stop after first quote
            
            if not essential_sentences:
                # If no structured parts found, take first sentence only
                return sentences[0] if sentences else answer
                
            # Join essential parts
            concise_answer = ' '.join(essential_sentences)
            
            # Ensure we have source and quote
            if 'dokumentin' not in concise_answer.lower():
                concise_answer = f"Dokumentin mukaan {concise_answer}"
            
            return concise_answer

        except Exception as e:
            logger.error(f"Error enforcing concise answer: {str(e)}")
            # Fallback to simple truncation
            return answer[:150] + ("..." if len(answer) > 150 else "")

    def _build_focused_prompt(self, question: str, results: List[Dict]) -> str:
        
        """Build a focused prompt for concise answer generation."""
        context = f"""Etsi tarkka ja ytimekäs vastaus seuraavaan kysymykseen käyttäen vain annettua kontekstia.
Vastaa mahdollisimman lyhyesti ja selkeästi, keskittyen vain olennaiseen tietoon.

Kysymys: {question}

Konteksti:
"""
        # Add most relevant documents
        for i, result in enumerate(results[:3], 1):
            context += f"\nDokumentti {i}:\nLähde: {result['document_id']}\nTeksti: {result['text']}\n"
        
        context += """\nVastausohjeet:
1. Jos löydät suoran vastauksen:
   - Mainitse dokumentti, josta vastaus löytyy
   - Käytä lyhyitä, tarkkoja lainauksia
   - Keskity vain kysyttyyn asiaan
   - Vältä turhaa toistoa ja ylimääräisiä selityksiä
2. Jos et löydä vastausta:
   - Ilmoita selkeästi ja lyhyesti, ettei vastausta löydy
3. Jos löydät vain osittaisen vastauksen:
   - Mainitse lyhyesti löydetty tieto ja mitä puuttuu

Vastaus:"""
        
        return context
    
    def _summarize_answer(self, answer: str) -> str:
        """Summarize the answer to be concise and focused."""
        # If answer is already short enough, return as is
        if len(answer) <= self.max_answer_length:
            return answer
            
        try:
            # Split into sentences
            sentences = re.split(r'(?<=[.!?])\s+', answer)
            
            # Start with the first sentence that has document reference
            summary_parts = []
            doc_ref_found = False
            
            for sentence in sentences:
                if 'Dokumentin' in sentence or 'dokumentin' in sentence:
                    summary_parts.append(sentence)
                    doc_ref_found = True
                    continue
                    
                if doc_ref_found and '"' in sentence:
                    # Include the quote if it's essential
                    summary_parts.append(sentence)
                    break
            
            # If no structured summary was possible, take the first part
            if not summary_parts:
                return answer[:self.max_answer_length] + "..."
                
            summary = ' '.join(summary_parts)
            
            # Ensure the summary ends properly
            if not any(summary.endswith(end) for end in '.!?'):
                summary += '.'
                
            return summary
            
        except Exception as e:
            logger.error(f"Error in summarization: {str(e)}")
            # Fallback to simple truncation
            return answer[:self.max_answer_length] + "..."
    
    def _analyze_query_type(self, query: str) -> str:
        """Analyze the type of query for better response structuring."""
        query_lower = query.lower()
        
        # Extract query characteristics without hardcoding specific names or values
        characteristics = {
            'personal_info': any(word in query_lower for word in ['kuka', 'kenen', 'kenelle']),
            'age_related': any(word in query_lower for word in ['vanha', 'ikä', 'syntymä', 'vuosi']),
            'activity_related': any(word in query_lower for word in ['tekee', 'harrastaa', 'pitää', 'tykkää']),
            'ability_related': any(word in query_lower for word in ['pystyy', 'osaa', 'liikkuu', 'käyttää']),
            'preference_related': any(word in query_lower for word in ['pitää', 'tykkää', 'haluaa']),
            'relationship_related': any(word in query_lower for word in ['ystävä', 'tärkeä', 'läheinen'])
        }
        
        # Return the most likely query type
        return max(characteristics.items(), key=lambda x: x[1])[0]
    
    def _get_query_instructions(self, query_type: str) -> str:
        """Get specific instructions based on query type."""
        instructions = {
            'personal_info': """
        Ohje: 
        1. Etsi henkilöön liittyvät suorat maininnat
        2. Käytä suoria lainauksia henkilöiden nimistä ja suhteista
        3. Mainitse dokumentin lähde""",
                
                'age_related': """
        Ohje:
        1. Etsi tarkat ikä- ja syntymävuositiedot
        2. Ilmoita sekä ikä että syntymävuosi jos molemmat löytyvät
        3. Mainitse dokumentin lähde""",
                
                'activity_related': """
        Ohje:
        1. Etsi kaikki mainitut aktiviteetit ja harrastukset
        2. Käytä suoria lainauksia aktiviteettien kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'ability_related': """
        Ohje:
        1. Etsi kuvaukset henkilön kyvyistä ja toiminnasta
        2. Käytä suoria lainauksia toimintakyvyn kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'preference_related': """
        Ohje:
        1. Etsi kaikki mainitut mieltymykset ja kiinnostukset
        2. Käytä suoria lainauksia mieltymysten kuvauksista
        3. Mainitse dokumentin lähde""",
                
                'relationship_related': """
        Ohje:
        1. Etsi kuvaukset ihmissuhteista ja tärkeistä henkilöistä
        2. Käytä suoria lainauksia suhteiden kuvauksista
        3. Mainitse dokumentin lähde"""
            }
            
        return instructions.get(query_type, """
        Ohje:
        1. Etsi suora vastaus kysymykseen dokumenteista
        2. Käytä suoria lainauksia
        3. Mainitse dokumentin lähde""")
    
    def _process_response(self, response: str, results: List[Dict]) -> str:
        """Process and validate the generated response with focus on conciseness."""
        if not response:
            return "En löytänyt vastausta annetusta kontekstista."
        
        # Clean up the response
        answer = response.strip()
        
        # Remove redundant phrases and metacommentary
        redundant_phrases = [
            "On tärkeää huomata, että",
            "On syytä mainita, että",
            "Voidaan todeta, että",
            "Tämän perusteella",
            "Yhteenvetona voidaan sanoa, että",
            "Luottamus:",
            "Selitys:",
            "Huomaa, että",
            "On tärkeää muistaa"
        ]
        
        for phrase in redundant_phrases:
            answer = answer.replace(phrase, "")
        
        # Ensure source citations
        if not any(f"Dokumentti" in answer for result in results):
            for result in results:
                if result['text'] in answer:
                    answer = f"Dokumentin {result['document_id']} mukaan {answer}"
                    break
        
        # Ensure proper quotes but keep them minimal
        if '"' not in answer and any(result['text'] in answer for result in results):
            # Find the most relevant quote
            for result in results:
                overlap = SequenceMatcher(None, result['text'], answer).find_longest_match(0, len(result['text']), 0, len(answer))
                if overlap.size > 20:
                    quoted_text = result['text'][overlap.a:overlap.a + overlap.size]
                    answer = f"Dokumentin {result['document_id']} mukaan \"{quoted_text}\""
                    break
        
        return answer.strip()
    
    def process_query(self, question: str) -> Dict:
        try:
            # Get chat history as text
            chat_history = []
            if self.memory:
                memory_vars = self.memory.load_memory_variables({})
                if "chat_history" in memory_vars:
                    # Convert messages to text format
                    chat_history = [
                        f"{msg.type}: {msg.content}"
                        for msg in memory_vars["chat_history"]
                    ]

            # Process query with context
            semantic_results = self._semantic_search(question)
            
            if not semantic_results:
                exact_matches = self._exact_match_search(question)
                all_results = semantic_results + exact_matches
                all_results = sorted(all_results, key=lambda x: x.get('score', 0), reverse=True)
            else:
                all_results = semantic_results

            # Generate response
            response = self._generate_response(
                question, 
                all_results, 
                "\n".join(chat_history)
            )

            # Store in memory
            if response and response.get('answer'):
                self.memory.save_context(
                    {"input": question},
                    {"output": response['answer']}
                )

            return response

        except Exception as e:
            logger.error(f"Error processing query: {str(e)}")
            return {
                "answer": "Virhe kysymyksen käsittelyssä.",
                "confidence": 0.0,
                "sources": []
            }
            
        except Exception as e:
            logger.error(f"Error processing query: {str(e)}")
            # Return a safe default response
            return {
                "answer": "Virhe kysymyksen käsittelyssä.",
                "confidence": 0.0,
                "sources": [],
                "context_analysis": {}
            }
            
def run_neo4j_enhanced_tests():
    try:
        print("\nStarting Neo4j Enhanced RAG test execution...")
        
        # Initialize base pipeline
        print("\n1. Initializing base pipeline...")
        base_pipeline = RAGPipeline()
        print("   ✓ Base pipeline initialized")
        
        # Initialize Neo4j enhanced pipeline
        print("\n2. Creating Neo4j Enhanced Pipeline...")
        neo4j_rag = Neo4jEnhancedRAGPipeline(
            base_rag_pipeline=base_pipeline,
            neo4j_uri="bolt://87.92.59.201:7687"  # Update with your Neo4j URI
        )
        print("   ✓ Neo4j Enhanced pipeline created")
        
        # Process documents
        print("\n3. Processing documents...")
        folder_path = '/scratch/project_2011638/notebooks/data/'  # Update with your actual path
        neo4j_rag.process_documents(folder_path)
        print("   ✓ Documents processed with Neo4j enhancement")
        
        # Define test questions
        test_questions = [
            "Kuinka kauan Annikki oli naimisissa miehensä kanssa?",  # Clear answer: 55 years from Annikki's document
            #"Mitä digitaalisia laitteita Elisa käyttää?",  # Clear answer from Elisa's document about her devices
            #"Mikä on Jarmon tärkein toive arjessa?",  # Clear answer about wanting to maintain possibility for solitude
            #"Millä alalla Kosti työskentelee eläkkeellä ollessaan?",  # Clear answer: kauppa-ala from Kosti's document
            #"Montako lastenlasta Annikilla on?"
        ]
        
        # Run tests
        print("\n4. Running test questions with Neo4j enhancement...")
        results = []
        for i, question in enumerate(test_questions, 1):
            try:
                print(f"\nKysymys {i}/{len(test_questions)}:")
                print("-"*80)
                print(f"{question}")
                print("-"*80)
                
                # Get results from Neo4j enhanced pipeline
                response = neo4j_rag.query(question)
                results.append({
                    "question": question,
                    "response": response
                })
                
                # Print results
                print("\nNeo4j Enhanced Vastaus:")
                print(response.get('answer', 'Ei vastausta'))
                print("\nLähteet ja konteksti:")
                print("- Dokumenttilähteet:")
                for src in response.get('sources', [])[:2]:
                    print(f"  • {src['document_id']}: {src['text'][:100]}...")
                print("- Graafikonteksti:")
                if response.get('person_context'):
                    print(f"  • Henkilö: {response['person_context']['name']}")
                    print(f"  • Suhteet: {', '.join(response['person_context'].get('relationship_types', []))}")
                print("-"*80)
                    
            except Exception as e:
                print(f"Virhe kysymyksen käsittelyssä: {str(e)}")
        
        print("\n5. Neo4j Enhanced testaus valmis")
        return results
        
    except Exception as e:
        print(f"\nVirhe Neo4j testauksessa: {str(e)}")
        raise

def run_improved_rag_tests():
    try:
        print("\nStarting RAG test execution...")
        
        # Initialize base pipeline
        print("\n1. Initializing base pipeline...")
        base_pipeline = RAGPipeline()
        print("   ✓ Base pipeline initialized")
        
        # Initialize Finnish RAG Agent
        print("\n2. Creating Finnish RAG Agent...")
        agent = FinnishRAGAgent(base_pipeline)
        print("   ✓ Agent created")
        
        # Process documents
        print("\n3. Processing documents...")
        folder_path = '/scratch/project_2011638/notebooks/data/'
        base_pipeline.process_documents(folder_path)
        print("   ✓ Documents processed")
        
        # Define test questions
        test_questions = [
            "Kuka on tärkeitä henkilöitä Eilalle?",
            "Kuinka vanha on Sulo ja mikä on hänen syntymävuotensa?",
            "Mitä Eila harrastaa?",
            "Miten Sulo liikkuu?",
            "Mistä asioista Eila pitää?"
        ]
        
        # Run tests through the agent
        print("\n4. Running test questions...")
        results = []
        for i, question in enumerate(test_questions, 1):
            try:
                print(f"\nKysymys {i}/{len(test_questions)}:")
                print("-"*80)
                print(f"{question}")
                print("-"*80)
                
                response = agent.process_query(question)
                results.append({
                    "question": question,
                    "response": response
                })
                
                print("\nVastaus:")
                print(response.get('answer', 'Ei vastausta'))
                print("\nLähde:")
                for src in response.get('sources', [])[:1]:
                    print(f"- {src['source']}: {src['text'][:150]}...")
                print("-"*80)
                    
            except Exception as e:
                print(f"Virhe kysymyksen käsittelyssä: {str(e)}")
        
        print("\n5. Testaus valmis")
        return results
        
    except Exception as e:
        print(f"\nVirhe testauksessa: {str(e)}")
        raise
        
class SharedPipelineManager:
    def __init__(self):
        self._base_pipeline = None
        self._neo4j_pipeline = None
        self._finnish_agent = None
        self.cleanup_gpu_memory()
        
    @property
    def base_pipeline(self):
        if self._base_pipeline is None:
            self._base_pipeline = RAGPipeline()
        return self._base_pipeline
    
    @property
    def neo4j_pipeline(self):
        if self._neo4j_pipeline is None:
            self._neo4j_pipeline = Neo4jEnhancedRAGPipeline(
                base_rag_pipeline=self.base_pipeline,
                neo4j_uri="bolt://87.92.59.201:7687"
            )
        return self._neo4j_pipeline
    
    @property
    def finnish_agent(self):
        if self._finnish_agent is None:
            self._finnish_agent = FinnishRAGAgent(self.base_pipeline)
        return self._finnish_agent

    def cleanup_gpu_memory(self):
        """Clean up GPU memory between pipeline operations"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
    def process_documents(self, folder_path: str):
        """Process documents once and share across all pipelines"""
        print("Processing documents...")
        # Only process documents through Neo4j pipeline as it handles both
        self.cleanup_gpu_memory()
        self.neo4j_pipeline.process_documents(folder_path)
        self.cleanup_gpu_memory()
        print("Documents processed successfully")
        
    def run_comparison_tests(self, test_questions: List[str]):
        """Run comparison tests using shared resources"""
        results = {
            "standard": [],
            "neo4j": []
        }
        
        for question in test_questions:
            # Get standard RAG results
            self.cleanup_gpu_memory()
            standard_response = self.finnish_agent.process_query(question)
            results["standard"].append({
                "question": question,
                "response": standard_response
            })
            
            # Get Neo4j enhanced results
            neo4j_response = self.neo4j_pipeline.query(question)
            results["neo4j"].append({
                "question": question,
                "response": neo4j_response
            })
            
        return results
class SharedEmbeddingManager:
    def __init__(self):
        self.cache = {}
        self.embedding_model = None
        self.batch_size = 8

    def get_embeddings(self, texts: List[str]) -> np.ndarray:
        """Get embeddings with caching."""
        # Create cache keys
        cache_keys = [hashlib.md5(text.encode()).hexdigest() for text in texts]
        
        # Find missing embeddings
        missing_indices = [i for i, key in enumerate(cache_keys) if key not in self.cache]
        missing_texts = [texts[i] for i in missing_indices]
        
        if missing_texts:
            # Generate embeddings in batches
            all_embeddings = []
            for i in range(0, len(missing_texts), self.batch_size):
                batch = missing_texts[i:i + self.batch_size]
                embeddings = self.embedding_model.generate(batch)
                all_embeddings.extend(embeddings)
                
            # Update cache
            for i, embedding in zip(missing_indices, all_embeddings):
                self.cache[cache_keys[i]] = embedding
        
        # Return all embeddings in order
        return np.array([self.cache[key] for key in cache_keys])   

def run_comparison_tests():
    """Run both standard and Neo4j enhanced tests for comparison."""
    try:
        print("\nStarting comparison test execution...")
        
        # Initialize shared pipeline manager
        pipeline_manager = SharedPipelineManager()
        
        # Process documents once
        folder_path = '/scratch/project_2011638/notebooks/data/'
        pipeline_manager.process_documents(folder_path)
        
        # Define test questions
        test_questions = [
            "Kuinka kauan Annikki oli naimisissa miehensä kanssa?",  # Clear answer: 55 years from Annikki's document
            #"Mitä digitaalisia laitteita Elisa käyttää?",  # Clear answer from Elisa's document about her devices
            #"Mikä on Jarmon tärkein toive arjessa?",  # Clear answer about wanting to maintain possibility for solitude
            #"Millä alalla Kosti työskentelee eläkkeellä ollessaan?",  # Clear answer: kauppa-ala from Kosti's document
            #"Montako lastenlasta Annikilla on?"
        ]
        
        # Run tests using shared resources
        results = pipeline_manager.run_comparison_tests(test_questions)
        
        # Print comparison
        print("\n5. Results Comparison:")
        print("-" * 80)
        for std, neo in zip(results["standard"], results["neo4j"]):
            print(f"\nQuestion: {std['question']}")
            print("\nStandard RAG Answer:")
            print(std["response"]["answer"])
            print("\nNeo4j Enhanced Answer:")
            print(neo["response"]["answer"])
            print("-" * 80)
            
        return results
        
    except Exception as e:
        print(f"\nError in comparison testing: {str(e)}")
        raise

if __name__ == "__main__":
    results = run_comparison_tests()