# Neo4j Write Performance Test

This notebook tests Neo4j write performance using **random embeddings** to isolate
database throughput from embedding generation overhead.

## What This Test Does

1. Reads data from a Delta table
2. Generates random 384-dimension float arrays as "embeddings"
3. Writes nodes with embeddings to Neo4j
4. Reports throughput statistics

## Why Use Random Embeddings?

- Establishes baseline Neo4j write throughput
- Helps tune batch size and parallelism settings
- Identifies if Neo4j or embedding generation is the bottleneck

**Typical Results:**
- Random embeddings: 2,000-5,000 rows/second
- With ai_query: 100-500 rows/second

## Imports and Setup

In [None]:
# =============================================================================
# PATH SETUP - Ensures modules are importable from the notebook
# =============================================================================
# This cell adds the notebook's directory to sys.path if needed
# Required when running in Databricks Repos or when modules aren't on the path

import sys
import os

# Get the directory containing this notebook
# In Databricks, use the notebook path to find the module directory
try:
    # Try to get the notebook path from Databricks context
    notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
    notebook_dir = "/Workspace" + "/".join(notebook_path.split("/")[:-1])
    
    if notebook_dir not in sys.path:
        sys.path.insert(0, notebook_dir)
        print(f"Added to sys.path: {notebook_dir}")
except Exception:
    # Fallback: add current working directory
    cwd = os.getcwd()
    if cwd not in sys.path:
        sys.path.insert(0, cwd)
        print(f"Added to sys.path: {cwd}")

print("Path setup complete!")

In [None]:
import time
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

# Import from modular components
from load_utils import (
    Config,
    load_config,
    neo4j_driver,
    print_config,
    print_section_header,
    test_neo4j_connection,
    format_duration,
)

from neo4j_schema import (
    SchemaConfig,
    setup_neo4j_schema,
    delete_nodes_by_label,
)

from embedding_providers import (
    EmbeddingConfig,
    RandomEmbeddingProvider,
)

from streaming_pipeline import (
    PipelineConfig,
    run_pipeline,
    print_pipeline_summary,
)

print("Imports complete!")

## Configuration

Adjust these settings to control the test run.

In [None]:
# =============================================================================
# TEST CONFIGURATION - Modify these values to control the test
# =============================================================================

# Row limit: Set to a positive number to limit rows, or -1 for all rows
# Default: 500 rows for quick testing
MAX_ROWS = 500  # Set to -1 to process all rows

# Batch size: Number of rows per Neo4j transaction
# Larger = faster but more memory; Smaller = safer but slower
BATCH_SIZE = 5000

# Write partitions: Parallel writers to Neo4j
# 1 = serial (safest), 2-4 = moderate parallelism, 4+ = may cause lock contention
WRITE_PARTITIONS = 1

# Embedding dimensions (must match production settings)
EMBEDDING_DIMENSIONS = 384

# Neo4j node label for test data (separate from production)
TEST_LABEL = "RemovalEventTest"

# Databricks secret scope containing Neo4j credentials
SCOPE_NAME = "airline-neo4j-secrets"

# Source table in Unity Catalog
SOURCE_TABLE = "airline_test.airline_test_lakehouse.nodes_removals_large"

# Column mappings
TEXT_COLUMN = "RMV_REA_TX"
ID_COLUMN = ":ID(RemovalEvent)"

# Checkpoint location (for streaming mode when MAX_ROWS=-1)
CHECKPOINT_LOCATION = "/tmp/neo4j_write_test_checkpoint"

# Whether to clean up existing test nodes before running
CLEANUP_NODES = True

# Whether to setup Neo4j schema (constraint)
SETUP_SCHEMA = True

print("Configuration loaded!")
print(f"  MAX_ROWS: {MAX_ROWS} {'(all rows)' if MAX_ROWS == -1 else ''}")
print(f"  BATCH_SIZE: {BATCH_SIZE:,}")
print(f"  WRITE_PARTITIONS: {WRITE_PARTITIONS}")
print(f"  TEST_LABEL: {TEST_LABEL}")

## Helper Functions

In [None]:
def select_columns(df: DataFrame) -> DataFrame:
    """Select and rename columns from the source table."""
    return df.select(
        col(f"`{ID_COLUMN}`").alias("removal_id"),
        col(TEXT_COLUMN).alias("removal_reason"),
        col("RMV_TRK_NO").alias("rmv_trk_no"),
        col("component_id"),
        col("aircraft_id"),
        col("removal_date"),
    )


