In [None]:
import chromadb
from typing import Dict, List, Optional, Any, Union
import uuid
import time
from datetime import datetime


class ChromaFramework:
    """
    A framework for managing records with CRUD operations using ChromaDB.
    Uses two separate collections: 'graph' for graph embeddings and 'text' for text embeddings.
    Records can exist in one or both collections with the same ID.
    Supports auto-generation of unique IDs.
    """
    
    def __init__(self, persist_directory: Optional[str] = None):
        """
        Initialize the ChromaDB framework with two collections.
        
        Args:
            persist_directory: Directory to persist the database. If None, uses in-memory storage.
        """
        if persist_directory:
            self.client = chromadb.PersistentClient(path=persist_directory)
        else:
            self.client = chromadb.Client()
        
        # Create two collections
        self.graph_collection = self.client.get_or_create_collection(name="graph")
        self.text_collection = self.client.get_or_create_collection(name="text")
        
        self.collections = {
            "graph": self.graph_collection,
            "text": self.text_collection
        }
    
    def create_record(self, 
                     name: str,
                     metadata: Optional[Dict[str, Any]] = None,
                     documents: Optional[List[str]] = None,
                     embeddings: Optional[Dict[str, List[float]]] = None) -> str:
        """
        Create a new record in the specified collection(s).
        
        Args:
            name: Name of the record (mandatory)
            record_id: Unique identifier for the record (auto-generated if None)
            metadata: Optional metadata dictionary
            documents: Optional list of text documents
            embeddings: Optional dictionary of embeddings {"graph": [...], "text": [...]}
        
        Returns:
            The record ID (generated or provided)
        """
        # Auto-generate ID if not provided
        record_id = str(uuid.uuid4())
        
        # Check if record already exists in any collection
        if self._record_exists(record_id):
            raise ValueError(f"Record with ID '{record_id}' already exists")
        
        # Prepare base record data
        base_metadata = {
            "name": name,
            "record_id": record_id
        }
        
        if metadata:
            base_metadata.update(metadata)
        
        # Prepare documents - use name as default document if none provided
        record_documents = documents[0] if documents else name
        
        # Determine which collections to create records in
        collections_to_create = []
        
        if embeddings:
            # Create records in collections specified by embeddings dict
            for embedding_type, embedding_vector in embeddings.items():
                if embedding_type in self.collections:
                    collections_to_create.append((embedding_type, embedding_vector))
                else:
                    raise ValueError(f"Invalid embedding type '{embedding_type}'. Must be 'graph' or 'text'")
        else:
            # If no embeddings provided, create in both collections with auto-generated embeddings
            collections_to_create = [("graph", None), ("text", None)]
        
        # Create records in specified collections
        try:
            for embedding_type, embedding_vector in collections_to_create:
                collection = self.collections[embedding_type]
                
                # Prepare metadata for this collection
                record_metadata = base_metadata.copy()
                record_metadata["embedding_type"] = embedding_type
                
                if embedding_vector:
                    # Add with custom embedding
                    collection.add(
                        ids=[record_id],
                        documents=[record_documents],
                        metadatas=[record_metadata],
                        embeddings=[embedding_vector]
                    )
                else:
                    # Let ChromaDB auto-generate embedding from document
                    collection.add(
                        ids=[record_id],
                        documents=[record_documents],
                        metadatas=[record_metadata]
                    )
                    
        except Exception as e:
            # Cleanup: if any collection failed, try to remove from all collections
            for embedding_type, _ in collections_to_create:
                try:
                    self.collections[embedding_type].delete(ids=[record_id])
                except Exception:
                    pass
            
            if "already exists" in str(e):
                raise ValueError(f"Record with ID '{record_id}' already exists")
            raise e
        
        return record_id
    
    def read_record(self, record_id: str, include_embeddings: bool = False) -> Optional[Dict[str, Any]]:
        """
        Read a record by its ID from all collections where it exists.
        
        Args:
            record_id: The unique identifier of the record
            include_embeddings: Whether to include embeddings in the result
        
        Returns:
            Dictionary containing the record data or None if not found
        """
        found_collections = {}
        record_data = None
        
        # Search in both collections
        for embedding_type, collection in self.collections.items():
            try:
                result = collection.get(
                    ids=[record_id],
                    include=["documents", "metadatas", "embeddings"] if include_embeddings else ["documents", "metadatas"]
                )
                
                if result['ids']:
                    metadata = result['metadatas'][0]
                    
                    # Initialize record_data on first find
                    if record_data is None:
                        record_data = {
                            "id": record_id,
                            "name": metadata.get('name'),
                            "collections": [],
                            "documents": [result['documents'][0]] if result['documents'] else [],
                            "metadata": {k: v for k, v in metadata.items() 
                                       if k not in ['name', 'record_id', 'embedding_type', 'auto_counter']}
                        }
                        if include_embeddings:
                            record_data["embeddings"] = {}
                    
                    # Add collection info
                    record_data["collections"].append(embedding_type)
                    
                    if include_embeddings and "embeddings" in result:
                        record_data["embeddings"][embedding_type] = result['embeddings'][0]
                    
            except Exception:
                continue
        
        return record_data
    
    def update_record(self, 
                     record_id: str,
                     name: Optional[str] = None,
                     metadata: Optional[Dict[str, Any]] = None,
                     documents: Optional[List[str]] = None,
                     embeddings: Optional[Dict[str, List[float]]] = None) -> bool:
        """
        Update an existing record in all collections where it exists.
        
        Args:
            record_id: The unique identifier of the record
            name: New name for the record
            metadata: New metadata (will merge with existing)
            documents: New documents (will replace existing)
            embeddings: New embeddings (will replace existing in specified collections)
        
        Returns:
            True if update was successful, False otherwise
        """
        # Find the record in collections
        existing_record = self.read_record(record_id, include_embeddings=True)
        if not existing_record:
            return False
        
        success = True
        
        # Update in each collection where the record exists
        for embedding_type in existing_record["collections"]:
            collection = self.collections[embedding_type]
            
            # Prepare updated data
            updated_metadata = existing_record["metadata"].copy()
            updated_name = name if name is not None else existing_record["name"]
            updated_metadata.update({
                "name": updated_name,
                "record_id": record_id,
                "embedding_type": embedding_type
            })
            
            if metadata:
                updated_metadata.update(metadata)
            
            updated_documents = documents[0] if documents else (existing_record["documents"][0] if existing_record["documents"] else updated_name)
            
            try:
                # Check if we have a new embedding for this collection
                if embeddings and embedding_type in embeddings:
                    # Update with new embedding
                    collection.update(
                        ids=[record_id],
                        documents=[updated_documents],
                        metadatas=[updated_metadata],
                        embeddings=[embeddings[embedding_type]]
                    )
                else:
                    # Update without changing embedding
                    collection.update(
                        ids=[record_id],
                        documents=[updated_documents],
                        metadatas=[updated_metadata]
                    )
                
            except Exception as e:
                print(f"Update error in {embedding_type} collection: {e}")
                success = False
        
        return success
    
    def delete_record(self, record_id: str) -> bool:
        """
        Delete a record by its ID from any collection.
        
        Args:
            record_id: The unique identifier of the record
        
        Returns:
            True if deletion was successful, False otherwise
        """
        # Try to delete from both collections
        deleted = False
        for collection in self.collections.values():
            try:
                collection.delete(ids=[record_id])
                deleted = True
            except Exception:
                continue
        
        return deleted
    
    def list_records(self, 
                    embedding_type: Optional[str] = None, 
                    limit: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        List records from specified collection(s).
        
        Args:
            embedding_type: Type of embedding ("graph", "text", or None for both)
            limit: Maximum number of records to return per collection
        
        Returns:
            List of record dictionaries
        """
        records = []
        
        # Determine which collections to search
        collections_to_search = {}
        if embedding_type and embedding_type in self.collections:
            collections_to_search[embedding_type] = self.collections[embedding_type]
        elif embedding_type is None:
            collections_to_search = self.collections
        else:
            return records  # Invalid embedding_type
        
        # Get records from collections
        for emb_type, collection in collections_to_search.items():
            try:
                result = collection.get(
                    include=["documents", "metadatas", "embeddings"],
                    limit=limit
                )
                
                for i, record_id in enumerate(result['ids']):
                    metadata = result['metadatas'][i]
                    record_data = {
                        "id": record_id,
                        "name": metadata.get('name'),
                        "embedding_type": emb_type,
                        "documents": [result['documents'][i]] if result['documents'] else [],
                        "metadata": {k: v for k, v in metadata.items() 
                                   if k not in ['name', 'record_id', 'embedding_type', 'auto_counter']},
                        "embeddings": result['embeddings'][i] if 'embeddings' in result else None
                    }
                    records.append(record_data)
                    
            except Exception:
                continue
        
        return records
    
    def search_records(self, 
                      query_text: str, 
                      embedding_type: str,
                      n_results: int = 5,
                      where: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        """
        Search for records using similarity search in specified collection.
        
        Args:
            query_text: Text to search for
            embedding_type: Collection to search in ("graph" or "text")
            n_results: Maximum number of results to return
            where: Metadata filter conditions
        
        Returns:
            List of matching records with similarity scores
        """
        if embedding_type not in self.collections:
            return []
        
        collection = self.collections[embedding_type]
        
        try:
            result = collection.query(
                query_texts=[query_text],
                n_results=n_results,
                where=where,
                include=["documents", "metadatas", "distances", "embeddings"]
            )
            
            records = []
            for i, record_id in enumerate(result['ids'][0]):
                metadata = result['metadatas'][0][i]
                
                record_data = {
                    "id": record_id,
                    "name": metadata.get('name'),
                    "embedding_type": embedding_type,
                    "documents": [result['documents'][0][i]] if result['documents'] else [],
                    "metadata": {k: v for k, v in metadata.items() 
                               if k not in ['name', 'record_id', 'embedding_type', 'auto_counter']},
                    "embeddings": result['embeddings'][0][i] if 'embeddings' in result else None,
                    "distance": result['distances'][0][i]
                }
                records.append(record_data)
            
            return records
            
        except Exception as e:
            print(f"Search error: {e}")
            return []
    
    def get_collection_stats(self) -> Dict[str, int]:
        """
        Get statistics about both collections.
        
        Returns:
            Dictionary with collection names and record counts
        """
        stats = {}
        for name, collection in self.collections.items():
            try:
                stats[name] = collection.count()
            except Exception:
                stats[name] = 0
        return stats
    
    def _record_exists(self, record_id: str) -> bool:
        """Check if a record with the given ID already exists in any collection."""
        for collection in self.collections.values():
            try:
                result = collection.get(
                    ids=[record_id],
                    include=["metadatas"]
                )
                if result['ids']:
                    return True
            except Exception:
                continue
        return False

In [None]:
vdb = ChromaFramework(
        persist_directory="./chroma_db",
    )

In [None]:
# Create records with different embedding combinations
print("Creating records...")

# Record with both graph and text embeddings (same ID in both collections)
both_embeddings_id = vdb.create_record(
    name="Multi-Modal Document",
    metadata={"type": "research_paper", "topic": "AI"},
    documents=["This is a research paper about artificial intelligence and graph networks"],
    embeddings={
        "graph": [0.1, 0.2, 0.3, 0.4, 0.5],  # Custom graph embedding
        "text": [0.6, 0.7, 0.8, 0.9, 1.0]    # Custom text embedding
    }
)
print(f"Created record in both collections: {both_embeddings_id}")

# Record with only graph embedding
graph_only_id = vdb.create_record(
    name="Graph Node",
    metadata={"node_type": "entity", "connections": 10},
    documents=["This represents an entity in a knowledge graph"],
    embeddings={
        "graph": [0.2, 0.4, 0.6, 0.8, 1.0]  # Only graph embedding
    }
)
print(f"Created graph-only record: {graph_only_id}")

# Record with only text embedding  
text_only_id = vdb.create_record(
    name="Text Article",
    metadata={"category": "news", "language": "en"},
    documents=["This is a news article about recent developments"],
    embeddings={
        "text": [0.3, 0.6, 0.9, 0.2, 0.5]  # Only text embedding
    }
)
print(f"Created text-only record: {text_only_id}")

# # Record with no custom embeddings (auto-generated in both collections)
# auto_embeddings_id = framework.create_record(
#     name="Auto Embeddings Document",
#     metadata={"source": "web", "auto_generated": True},
#     documents=["This document will have auto-generated embeddings in both collections"]
#     # No embeddings dict - will create in both collections with auto-generated embeddings
# )
# print(f"Created record with auto-generated embeddings: {auto_embeddings_id}")

# Read records and show which collections they exist in
print("\nReading records:")

multi_modal = vdb.read_record(both_embeddings_id, include_embeddings=True)
print(f"Multi-modal record '{multi_modal['name']}' exists in collections: {multi_modal['collections']}")
print(f"  Has embeddings for: {list(multi_modal['embeddings'].keys())}")

graph_record = vdb.read_record(graph_only_id, include_embeddings=True)
print(f"Graph record '{graph_record['name']}' exists in collections: {graph_record['collections']}")

text_record = vdb.read_record(text_only_id, include_embeddings=True)
print(f"Text record '{text_record['name']}' exists in collections: {text_record['collections']}")

# auto_record = framework.read_record(auto_embeddings_id)
# print(f"Auto record '{auto_record['name']}' exists in collections: {auto_record['collections']}")

# Update a record - update embeddings in both collections
print(f"\nUpdating record {both_embeddings_id}...")
vdb.update_record(
    both_embeddings_id,
    name="Updated Multi-Modal Document",
    metadata={"updated": True},
    embeddings={
        "graph": [0.9, 0.8, 0.7, 0.6, 0.5],  # New graph embedding
        "text": [0.5, 0.4, 0.3, 0.2, 0.1]    # New text embedding
    }
)

updated_record = vdb.read_record(both_embeddings_id)
print(f"Updated record name: {updated_record['name']}")
print(f"Updated metadata: {updated_record['metadata']}")

# List records by collection
print("\nListing records by collection:")
graph_records = vdb.list_records(embedding_type="graph")
print(f"Graph collection has {len(graph_records)} records:")
for record in graph_records:
    print(f"  - {record['name']} (ID: {record['id']})")

text_records = vdb.list_records(embedding_type="text") 
print(f"Text collection has {len(text_records)} records:")
for record in text_records:
    print(f"  - {record['name']} (ID: {record['id']})")

# Search in specific collections
print("\nSearching in collections:")
graph_search = vdb.search_records("graph network", embedding_type="graph", n_results=3)
print(f"Graph search results: {len(graph_search)} matches")
for result in graph_search:
    print(f"  - {result['name']} (distance: {result['distance']:.3f})")

text_search = vdb.search_records("artificial intelligence", embedding_type="text", n_results=3)
print(f"Text search results: {len(text_search)} matches")
for result in text_search:
    print(f"  - {result['name']} (distance: {result['distance']:.3f})")

# Collection statistics
print("\nCollection statistics:")
stats = vdb.get_collection_stats()
for collection_name, count in stats.items():
    print(f"  {collection_name}: {count} records")


In [None]:
vdb.list_records(embedding_type=None, limit=10)

In [None]:
graph_record