In [31]:
import faiss
import numpy as np
import pickle
import zlib
import os
import re
import pandas as pd
from typing import List, Dict
import json
from langchain_core.documents import Document
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.tools import QuerySQLDataBaseTool

from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI
import os

from operator import itemgetter
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage

os.environ["GEMINI_API_KEY"] = "AIzaSyBc_8Ls8yQQsgOgeMusRW3Y8jcC3EO1E_k"

db_user = "usr_reporting"
db_password = "Atdd5v3ecsr3p"
db_host = "alspgbdvit01q.ohl.com"
db_name = "vite_reporting_r_qa"
db_port = 6432
db_schema = "customersetup"  # Define the schema name
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine

# Create the engine without specifying the schema in the URI
engine = create_engine(f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")

# Create the SQLDatabase object, specifying the schema
db = SQLDatabase(engine, schema=db_schema)


os.environ["GOOGLE_API_KEY"] = "AIzaSyCKHLCrRFIlREEr37RMuqf83E0ezWxdghY"  # Replace with your actual API key

# Create the LLM without credentials parameter
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0
    # Don't pass credentials here when using API key
)

sql_prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content="""You are an SQL expert. Generate SQL queries based on the user's question and provided schema context.
Follow these rules:
1. Use only tables and columns mentioned in the schema context
2. Write clear, efficient SQL queries
3. Consider table relationships and column types from the schema"""),
    ("human", "Schema Context: {schema}\nQuestion: {question}")
])

# Create SQL generation chain without database dependency
generate_query = (sql_prompt | llm | StrOutputParser()
)

def safe_execute_query(query):
    try:
        print("query:", strip_sql_markdown(query))
        return execute_query.invoke(strip_sql_markdown(query))
    except Exception as e:
        return f"Error executing query: {str(e)}"

execute_query = QuerySQLDataBaseTool(db=db)
def strip_sql_markdown(sql: str) -> str:
    return sql.strip().replace("```sql", "").replace("```", "").strip()


chat_prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content="Rephrase the answer to the question based on the from LLM and schema context."),
    ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nSchema Context: {schema}")
])

rephrase_answer = chat_prompt | llm | StrOutputParser()

