In [None]:
# Import required libraries
import os
import re
import nltk
import torch
import numpy as np
import logging
from typing import List, Dict, Any
from docx import Document
from tqdm.notebook import tqdm
from pymilvus import connections, Collection, utility, CollectionSchema, FieldSchema, DataType
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def print_status(section_name: str, status: bool, message: str = ""):
    """Print status of a section with colored output."""
    status_str = "✅ SUCCESS" if status else "❌ FAILED"
    print(f"\n{status_str} | {section_name}")
    if message:
        print(f"  └─ {message}")

def check_cuda():
    """Check CUDA availability and print status."""
    try:
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name(0)
            print_status("CUDA Check", True, f"Using GPU: {device_name}")
            return True
        else:
            print_status("CUDA Check", True, "Using CPU")
            return True
    except Exception as e:
        print_status("CUDA Check", False, str(e))
        return False

def ensure_stopwords_downloaded(language='finnish'):
    """Download NLTK stopwords and print status."""
    try:
        nltk.download('stopwords', quiet=True)
        print_status("NLTK Setup", True, f"Downloaded {language} stopwords")
        return True
    except Exception as e:
        print_status("NLTK Setup", False, str(e))
        return False

# Milvus Connection Settings
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
MILVUS_ALIAS = "default"
EMBEDDING_DIM = 384

class DocumentProcessor:
    def __init__(self, chunk_size=400, chunk_overlap=50):
        try:
            self.text_splitter = RecursiveCharacterTextSplitter(
                separators=["\n\n", "\n", ". ", ", ", " "],
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len,
                keep_separator=True,
                add_start_index=True
            )
            print_status("Document Processor", True, "Initialized successfully")
        except Exception as e:
            print_status("Document Processor", False, str(e))
            raise
        
    def extract_metadata_from_filename(self, filename: str) -> tuple:
        """Extract metadata from filename."""
        title = os.path.splitext(filename)[0]
        match = re.match(r'([A-Za-z]+)\s+(\d{1,3})v\s+([A-Za-z0-9\-]+)', title)
        if match:
            return match.group(1), int(match.group(2)), match.group(3)
        return None, None, None
    
    def preprocess_text(self, text: str) -> str:
        """Clean and normalize Finnish text."""
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\w\s\.\,\?\!\-\:\;äöåÄÖÅ]', '', text)
        return text.strip()
    
    def process_document(self, file_path: str) -> List[Dict[str, Any]]:
        """Process a single document and return chunks with metadata."""
        try:
            # Read document
            doc = Document(file_path)
            text = "\n".join([para.text for para in doc.paragraphs])
            
            # Extract metadata
            filename = os.path.basename(file_path)
            name, age, doc_id = self.extract_metadata_from_filename(filename)
            
            # Preprocess and split text
            clean_text = self.preprocess_text(text)
            chunks = self.text_splitter.split_text(clean_text)
            
            # Create chunks with metadata
            processed_chunks = []
            for i, chunk in enumerate(chunks):
                processed_chunks.append({
                    "text": chunk,
                    "metadata": {
                        "source": filename,
                        "person_name": name,
                        "person_age": age,
                        "document_id": doc_id,
                        "chunk_index": i
                    }
                })
            
            print_status("Document Processing", True, f"Processed {filename}")
            return processed_chunks
        except Exception as e:
            print_status("Document Processing", False, f"Error processing {file_path}: {str(e)}")
            raise

class MilvusManager:
    def __init__(self, host: str, port: str, alias: str = "default"):
        self.host = host
        self.port = port
        self.alias = alias
        self.connected = False
        self.connect()
        
    def connect(self):
        """Establish connection to Milvus."""
        try:
            connections.connect(
                alias=self.alias,
                host=self.host,
                port=self.port
            )
            self.connected = True
            print_status("Milvus Connection", True, f"Connected to {self.host}:{self.port}")
        except Exception as e:
            print_status("Milvus Connection", False, str(e))
            raise
            
    def create_collection(self, collection_name: str = "document_embeddings"):
        """Create Milvus collection with appropriate schema."""
        try:
            if utility.has_collection(collection_name):
                collection = Collection(name=collection_name)
                print_status("Milvus Collection", True, f"Using existing collection: {collection_name}")
                return collection
                
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
                FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
                FieldSchema(name="person_name", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="person_age", dtype=DataType.INT64),
                FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=100),
                FieldSchema(name="chunk_index", dtype=DataType.INT64)
            ]
            
            schema = CollectionSchema(fields=fields, description="Document embeddings collection")
            collection = Collection(name=collection_name, schema=schema)
            
            # Create IVF_FLAT index
            index_params = {
                "metric_type": "IP",
                "index_type": "IVF_FLAT",
                "params": {"nlist": 1024}
            }
            collection.create_index(field_name="embedding", index_params=index_params)
            print_status("Milvus Collection", True, f"Created new collection: {collection_name}")
            return collection
        except Exception as e:
            print_status("Milvus Collection", False, str(e))
            raise

