In [None]:
import faiss
import numpy as np
import pickle
import os
from typing import Dict, List, Tuple, Optional
import json

class FAISSIndexTester:
    def __init__(self, base_path: str = "."):
        """
        Initialize the FAISS Index Tester
        
        Args:
            base_path: Base directory containing the indices folder
        """
        self.base_path = base_path
        self.indices_path = os.path.join(base_path, "indices")
        self.loaded_indices = {}
        self.index_metadata = {}
        
    def load_index(self, index_name: str) -> bool:
        """
        Load a FAISS index from disk
        
        Args:
            index_name: Name of the index (without file extension)
            
        Returns:
            bool: True if successfully loaded, False otherwise
        """
        try:
            index_path = os.path.join(self.indices_path, f"{index_name}.faiss")
            metadata_path = os.path.join(self.indices_path, f"{index_name}_metadata.pkl")
            
            # Load the FAISS index
            if os.path.exists(index_path):
                index = faiss.read_index(index_path)
                self.loaded_indices[index_name] = index
                print(f"✓ Successfully loaded index: {index_name}")
                
                # Try to load metadata if it exists
                if os.path.exists(metadata_path):
                    with open(metadata_path, 'rb') as f:
                        self.index_metadata[index_name] = pickle.load(f)
                    print(f"✓ Loaded metadata for: {index_name}")
                
                return True
            else:
                print(f"✗ Index file not found: {index_path}")
                return False
                
        except Exception as e:
            print(f"✗ Error loading index {index_name}: {str(e)}")
            return False
    
    def load_all_indices(self):
        """Load all FAISS indices found in the indices directory"""
        if not os.path.exists(self.indices_path):
            print(f"✗ Indices directory not found: {self.indices_path}")
            return
            
        faiss_files = [f for f in os.listdir(self.indices_path) if f.endswith('.faiss')]
        
        if not faiss_files:
            print("No FAISS index files found")
            return
            
        for faiss_file in faiss_files:
            index_name = faiss_file.replace('.faiss', '')
            self.load_index(index_name)
    
    def get_index_info(self, index_name: str) -> Dict:
        """
        Get detailed information about a loaded index
        
        Args:
            index_name: Name of the index
            
        Returns:
            Dict containing index information
        """
        if index_name not in self.loaded_indices:
            return {"error": f"Index {index_name} not loaded"}
            
        index = self.loaded_indices[index_name]
        
        info = {
            "name": index_name,
            "total_vectors": index.ntotal,
            "dimension": index.d,
            "is_trained": index.is_trained,
            "metric_type": "L2" if index.metric_type == faiss.METRIC_L2 else "IP",
            "index_type": type(index).__name__
        }
        
        # Add metadata if available
        if index_name in self.index_metadata:
            info["metadata"] = self.index_metadata[index_name]
            
        return info
    
    def test_search(self, index_name: str, query_vector: Optional[np.ndarray] = None, 
                   k: int = 5) -> Dict:
        """
        Test search functionality on an index
        
        Args:
            index_name: Name of the index to test
            query_vector: Query vector (if None, uses a random vector)
            k: Number of nearest neighbors to return
            
        Returns:
            Dict containing search results and performance metrics
        """
        if index_name not in self.loaded_indices:
            return {"error": f"Index {index_name} not loaded"}
            
        index = self.loaded_indices[index_name]
        
        # Generate random query vector if none provided
        if query_vector is None:
            query_vector = np.random.random((1, index.d)).astype('float32')
            print(f"Using random query vector of dimension {index.d}")
        else:
            query_vector = query_vector.reshape(1, -1).astype('float32')
            
        try:
            import time
            start_time = time.time()
            
            # Perform search
            distances, indices = index.search(query_vector, k)
            
            search_time = time.time() - start_time
            
            results = {
                "search_time_ms": search_time * 1000,
                "num_results": len(indices[0]),
                "distances": distances[0].tolist(),
                "indices": indices[0].tolist(),
                "query_shape": query_vector.shape
            }
            
            return results
            
        except Exception as e:
            return {"error": f"Search failed: {str(e)}"}
    
    def benchmark_index(self, index_name: str, num_queries: int = 100, k: int = 10) -> Dict:
        """
        Benchmark search performance on an index
        
        Args:
            index_name: Name of the index to benchmark
            num_queries: Number of random queries to run
            k: Number of nearest neighbors per query
            
        Returns:
            Dict containing benchmark results
        """
        if index_name not in self.loaded_indices:
            return {"error": f"Index {index_name} not loaded"}
            
        index = self.loaded_indices[index_name]
        
        # Generate random query vectors
        query_vectors = np.random.random((num_queries, index.d)).astype('float32')
        
        try:
            import time
            
            # Warmup
            index.search(query_vectors[:5], k)
            
            # Benchmark
            start_time = time.time()
            distances, indices = index.search(query_vectors, k)
            total_time = time.time() - start_time
            
            results = {
                "total_queries": num_queries,
                "k": k,
                "total_time_seconds": total_time,
                "average_query_time_ms": (total_time / num_queries) * 1000,
                "queries_per_second": num_queries / total_time,
                "index_size": index.ntotal
            }
            
            return results
            
        except Exception as e:
            return {"error": f"Benchmark failed: {str(e)}"}
    
    def compare_indices(self, query_vector: Optional[np.ndarray] = None, k: int = 5) -> Dict:
        """
        Compare search results across all loaded indices
        
        Args:
            query_vector: Query vector to use for comparison
            k: Number of results to compare
            
        Returns:
            Dict containing comparison results
        """
        if not self.loaded_indices:
            return {"error": "No indices loaded"}
            
        # Use the first index's dimension if no query vector provided
        first_index = next(iter(self.loaded_indices.values()))
        if query_vector is None:
            query_vector = np.random.random((1, first_index.d)).astype('float32')
            
        results = {}
        for index_name in self.loaded_indices:
            search_result = self.test_search(index_name, query_vector, k)
            results[index_name] = search_result
            
        return results
    
    def print_summary(self):
        """Print a summary of all loaded indices"""
        print("\n" + "="*60)
        print("FAISS INDICES SUMMARY")
        print("="*60)
        
        if not self.loaded_indices:
            print("No indices loaded")
            return
            
        for index_name in self.loaded_indices:
            info = self.get_index_info(index_name)
            print(f"\nIndex: {index_name}")
            print(f"  Vectors: {info['total_vectors']:,}")
            print(f"  Dimensions: {info['dimension']}")
            print(f"  Type: {info['index_type']}")
            print(f"  Trained: {info['is_trained']}")
            print(f"  Metric: {info['metric_type']}")
            
            if 'metadata' in info:
                print(f"  Metadata keys: {list(info['metadata'].keys())}")