class IVFFAISSStore:
    """
    Memory-efficient FAISS storage using IVF (Inverted File Index)
    Optimized for schema storage and retrieval
    """
    
    def __init__(self, embedding_dim: int = 768, nlist: int = 100):
        """
        Initialize IVF FAISS store
        
        Args:
            embedding_dim: Dimension of embeddings
            nlist: Number of clusters for IVF index (higher = more precise but slower)
        """
        self.embedding_dim = embedding_dim
        self.nlist = nlist
        self.index = None
        self.document_store = {}  # Compressed document storage
        self.metadata_store = {}  # Lightweight metadata
        self.id_counter = 0
        
    def _create_ivf_index(self, sample_embeddings: np.ndarray) -> faiss.Index:
        """Create IVF index optimized for memory efficiency"""
        
        # Create quantizer (the index that produces centroids)
        quantizer = faiss.IndexFlatL2(self.embedding_dim)
        
        # Create IVF index with flat storage for vectors
        # This is more memory-efficient than standard FAISS
        n_clusters = min(self.nlist, len(sample_embeddings) // 10)
        index = faiss.IndexIVFFlat(
            quantizer,
            self.embedding_dim,
            n_clusters
        )
        
        # Set search parameters
        # Higher values = more accurate but slower
        index.nprobe = min(20, n_clusters // 5)
        
        return index
    
    def add_documents(self, embeddings: np.ndarray, documents: List[Dict], metadata: List[Dict] = None):
        """Add documents with embeddings to the IVF store"""
        
        if len(embeddings) == 0:
            return
            
        # Create index if it doesn't exist
        if self.index is None:
            self.index = self._create_ivf_index(embeddings)
            
            # Train the index (required for IVF)
            if not self.index.is_trained:
                self.index.train(embeddings.astype('float32'))
        
        # Add embeddings to index
        self.index.add(embeddings.astype('float32'))
        
        # Store compressed documents and metadata
        for i in range(len(documents)):
            doc_id = self.id_counter + i
            
            # Compress and store document
            self.document_store[doc_id] = self._compress_document(documents[i])
            
            # Store metadata separately (no compression for fast access)
            if metadata and i < len(metadata):
                self.metadata_store[doc_id] = metadata[i]
            else:
                self.metadata_store[doc_id] = {}
                
        # Update counter
        self.id_counter += len(documents)
    
    def _compress_document(self, document: Dict) -> bytes:
        """Compress document content"""
        # Convert to JSON and compress
        doc_json = json.dumps(document, separators=(',', ':'))  # Minimal JSON
        compressed = zlib.compress(doc_json.encode('utf-8'), level=9)
        return compressed
    
    def _decompress_document(self, compressed_doc: bytes) -> Dict:
        """Decompress document content"""
        doc_json = zlib.decompress(compressed_doc).decode('utf-8')
        return json.loads(doc_json)
    
    def search(self, query_embedding: np.ndarray, k: int = 5) -> List[Dict]:
        """Search with IVF index"""
        if self.index is None:
            return []
            
        # Search the IVF index
        distances, indices = self.index.search(
            query_embedding.reshape(1, -1).astype('float32'), k
        )
        
        results = []
        for i, idx in enumerate(indices[0]):
            if idx >= 0 and idx < self.id_counter:  # Valid result
                # Get metadata (always available)
                metadata = self.metadata_store.get(idx, {})
                
                # Decompress document only when needed
                if idx in self.document_store:
                    doc = self._decompress_document(self.document_store[idx])
                    content = doc.get('content', '')
                else:
                    content = f"Document {idx} not found"
                
                results.append({
                    'content': content,
                    'metadata': metadata,
                    'distance': float(distances[0][i]),
                    'doc_id': idx
                })
        
        return results
    
    def get_stats(self) -> Dict:
        """Get statistics about the IVF store"""
        if not self.document_store:
            return {"error": "No documents stored"}
            
        # Calculate stats
        total_docs = len(self.document_store)
        total_compressed_size = sum(len(doc) for doc in self.document_store.values())
        
        # Sample a few documents to estimate uncompressed size
        sample_size = min(10, total_docs)
        if sample_size > 0:
            sample_ids = list(self.document_store.keys())[:sample_size]
            sample_uncompressed = sum(
                len(json.dumps(self._decompress_document(self.document_store[idx])).encode('utf-8'))
                for idx in sample_ids
            )
            estimated_uncompressed = (sample_uncompressed / sample_size) * total_docs
            compression_ratio = estimated_uncompressed / total_compressed_size if total_compressed_size > 0 else 0
        else:
            estimated_uncompressed = 0
            compression_ratio = 0
            
        return {
            "document_count": total_docs,
            "compressed_size_bytes": total_compressed_size,
            "estimated_uncompressed_bytes": estimated_uncompressed,
            "compression_ratio": compression_ratio,
            "avg_document_size_bytes": total_compressed_size / total_docs if total_docs > 0 else 0,
            "embedding_dim": self.embedding_dim,
            "nlist": self.nlist
        }
    
    def save(self, directory: str):
        """Save the IVF store to disk"""
        os.makedirs(directory, exist_ok=True)
        
        # Save the FAISS index
        if self.index is not None:
            faiss.write_index(self.index, os.path.join(directory, "ivf_index.faiss"))
            
        # Save document store and metadata
        with open(os.path.join(directory, "documents.pkl"), "wb") as f:
            pickle.dump(self.document_store, f)
            
        with open(os.path.join(directory, "metadata.pkl"), "wb") as f:
            pickle.dump(self.metadata_store, f)
            
        # Save configuration
        config = {
            "embedding_dim": self.embedding_dim,
            "nlist": self.nlist,
            "id_counter": self.id_counter
        }
        
        with open(os.path.join(directory, "config.json"), "w") as f:
            json.dump(config, f)
    
    @classmethod
    def load(cls, directory: str) -> 'IVFFAISSStore':
        """Load an IVF store from disk"""
        # Load configuration
        with open(os.path.join(directory, "config.json"), "r") as f:
            config = json.load(f)
            
        # Create instance
        store = cls(
            embedding_dim=config["embedding_dim"],
            nlist=config["nlist"]
        )
        
        # Load the FAISS index
        store.index = faiss.read_index(os.path.join(directory, "ivf_index.faiss"))
        
        # Load document store and metadata
        with open(os.path.join(directory, "documents.pkl"), "rb") as f:
            store.document_store = pickle.load(f)
            
        with open(os.path.join(directory, "metadata.pkl"), "rb") as f:
            store.metadata_store = pickle.load(f)
            
        # Set counter
        store.id_counter = config["id_counter"]
        
        return store

# Create IVF FAISS store from schema data
def create_ivf_schema_store():
    """Create an IVF FAISS store from schema data"""
    # Initialize embeddings
    embeddings_model = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
    
    # Load your data
    df = pd.read_csv("./table_schema.csv")
    
    # Prepare documents and metadata
    documents = []
    metadata = []
    
    for i, row in df.iterrows():
        # Extract column definitions from DDL
        column_pattern = r'(\w+)\s+(\w+(?:\(\d+\))?)\s+(\w+)?'
        columns = re.findall(column_pattern, row['DDL'])
        
        # Create document for table
        table_doc = {
            'content': f"TABLE: {row['table_name']} SCHEMA: {row['DDL']}",
            'type': 'table_schema'
        }
        
        table_meta = {
            'table': row['table_name'],
            'type': 'table_schema'
        }
        
        documents.append(table_doc)
        metadata.append(table_meta)
        
        # Create documents for each column
        for col in columns:
            if len(col) >= 2:
                col_name, col_type = col[0], col[1]
                col_doc = {
                    'content': f"TABLE: {row['table_name']} COLUMN: {col_name} TYPE: {col_type}",
                    'type': 'column_info'
                }
                
                col_meta = {
                    'table': row['table_name'],
                    'column': col_name,
                    'type': 'column_info'
                }
                
                documents.append(col_doc)
                metadata.append(col_meta)
    
    # Get embeddings for all documents
    texts = [doc['content'] for doc in documents]
    embeddings = np.array(embeddings_model.embed_documents(texts))
    
    # Create IVF store
    store = IVFFAISSStore(embedding_dim=embeddings.shape[1], nlist=50)
    store.add_documents(embeddings, documents, metadata)
    
    # Save the store
    store.save("ivf_schema_store")
    
    # Show stats
    stats = store.get_stats()
    # print("\nIVF FAISS Store Statistics:")
    # for key, value in stats.items():
    #     print(f"{key}: {value}")
    
    return store

# LangChain-compatible retriever for IVF store
class IVFFAISSRetriever:
    """LangChain-compatible retriever for IVF FAISS store"""
    
    def __init__(self, store: IVFFAISSStore):
        self.store = store
        self.embeddings_model = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
    
    def invoke(self, query: str) -> List[Document]:
        """LangChain-compatible invoke method"""
        # Get query embedding
        query_embedding = np.array(self.embeddings_model.embed_query(query))
        
        # Search
        results = self.store.search(query_embedding, k=5)
        
        # Convert to LangChain Document format
        documents = []
        for result in results:
            documents.append(Document(
                page_content=result['content'],
                metadata=result['metadata']
            ))
        
        return documents

# Create and test the IVF store
ivf_store = create_ivf_schema_store()
ivf_retriever = IVFFAISSRetriever(ivf_store)

# Test the IVF retriever




In [None]:
chain_with_ivf_retriever = (
    RunnablePassthrough()
    .assign(
        schema=lambda x: ivf_retriever.invoke(x["question"]) if isinstance(x, dict) else ivf_retriever.invoke(x)
    )
    .assign(query=generate_query)
    .assign(
        result=itemgetter("query") | 
        RunnableLambda(lambda q: safe_execute_query(q))
    )
    | rephrase_answer
)
while True:
    try:
        question = input("Enter your question (or 'exit' to quit): ")
        if question.lower() == 'exit':
            break
    except EOFError:
        break
    # Test the chain with a properly formatted input dictionary
    response = chain_with_ivf_retriever.invoke({"question": question})
    print("\nResponse with IVF FAISS Retriever:")
    print(response)

retreviever   [Document(metadata={'table': 'locations', 'column': 'weightcapacity', 'type': 'column_info'}, page_content='TABLE: locations COLUMN: weightcapacity TYPE: int4'), Document(metadata={'table': 'container', 'column': 'weightuom', 'type': 'column_info'}, page_content='TABLE: container COLUMN: weightuom TYPE: varchar(10)'), Document(metadata={'table': 'container', 'column': 'weightmin', 'type': 'column_info'}, page_content='TABLE: container COLUMN: weightmin TYPE: numeric'), Document(metadata={'table': 'container', 'column': 'weightmax', 'type': 'column_info'}, page_content='TABLE: container COLUMN: weightmax TYPE: numeric'), Document(metadata={'table': 'container', 'column': 'unituom', 'type': 'column_info'}, page_content='TABLE: container COLUMN: unituom TYPE: varchar(10)')]
query: SELECT
  COUNT(CASE WHEN weightuom = 'Pound' THEN 1 ELSE NULL END)
FROM container;

Response with IVF FAISS Retriever:
There are 9 containers with a weight unit of measure in pounds.