def verify_test_nodes(config: Config) -> int:
    """Verify test nodes were created correctly in Neo4j."""
    print_section_header("VERIFYING TEST NODES")

    with neo4j_driver(config) as driver:
        with driver.session(database=config.database) as session:
            # Count total nodes
            result = session.run(f"""
                MATCH (n:{TEST_LABEL})
                RETURN count(n) AS count
            """)
            total_count = result.single()["count"]
            print(f"  Total {TEST_LABEL} nodes: {total_count:,}")

            # Count nodes with embeddings
            result = session.run(f"""
                MATCH (n:{TEST_LABEL})
                WHERE n.embedding IS NOT NULL
                RETURN count(n) AS count
            """)
            embedding_count = result.single()["count"]
            print(f"  Nodes with embeddings: {embedding_count:,}")

            # Sample verification
            result = session.run(f"""
                MATCH (n:{TEST_LABEL})
                WHERE n.embedding IS NOT NULL
                RETURN n.removal_id AS id, size(n.embedding) AS dims
                LIMIT 3
            """)
            records = list(result)
            if records:
                print("\n  Sample nodes:")
                for r in records:
                    print(f"    ID: {r['id']}, embedding dims: {r['dims']}")

    return total_count

print("Helper functions defined!")

## Run the Test

Execute the cell below to run the Neo4j write performance test.

In [None]:
# =============================================================================
# RUN NEO4J WRITE PERFORMANCE TEST
# =============================================================================

pipeline_start = time.time()

# Step 1: Print header
print_section_header("NEO4J WRITE PERFORMANCE TEST")
print("Testing Neo4j write performance with random embeddings")
print(f"Target label: :{TEST_LABEL}")
print(f"Embedding dimensions: {EMBEDDING_DIMENSIONS} (random)")
print(f"Max rows: {MAX_ROWS} {'(all rows)' if MAX_ROWS == -1 else ''}")

# Step 2: Load configuration
config = load_config(
    dbutils,
    SCOPE_NAME,
    default_database="neo4j",
    default_protocol="neo4j+s",
    default_embedding_endpoint="unused",
)
print_config(config, SCOPE_NAME, EMBEDDING_DIMENSIONS, BATCH_SIZE)

# Step 3: Test Neo4j connection
if not test_neo4j_connection(config):
    raise Exception("Neo4j connection failed! Check credentials and network.")

# Step 4: Setup schema (if requested)
if SETUP_SCHEMA:
    schema_config = SchemaConfig(
        node_label=TEST_LABEL,
        id_property="removal_id",
        embedding_dimensions=EMBEDDING_DIMENSIONS,
        constraint_name=f"{TEST_LABEL.lower()}_removal_id_unique",
        vector_index_name=f"{TEST_LABEL.lower()}_embeddings",
    )
    setup_neo4j_schema(config, schema_config)

# Step 5: Cleanup existing test nodes (if requested)
if CLEANUP_NODES:
    delete_nodes_by_label(config, TEST_LABEL)

# Step 6: Configure embedding provider (random)
embedding_config = EmbeddingConfig(
    endpoint_name="random",
    dimensions=EMBEDDING_DIMENSIONS,
    text_column="removal_reason",
    output_column="embedding",
)
embedding_provider = RandomEmbeddingProvider(embedding_config)

print_section_header("VALIDATING EMBEDDING PROVIDER")
embedding_provider.validate_endpoint()

# Step 7: Configure pipeline
pipeline_config = PipelineConfig(
    source_table=SOURCE_TABLE,
    node_label=TEST_LABEL,
    id_column="removal_id",
    batch_size=BATCH_SIZE,
    write_partitions=WRITE_PARTITIONS,
    checkpoint_location=CHECKPOINT_LOCATION,
    max_files_per_trigger=1,
    max_rows=MAX_ROWS,
)

# Step 8: Run pipeline
stats = run_pipeline(
    spark=spark,
    neo4j_config=config,
    pipeline_config=pipeline_config,
    embedding_provider=embedding_provider,
    column_selector=select_columns,
    clear_checkpoint=True,
    dbutils=dbutils,
)

# Step 9: Verify results
verify_test_nodes(config)

# Step 10: Print summary
print_pipeline_summary(stats, pipeline_config, embedding_provider)

total_time = time.time() - pipeline_start
print(f"\nTotal test time: {format_duration(total_time)}")
print("\nDone!")

## Interpret Results

After running the test, review the statistics:

- **Rows per second**: This is your Neo4j write throughput baseline
- **Batch time**: Time to process each chunk (should be consistent)
- **Total nodes**: Verify all expected nodes were created

### Tuning Recommendations

| Symptom | Adjustment |
|---------|------------|
| Low throughput | Increase BATCH_SIZE (try 10000) |
| Memory errors | Decrease BATCH_SIZE (try 1000) |
| Lock contention | Decrease WRITE_PARTITIONS to 1 |
| Need more speed | Increase WRITE_PARTITIONS (try 2-4) |