In [1]:
from rdflib import BNode, Graph, Literal, Namespace, URIRef
from rdflib.namespace import RDF, XSD
from humemai.rdflib import Humemai

humemai = Humemai()

In [5]:
# Cell: JanusGraph Manager - Add and Delete Data with Streaming and Corrected Counting

import nest_asyncio
import asyncio
import logging
import random
from itertools import islice
from gremlin_python.driver.client import Client
from gremlin_python.driver.protocol import GremlinServerError
from tqdm.auto import tqdm

# Apply nest_asyncio to allow nested event loops (useful in Jupyter notebooks)
nest_asyncio.apply()

# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Gremlin Client
GREMLIN_SERVER_URL = "ws://localhost:8282/gremlin"  # Update if different
client = Client(GREMLIN_SERVER_URL, "g")

# # Define the 'entity_instance_of' dictionary
# # For streaming purposes, consider loading data from an external source or generator
# # Here, we'll simulate streaming with a generator function
# entity_instance_of = {
#     "Dog": ["Mammal", "Canine"],
#     "Cat": ["Mammal", "Feline"],
#     "Snake": ["Reptile", "Serpent"],
#     "Lizard": ["Reptile"],
#     "Crocodile": ["Reptile", "Crocodilian"],
#     "Crocodilian": ["Reptile"],
#     "Mammal": ["Animal"],
#     "Reptile": ["Animal"],
#     "Canine": ["Carnivore"],
#     "Feline": ["Carnivore"],
#     "Carnivore": ["Animal"],
#     "Serpent": ["Reptile"],
#     "Animal": [],  # Top-level entity with no parents
#     # Add more entities and relationships as needed
# }


def data_generator(entity_dict, slice_size=10000):
    """
    Generator to yield data slices from the entity_instance_of dictionary.

    Args:
        entity_dict (dict): Dictionary mapping child labels to parent labels.
        slice_size (int): Number of entities to process per slice.

    Yields:
        tuple: (set of unique labels, list of (child_label, parent_label) tuples)
    """
    it = iter(entity_dict.items())
    while True:
        slice_data = dict(islice(it, slice_size))
        if not slice_data:
            break
        unique_labels = set()
        relationships = []
        for child_label, parent_labels in slice_data.items():
            unique_labels.add(child_label)
            for parent_label in parent_labels:
                unique_labels.add(parent_label)
                relationships.append((child_label, parent_label))
        yield unique_labels, relationships


async def submit_query(query, operation="query"):
    """
    Submit a Gremlin query with basic error handling.

    Args:
        query (str): The Gremlin query string.
        operation (str): Description of the operation (for logging purposes).

    Returns:
        ResultSet: The result of the query.
    """
    try:
        future = client.submitAsync(query)
        result = await asyncio.wrap_future(future)
        logger.info(f"Successfully executed {operation}.")
        return result
    except GremlinServerError as e:
        logger.error(f"Gremlin Server Error during {operation}: {e.message}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error during {operation}: {e}")
        return None


async def upsert_vertex(label):
    """
    Insert a vertex with the given label if it doesn't already exist.

    Args:
        label (str): The unique label of the entity.
    """
    query = f"""
        g.V().hasLabel('{label}').fold().
        coalesce(
            unfold(),
            addV('{label}')
        )
    """
    await submit_query(query, operation=f"upserting vertex '{label}'")


async def upsert_edge(child_label, parent_label):
    """
    Insert an edge 'subclass_of' from child to parent if it doesn't already exist.

    Args:
        child_label (str): Label of the child entity.
        parent_label (str): Label of the parent entity.
    """
    query = f"""
        g.V().hasLabel('{child_label}').as('child').
        V().hasLabel('{parent_label}').as('parent').
        coalesce(
            __.select('child').outE('subclass_of').where(__.inV().as('parent')),
            __.select('child').addE('subclass_of').to('parent')
        )
    """
    await submit_query(
        query, operation=f"upserting edge from '{child_label}' to '{parent_label}'"
    )


