In [None]:
from openai import OpenAI
from google import genai
from google.genai.types import EmbedContentConfig
import faiss
import numpy as np
from dotenv import load_dotenv
import os
from markitdown import MarkItDown

# Load environment variables
load_dotenv()

# Khởi tạo các client
gemini_client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))

# OpenAI Client
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# Cerebras Client
cerebras_client = OpenAI(
    api_key=os.getenv("CEREBRAS_API_KEY"),
    base_url="https://api.cerebras.ai/v1"
)

# Groq Client
groq_client = OpenAI(
    api_key=os.getenv("GROQ_API_KEY"),
    base_url="https://api.groq.com/openai/v1"
)

# FPT Embedding configuration
API_KEY_FPT = os.getenv("FPT_API_KEY", "sk-f4MwNZO_2TtcdBYAwzq5rw")
BASE_URL_FPT = "https://mkp-api.fptcloud.com"
MODEL_FPT = "Vietnamese_Embedding"

# Client cho FPT Embedding
fpt_embedding_client = OpenAI(
    api_key=API_KEY_FPT,
    base_url=BASE_URL_FPT
)

# Initialize MarkItDown with plugins disabled for stability
markitdown = MarkItDown(enable_plugins=False)

def get_embedding(text: str, model: str = MODEL_FPT) -> list:
    """
    Retrieves the embedding for a given text using the specified model.
    Args:
        text (str): The input text to embed.
        model (str): The model to use for embedding.
    Returns:
        list: The embedding vector.
    """
    try:
        response = fpt_embedding_client.embeddings.create(
            input=text,
            model=model
        )
        return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None

def extract_text_with_markitdown(file_path):
    """
    Extract text from various file formats using Microsoft MarkItDown
    Supports: PDF, DOCX, PPTX, XLSX, CSV, HTML, images, audio, and more
    """
    try:
        print(f"Processing file: {file_path}")
        result = markitdown.convert(file_path)
        
        if result and result.text_content:
            print(f"Successfully extracted {len(result.text_content)} characters")
            return result.text_content
        else:
            print("No content extracted from file")
            return ""
            
    except Exception as e:
        print(f"Error extracting text from {file_path}: {e}")
        return ""

def split_text(text, chunk_size=800, overlap=100):
    """
    Split text into chunks with overlap
    Increased chunk size for better context with MarkItDown's structured output
    """
    if not text:
        return []
        
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunk = text[start:end]
        
        # Try to break at sentence boundaries
        if end < len(text):
            last_period = chunk.rfind('.')
            last_newline = chunk.rfind('\n')
            break_point = max(last_period, last_newline)
            
            if break_point > start + chunk_size * 0.5:  # At least half the chunk
                chunk = text[start:break_point + 1]
                end = break_point + 1
        
        chunks.append(chunk.strip())
        start = end - overlap
        
    return [chunk for chunk in chunks if chunk]  # Filter empty chunks

def create_faiss_index(embeddings):
    """Create FAISS index from embeddings"""
    if not embeddings or not embeddings[0]:
        return None
        
    dimension = len(embeddings[0])
    index = faiss.IndexFlatL2(dimension)
    
    # Filter out None embeddings
    valid_embeddings = [emb for emb in embeddings if emb is not None]
    if valid_embeddings:
        index.add(np.array(valid_embeddings).astype('float32'))
        return index
    return None

def search_faiss_index(index, query_embedding, k=5):
    """Search FAISS index for similar vectors"""
    if not index or not query_embedding:
        return None, None
        
    query_vector = np.array([query_embedding]).astype('float32')
    distances, indices = index.search(query_vector, k)
    return distances, indices

def generate_response(prompt, context, provider="cerebras"):
    """Generate response using specified provider"""
    # Chọn client và model dựa trên provider
    if provider == "cerebras":
        client = cerebras_client
        model = "llama-4-scout-17b-16e-instruct"
    elif provider == "groq":
        client = groq_client
        model = "llama3.1-8b"
    elif provider == "openai":
        client = openai_client
        model = "gpt-4o-mini"
    else:
        print(f"Unknown provider: {provider}, using cerebras as default")
        client = cerebras_client
        model = "llama-4-scout-17b-16e-instruct"
    
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {
                    "role": "system", 
                    "content": "You are a helpful assistant. Answer questions based on the provided context. If the context doesn't contain relevant information, say so clearly. Answer in Vietnamese if the question is in Vietnamese."
                },
                {
                    "role": "user", 
                    "content": f"Question: {prompt}\n\nContext:\n{context}"
                }
            ],
            stream=False,
            temperature=0.7,
            max_tokens=1500
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error generating response with {provider}: {e}"

def display_results(chunks, indices, distances=None):
    """Display retrieved chunks with relevance scores"""
    if not indices or len(indices) == 0:
        print("No relevant chunks found")
        return
        
    print("Retrieved relevant chunks:")
    print("=" * 80)
    
    for i, idx in enumerate(indices[0]):
        if idx < len(chunks):
            relevance_score = ""
            if distances is not None and len(distances[0]) > i:
                score = distances[0][i]
                relevance_score = f" (Relevance: {1/(1+score):.3f})"
            
            print(f"Chunk {i+1}{relevance_score}:")
            print(chunks[idx])
            print("-" * 80)