class EmbeddingGenerator:
    def __init__(self, model_name: str = "TurkuNLP/bert-base-finnish-cased-v1"):
        try:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModel.from_pretrained(model_name).to(self.device)
            print_status("Embedding Model", True, f"Loaded {model_name}")
        except Exception as e:
            print_status("Embedding Model", False, str(e))
            raise
        
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
    def generate(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        """Generate embeddings for a list of texts."""
        try:
            all_embeddings = []
            
            # Process in batches
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                
                # Tokenize texts
                encoded_input = self.tokenizer(
                    batch_texts,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ).to(self.device)
                
                # Compute token embeddings
                with torch.no_grad():
                    model_output = self.model(**encoded_input)
                
                # Perform pooling
                sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
                
                # Normalize embeddings
                sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
                
                all_embeddings.append(sentence_embeddings.cpu().numpy())
            
            result = np.concatenate(all_embeddings)
            print_status("Embedding Generation", True, f"Generated {len(texts)} embeddings")
            return result
        except Exception as e:
            print_status("Embedding Generation", False, str(e))
            raise

class RAGPipeline:
    def __init__(self, model_id: str = "Finnish-NLP/llama-7b-finnish-instruct-v0.2"):
        try:
            self.setup_llm(model_id)
            print_status("LLM Setup", True, f"Loaded {model_id}")
            
            self.doc_processor = DocumentProcessor()
            self.embedding_generator = EmbeddingGenerator()
            self.milvus_manager = MilvusManager(
                host=MILVUS_HOST,
                port=MILVUS_PORT,
                alias=MILVUS_ALIAS
            )
            self.collection = self.milvus_manager.create_collection()
            print_status("RAG Pipeline", True, "All components initialized")
        except Exception as e:
            print_status("RAG Pipeline", False, str(e))
            raise
        
    def setup_llm(self, model_id: str):
        """Initialize the LLM with optimized settings."""
        try:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                quantization_config=bnb_config,
                torch_dtype=torch.float16,
                device_map="auto",
                max_memory={0: "6GiB"},
                offload_folder="offload"
            )
            
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            self.pipeline = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.3,
                top_p=0.95,
                repetition_penalty=1.15
            )
            print_status("LLM Pipeline", True, "Pipeline configured successfully")
        except Exception as e:
            print_status("LLM Pipeline", False, str(e))
            raise
        
    def process_documents(self, folder_path: str):
        """Process all documents in the specified folder."""
        try:
            file_paths = [f for f in os.listdir(folder_path) if f.endswith('.docx')]
            all_chunks = []
            
            for file in tqdm(file_paths, desc="Processing documents"):
                file_path = os.path.join(folder_path, file)
                chunks = self.doc_processor.process_document(file_path)
                all_chunks.extend(chunks)
                
            # Generate embeddings
            texts = [chunk["text"] for chunk in all_chunks]
            embeddings = self.embedding_generator.generate(texts)
            
            # Insert into Milvus
            data = {
                "text": texts,
                "embedding": embeddings.tolist(),
                "person_name": [chunk["metadata"]["person_name"] for chunk in all_chunks],
                "person_age": [chunk["metadata"]["person_age"] for chunk in all_chunks],
                "document_id": [chunk["metadata"]["document_id"] for chunk in all_chunks],
                "chunk_index": [chunk["metadata"]["chunk_index"] for chunk in all_chunks]
            }
            
            self.collection.insert(data)
            self.collection.flush()
            print_status("Document Processing", True, f"Inserted {len(texts)} chunks into Milvus")
        except Exception as e:
            print_status("Document Processing", False, str(e))
            raise
        
    def query(self, question: str, top_k: int = 3):
        """Query the system with a question."""
        try:
            # Generate question embedding
            question_embedding = self.embedding_generator.generate([question])[0]
            
            # Search in Milvus
            search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
            results = self.collection.search(
                data=[question_embedding.tolist()],
                anns_field="embedding",
                param=search_params,
                limit=top_k,
                output_fields=["text", "person_name", "document_id"]
            )
            
            # Format context
            context = "\n".join([
                f"Dokumentti {i+1}:\n{hit.entity.get('text')}\n"
                for i, hit in enumerate(results[0])
            ])
            
            # Generate answer
            prompt = f"""Käytä seuraavaa kontekstia vastataksesi kysymykseen.
            Vastaa vain kysyttyyn kysymykseen ja käytä vain annettua kontekstia.
            Jos et löydä vastausta kontekstista, kerro se rehellisesti.

            Konteksti:
            {context}

            Kysymys: {question}

            Vastaus:"""
            
            response = self.pipeline(prompt)[0]["generated_text"]
            answer = response.split("Vastaus:")[-1].strip()
            
            print_status("Query", True, "Generated response successfully")
            return {
                "answer": answer,
                "sources": [
                    {
                        "text": hit.entity.get('text'),
                        "person_name": hit.entity.get('person_name'),
                        "document_id": hit.entity.get('document_id')
                    }
                    for hit in results[0]
                ]
            }
        except Exception as e:
            print_status("Query", False, str(e))
            raise

def main():
    try:
        # Check CUDA
        check_cuda()
        
        # Download stopwords
        ensure_stopwords_downloaded()
        
        # Initialize the RAG pipeline
        rag = RAGPipeline()
        
        # Process documents
        folder_path = '/home/jovyan/work/notebooks/data/'  # Update this path
        print_status("Document Path", True, f"Using folder: {folder_path}")
        rag.process_documents(folder_path)
        
        # Example queries
        questions = [
            "Onko Marjatta Eilan ystävä?",
            "Miten Sulo kokee sosiaalisen kanssakäymisen merkityksen?",
            "Montako sisarusta Sulolla on?"
        ]
        
        for i, question in enumerate(questions, 1):
            print(f"\nProcessing Query {i}/{len(questions)}")
            try:
                result = rag.query(question)
                print_status(f"Query {i}", True, f"Question: {question}")
                print(f"Answer: {result['answer']}")
                print("\nSources:")
                for source in result['sources']:
                    print(f"- {source['document_id']}: {source['text'][:100]}...")
            except Exception as e:
                print_status(f"Query {i}", False, f"Failed to process question: {str(e)}")
                continue
        
        print_status("Main Execution", True, "All operations completed successfully")
        
    except Exception as e:
        print_status("Main Execution", False, str(e))
        raise

if __name__ == "__main__":
    main()