async def process_vertices(unique_labels, max_workers=3):
    """
    Process and insert all unique vertices with controlled concurrency.

    Args:
        unique_labels (set): Set of unique entity labels.
        max_workers (int): Number of concurrent tasks for processing.
    """
    semaphore = asyncio.Semaphore(max_workers)

    async def sem_upsert_vertex(label):
        async with semaphore:
            await upsert_vertex(label)

    tasks = [asyncio.create_task(sem_upsert_vertex(label)) for label in unique_labels]

    # Await all tasks and handle exceptions
    for f in tqdm(
        asyncio.as_completed(tasks), total=len(tasks), desc="Upserting Vertices"
    ):
        try:
            await f
        except Exception as e:
            logger.error(f"Error during vertex upsert: {e}")


async def process_edges(relationships, batch_size=5, max_workers=3):
    """
    Process and insert all edges in batches with controlled concurrency.

    Args:
        relationships (list): List of (child, parent) tuples.
        batch_size (int): Number of relationships to process per batch.
        max_workers (int): Number of concurrent tasks for processing.
    """
    total = len(relationships)
    logger.info(f"Total relationships to process: {total}")

    # Split into batches
    for i in range(0, total, batch_size):
        batch = relationships[i : i + batch_size]
        batch_number = i // batch_size + 1
        logger.info(
            f"Processing edge batch {batch_number} with {len(batch)} relationships."
        )
        await _process_edge_batch(batch, max_workers)
        logger.info(f"Edge batch {batch_number} processed.")


async def _process_edge_batch(batch, max_workers):
    """
    Process a single batch of edges using limited concurrency.

    Args:
        batch (list): List of (child, parent) tuples.
        max_workers (int): Number of concurrent tasks for processing.
    """
    semaphore = asyncio.Semaphore(max_workers)

    async def sem_upsert_edge(child, parent):
        async with semaphore:
            await upsert_edge(child, parent)

    tasks = [
        asyncio.create_task(sem_upsert_edge(child, parent)) for child, parent in batch
    ]

    # Await all tasks and handle exceptions
    for f in tqdm(
        asyncio.as_completed(tasks), total=len(tasks), desc="Upserting Edges"
    ):
        try:
            await f
        except Exception as e:
            logger.error(f"Error during edge upsert: {e}")


def get_data_stream(entity_dict, slice_size=10000):
    """
    Create a generator to stream data slices from the entity_instance_of dictionary.

    Args:
        entity_dict (dict): Dictionary mapping child labels to parent labels.
        slice_size (int): Number of entities to process per slice.

    Yields:
        tuple: (set of unique labels, list of (child_label, parent_label) tuples)
    """
    return data_generator(entity_dict, slice_size)


async def delete_all_vertices():
    """
    Delete all vertices (and consequently all edges) from the JanusGraph database.
    """
    delete_query = "g.V().drop().iterate()"
    try:
        await submit_query(delete_query, operation="deleting all vertices and edges")
        logger.info(
            "All vertices and their associated edges have been successfully deleted."
        )
    except Exception as e:
        logger.error(f"Failed to delete all data: {e}")


async def delete_vertices_in_batches(batch_size=100):
    """
    Delete all vertices in smaller batches to minimize load and lock contention.

    Args:
        batch_size (int): Number of vertices to delete per batch.
    """
    logger.info("Starting deletion of vertices in batches...")
    while True:
        delete_batch_query = f"""
            g.V().limit({batch_size}).drop().iterate()
        """
        try:
            await submit_query(
                delete_batch_query,
                operation=f"deleting a batch of {batch_size} vertices",
            )
            # Check if any vertices remain
            count_query = "g.V().count()"
            result = await submit_query(
                count_query, operation="counting remaining vertices"
            )
            if result:
                counts = result.all().result()
                remaining = counts[0] if counts else 0
                logger.info(
                    f"Deleted a batch of {batch_size} vertices. Remaining vertices: {remaining}"
                )
                if remaining == 0:
                    logger.info("All vertices have been successfully deleted.")
                    break
            else:
                logger.warning("Failed to retrieve remaining vertex count.")
                break
        except Exception as e:
            logger.error(f"Failed to delete a batch of vertices: {e}")
            break