class RAGSystem:
    """
    RAG System class using Microsoft MarkItDown for document parsing
    Supports multiple file formats: PDF, DOCX, PPTX, XLSX, CSV, HTML, images, audio, etc.
    """
    
    def __init__(self, llm_client=None, llm_model=None):
        """
        Initialize RAG system
        Args:
            llm_client: OpenAI client for image description (optional)
            llm_model: Model name for image description (optional)
        """
        self.chunks = []
        self.embeddings = []
        self.index = None
        self.valid_indices = []  # Track which chunks have valid embeddings
        
        # Initialize MarkItDown with optional LLM for image descriptions
        if llm_client and llm_model:
            self.markitdown = MarkItDown(
                llm_client=llm_client, 
                llm_model=llm_model,
                enable_plugins=False
            )
            print("MarkItDown initialized with LLM support for image descriptions")
        else:
            self.markitdown = MarkItDown(enable_plugins=False)
            print("MarkItDown initialized (basic mode)")
    
    def load_document(self, file_path):
        """
        Load document from file path using MarkItDown
        Supports: PDF, DOCX, PPTX, XLSX, CSV, HTML, images, audio, ZIP, etc.
        """
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            return False
        
        # Extract text using MarkItDown
        text = extract_text_with_markitdown(file_path)
        
        if not text:
            print("No text extracted from file")
            return False
        
        # Split text into chunks
        self.chunks = split_text(text)
        print(f"Document split into {len(self.chunks)} chunks")
        
        if not self.chunks:
            print("No chunks created")
            return False
        
        # Create embeddings
        print("Creating embeddings...")
        self.embeddings = []
        self.valid_indices = []
        
        for i, chunk in enumerate(self.chunks):
            if i % 10 == 0:
                print(f"Processing chunk {i+1}/{len(self.chunks)}")
            
            embedding = get_embedding(chunk)
            if embedding:
                self.embeddings.append(embedding)
                self.valid_indices.append(i)
            else:
                print(f"Failed to create embedding for chunk {i+1}")
        
        if not self.embeddings:
            print("No valid embeddings created")
            return False
        
        # Create FAISS index
        self.index = create_faiss_index(self.embeddings)
        if self.index:
            print(f"FAISS index created successfully with {len(self.embeddings)} embeddings!")
            return True
        else:
            print("Failed to create FAISS index")
            return False
    
    def query(self, question, provider="cerebras", k=3, show_context=True):
        """Query the RAG system"""
        if not self.index:
            return "Please load a document first using load_document()"
        
        if not question.strip():
            return "Please provide a valid question"
        
        print(f"Searching for: {question}")
        
        # Get query embedding
        query_embedding = get_embedding(question)
        if not query_embedding:
            return "Error: Could not create embedding for the question"
        
        # Search for similar chunks
        distances, indices = search_faiss_index(self.index, query_embedding, k=k)
        
        if distances is None or indices is None:
            return "Error: Search failed"
        
        # Map back to original chunk indices
        retrieved_chunks = []
        actual_indices = []
        
        for idx in indices[0]:
            if idx < len(self.valid_indices):
                original_idx = self.valid_indices[idx]
                retrieved_chunks.append(self.chunks[original_idx])
                actual_indices.append([original_idx])
        
        if not retrieved_chunks:
            return "No relevant information found"
        
        context = "\n\n".join(retrieved_chunks)
        
        if show_context:
            display_results(self.chunks, [actual_indices[0]], distances)
        
        # Generate response
        print(f"\nGenerating response using {provider}...")
        response = generate_response(question, context, provider)
        
        return response
    
    def get_document_info(self):
        """Get information about the loaded document"""
        if not self.chunks:
            return "No document loaded"
        
        total_chars = sum(len(chunk) for chunk in self.chunks)
        valid_embeddings = len(self.embeddings)
        
        return f"""
Document Information:
- Total chunks: {len(self.chunks)}
- Valid embeddings: {valid_embeddings}
- Total characters: {total_chars:,}
- Average chunk size: {total_chars // len(self.chunks) if self.chunks else 0} characters
- Index ready: {'Yes' if self.index else 'No'}
        """.strip()

# Example usage for Jupyter Notebook:
print("RAG System with Microsoft MarkItDown initialized!")
print("\nSupported file formats:")
print("- PDF, DOCX, PPTX, XLSX, CSV")
print("- HTML, TXT, JSON, XML")
print("- Images (with OCR), Audio (with transcription)")
print("- ZIP files, YouTube URLs, EPubs")
print("\nUsage example:")
print("# Create RAG system")
print("rag = RAGSystem()")
print("\n# For image descriptions, use:")
print("# from openai import OpenAI")
print("# client = OpenAI()")
print("# rag = RAGSystem(llm_client=client, llm_model='gpt-4o')")
print("\n# Load document (any supported format)")
print("rag.load_document('university_texts.csv')")
print("# or")
print("rag.load_document('document.pdf')")
print("\n# Query the system") 
print("response = rag.query('Your question here', provider='cerebras')")
print("print(response)")
print("\n# Get document info")
print("print(rag.get_document_info())")