# Example usage and testing functions
def main():
    """Main function to demonstrate usage"""
    
    # Initialize the tester
    tester = FAISSIndexTester()
    
    # Load your specific indices
    print("Loading FAISS indices...")
    tester.load_index("faiss_severity_index_medibot")
    tester.load_index("faiss_symptom_index_medibot")
    
    # Or load all indices in the directory
    # tester.load_all_indices()
    
    # Print summary
    tester.print_summary()
    
    # Test individual indices
    print("\n" + "="*60)
    print("TESTING SEARCH FUNCTIONALITY")
    print("="*60)
    
    for index_name in tester.loaded_indices:
        print(f"\nTesting {index_name}:")
        result = tester.test_search(index_name, k=5)
        if "error" not in result:
            print(f"  Search time: {result['search_time_ms']:.2f}ms")
            print(f"  Results found: {result['num_results']}")
            print(f"  Top distances: {result['distances'][:3]}")
        else:
            print(f"  Error: {result['error']}")
    
    # Benchmark performance
    print("\n" + "="*60)
    print("PERFORMANCE BENCHMARK")
    print("="*60)
    
    for index_name in tester.loaded_indices:
        print(f"\nBenchmarking {index_name}:")
        benchmark = tester.benchmark_index(index_name, num_queries=50)
        if "error" not in benchmark:
            print(f"  Average query time: {benchmark['average_query_time_ms']:.2f}ms")
            print(f"  Queries per second: {benchmark['queries_per_second']:.1f}")
        else:
            print(f"  Error: {benchmark['error']}")

if __name__ == "__main__":
    main()