async def add_data(
    slice_size=10000, max_workers_vertices=3, max_workers_edges=10, batch_size=100
):
    """
    Add data to the JanusGraph database by inserting vertices and edges.

    Args:
        slice_size (int): Number of entities to process per data slice.
        max_workers_vertices (int): Number of concurrent tasks for vertex upsertion.
        max_workers_edges (int): Number of concurrent tasks for edge upsertion.
        batch_size (int): Number of relationships to process per edge batch.
    """
    data_stream = get_data_stream(entity_instance_of, slice_size)
    slice_number = 0
    for unique_labels, relationships in data_stream:
        slice_number += 1
        logger.info(
            f"Processing slice {slice_number} with {len(unique_labels)} unique labels and {len(relationships)} relationships."
        )

        # Process all vertices in the current slice
        await process_vertices(unique_labels, max_workers=max_workers_vertices)

        # Process all edges in the current slice
        await process_edges(
            relationships, batch_size=batch_size, max_workers=max_workers_edges
        )


async def delete_data(batch_size=100, method="batch"):
    """
    Delete all data from the JanusGraph database.

    Args:
        batch_size (int): Number of vertices to delete per batch (used in batch deletion).
        method (str): Deletion method - 'all' or 'batch'.
                      'all' deletes all vertices in one operation.
                      'batch' deletes vertices in smaller batches.
    """
    if method == "all":
        # Option 1: Delete all at once
        await delete_all_vertices()
    elif method == "batch":
        # Option 2: Delete in batches to minimize lock contention
        await delete_vertices_in_batches(batch_size=batch_size)
    else:
        logger.error("Invalid deletion method specified. Choose 'all' or 'batch'.")


async def close_client_async():
    """
    Close the Gremlin client synchronously.
    """
    try:
        client.close()
        logger.info("Disconnected from the Gremlin server.")
    except Exception as e:
        logger.error(f"Error while closing client: {e}")


async def main_async(
    operation="add",
    slice_size=10000,
    max_workers_vertices=3,
    max_workers_edges=10,
    batch_size=100,
    delete_method="batch",
):
    """
    Main asynchronous function to perform add or delete operations based on arguments.

    Args:
        operation (str): 'add' to add data or 'delete' to delete all data.
        slice_size (int): Number of entities to process when adding data.
        max_workers_vertices (int): Number of concurrent tasks for vertex upsertion.
        max_workers_edges (int): Number of concurrent tasks for edge upsertion.
        batch_size (int): Number of relationships per edge batch or number of vertices per deletion batch.
        delete_method (str): Method of deletion - 'all' or 'batch'.
    """
    if operation == "add":
        logger.info("Starting data addition process...")
        await add_data(
            slice_size=slice_size,
            max_workers_vertices=max_workers_vertices,
            max_workers_edges=max_workers_edges,
            batch_size=batch_size,
        )
    elif operation == "delete":
        logger.info("Starting data deletion process...")
        await delete_data(batch_size=batch_size, method=delete_method)
    else:
        logger.error("Invalid operation specified. Choose 'add' or 'delete'.")

    # Close the Gremlin connection
    await close_client_async()


def run_operation(
    operation="add",
    slice_size=10000,
    max_workers_vertices=3,
    max_workers_edges=10,
    batch_size=100,
    delete_method="batch",
):
    """
    Run the specified operation asynchronously within the Jupyter notebook.

    Args:
        operation (str): 'add' to add data or 'delete' to delete all data.
        slice_size (int): Number of entities to process when adding data.
        max_workers_vertices (int): Number of concurrent tasks for vertex upsertion.
        max_workers_edges (int): Number of concurrent tasks for edge upsertion.
        batch_size (int): Number of relationships per edge batch or number of vertices per deletion batch.
        delete_method (str): Method of deletion - 'all' or 'batch'.
    """
    try:
        asyncio.run(
            main_async(
                operation,
                slice_size,
                max_workers_vertices,
                max_workers_edges,
                batch_size,
                delete_method,
            )
        )
    except RuntimeError as e:
        logger.error(f"Runtime error: {e}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")


async def count_vertices():
    """
    Count the number of vertices in the JanusGraph database.

    Returns:
        int: Total number of vertices.
    """
    query = "g.V().count()"
    try:
        future = client.submitAsync(query)
        result = await asyncio.wrap_future(future)
        counts = result.all().result()
        if counts:
            count = counts[0]
            logger.info(f"Total number of vertices: {count}")
            return count
        else:
            logger.warning("No vertices found.")
            return 0
    except GremlinServerError as e:
        logger.error(f"Gremlin Server Error: {e.message}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        return None


async def count_edges():
    """
    Count the number of edges in the JanusGraph database.

    Returns:
        int: Total number of edges.
    """
    query = "g.E().count()"
    try:
        future = client.submitAsync(query)
        result = await asyncio.wrap_future(future)
        counts = result.all().result()
        if counts:
            count = counts[0]
            logger.info(f"Total number of edges: {count}")
            return count
        else:
            logger.warning("No edges found.")
            return 0
    except GremlinServerError as e:
        logger.error(f"Gremlin Server Error: {e.message}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        return None


def run_count_vertices():
    """
    Synchronously run the asynchronous count_vertices function.

    Returns:
        int: Total number of vertices.
    """
    try:
        count = asyncio.run(count_vertices())
        return count
    except RuntimeError as e:
        logger.error(f"Runtime error: {e}")
        return None
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        return None


def run_count_edges():
    """
    Synchronously run the asynchronous count_edges function.

    Returns:
        int: Total number of edges.
    """
    try:
        count = asyncio.run(count_edges())
        return count
    except RuntimeError as e:
        logger.error(f"Runtime error: {e}")
        return None
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        return None


# Example Usage:

# To add data to the JanusGraph database
# Adjust the parameters as needed and uncomment the line below to execute

# run_operation(
#     operation='add',
#     slice_size=10000,              # Number of entities per slice
#     max_workers_vertices=3,        # Number of concurrent tasks for vertex upsertion
#     max_workers_edges=10,          # Number of concurrent tasks for edge upsertion
#     batch_size=100,                # Number of relationships per edge batch
# )

# To delete all data from the JanusGraph database using batch deletion
# Adjust the batch_size and method as needed and uncomment the line below to execute

# run_operation(
#     operation='delete',
#     batch_size=100,                # Number of vertices to delete per batch
#     delete_method='batch'          # Deletion method: 'all' or 'batch'
# )

# To delete all data from the JanusGraph database in one go
# Uncomment the lines below to execute

# run_operation(
#     operation='delete',
#     batch_size=0,                   # Not used in 'all' method
#     delete_method='all'             # Specify deletion method as 'all'
# )

# To count the number of vertices
# Uncomment the lines below to execute

total_vertices = run_count_vertices()
print(f"Total Vertices: {total_vertices}")

# To count the number of edges
# Uncomment the lines below to execute

total_edges = run_count_edges()
print(f"Total Edges: {total_edges}")

# Close the client after all operations are complete
try:
    client.close()
    logger.info("Disconnected from the Gremlin server.")
except Exception as e:
    logger.error(f"Error while closing client: {e}")

INFO:gremlinpython:Creating Client with url 'ws://localhost:8282/gremlin'
INFO:__main__:Total number of vertices: 68805


Total Vertices: 68805


INFO:__main__:Total number of edges: 53987
INFO:gremlinpython:Closing Client with url 'ws://localhost:8282/gremlin'
INFO:__main__:Disconnected from the Gremlin server.


Total Edges: 53987
