# Embedding Pipeline - Custom Model Endpoints

This notebook loads data from a Delta table into Neo4j with vector embeddings
using a **custom model** deployed to Databricks Model Serving.

## Custom Model Advantages

- **Control**: Choose any sentence-transformer model
- **Fine-tuning**: Use domain-specific models
- **Dimensions**: Select appropriate embedding size
- **Cost**: May be cheaper for high volume

## Common Model Dimensions

| Model | Dimensions | Notes |
|-------|------------|-------|
| all-MiniLM-L6-v2 | 384 | Fast, good quality |
| all-mpnet-base-v2 | 768 | Higher quality |
| e5-large-v2 | 1024 | Highest quality |

## API Format

Custom models use the `dataframe_records` input format:

```python
Input:  {"dataframe_records": [{"text": "hello"}]}
Output: {"predictions": [[0.1, 0.2, ...]]}
```

## Imports and Setup

In [None]:
# =============================================================================
# PATH SETUP - Ensures modules are importable from the notebook
# =============================================================================
# This cell adds the notebook's directory to sys.path if needed
# Required when running in Databricks Repos or when modules aren't on the path

import sys
import os

# Get the directory containing this notebook
# In Databricks, use the notebook path to find the module directory
try:
    # Try to get the notebook path from Databricks context
    notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
    notebook_dir = "/Workspace" + "/".join(notebook_path.split("/")[:-1])
    
    if notebook_dir not in sys.path:
        sys.path.insert(0, notebook_dir)
        print(f"Added to sys.path: {notebook_dir}")
except Exception:
    # Fallback: add current working directory
    cwd = os.getcwd()
    if cwd not in sys.path:
        sys.path.insert(0, cwd)
        print(f"Added to sys.path: {cwd}")

print("Path setup complete!")

In [None]:
import time
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

# Import from modular components
from load_utils import (
    Config,
    load_config,
    neo4j_driver,
    print_config,
    print_section_header,
    test_neo4j_connection,
    format_duration,
)

from neo4j_schema import (
    SchemaConfig,
    setup_neo4j_schema,
    wait_for_vector_index,
)

from embedding_providers import (
    EmbeddingConfig,
    CustomModelEmbeddingProvider,
    generate_query_embedding,
)

from streaming_pipeline import (
    PipelineConfig,
    run_pipeline,
    print_pipeline_summary,
)

print("Imports complete!")

## Configuration

Adjust these settings to control the embedding pipeline.

In [None]:
# =============================================================================
# PIPELINE CONFIGURATION - Modify these values to control the run
# =============================================================================

# Row limit: Set to a positive number to limit rows, or -1 for all rows
# Default: 500 rows for quick testing
MAX_ROWS = 500  # Set to -1 to process all rows

# Batch size: Number of rows per Neo4j transaction
# Larger = faster but more memory; Smaller = safer but slower
BATCH_SIZE = 5000

# Write partitions: Parallel writers to Neo4j
# Custom models often handle concurrency well, so 4 is a good starting point
WRITE_PARTITIONS = 4

# Custom embedding model endpoint (your deployed model)
EMBEDDING_ENDPOINT = "rk_serving_airline_embedding"

# Embedding dimensions (must match your model!)
# MiniLM: 384, MPNet: 768, E5-Large: 1024
EMBEDDING_DIMENSIONS = 384

# Neo4j node label (separate from DBX hosted model to avoid conflicts)
NODE_LABEL = "RemovalEvent"

# Neo4j schema names
CONSTRAINT_NAME = "removal_event_removal_id_unique"
VECTOR_INDEX_NAME = "removal_reason_embeddings"

# Databricks secret scope containing Neo4j credentials
SCOPE_NAME = "airline-neo4j-secrets"

# Source table in Unity Catalog
SOURCE_TABLE = "airline_test.airline_test_lakehouse.nodes_removals_large"

# Column mappings
TEXT_COLUMN = "RMV_REA_TX"
ID_COLUMN = ":ID(RemovalEvent)"

# Checkpoint location (for streaming mode when MAX_ROWS=-1)
CHECKPOINT_LOCATION = "/tmp/removal_embeddings_checkpoint"

# Whether to setup Neo4j schema (constraint and vector index)
SETUP_SCHEMA = True

# Whether to run similarity search test after loading
TEST_SEARCH = True

# =============================================================================
# Step 1: Print header
# =============================================================================
print_section_header("EMBEDDING PIPELINE (CUSTOM MODEL)")
print(f"Embedding Model: {EMBEDDING_ENDPOINT}")
print(f"Dimensions: {EMBEDDING_DIMENSIONS}")
print(f"Neo4j Label: :{NODE_LABEL}")
print(f"Vector Index: {VECTOR_INDEX_NAME}")
print(f"Max rows: {MAX_ROWS} {'(all rows)' if MAX_ROWS == -1 else ''}")

# =============================================================================
# Step 2: Load configuration
# =============================================================================
config = load_config(
    dbutils,
    SCOPE_NAME,
    default_database="neo4j",
    default_protocol="neo4j+s",
    default_embedding_endpoint=EMBEDDING_ENDPOINT,
)
config.embedding_endpoint = EMBEDDING_ENDPOINT
print_config(config, SCOPE_NAME, EMBEDDING_DIMENSIONS, BATCH_SIZE)

## Helper Functions

In [None]:
def select_columns(df: DataFrame) -> DataFrame:
    """Select and rename columns from the source table."""
    return df.select(
        col(f"`{ID_COLUMN}`").alias("removal_id"),
        col(TEXT_COLUMN).alias("removal_reason"),
        col("RMV_TRK_NO").alias("rmv_trk_no"),
        col("component_id"),
        col("aircraft_id"),
        col("removal_date"),
    )


def verify_embeddings(config: Config) -> bool:
    """Verify embeddings were stored correctly using Spark Connector."""
    print_section_header("VERIFYING EMBEDDINGS")

    df = (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("url", config.uri)
        .option("authentication.basic.username", config.username)
        .option("authentication.basic.password", config.password)
        .option("database", config.database)
        .option("labels", NODE_LABEL)
        .load()
    )

    total_count = df.count()
    print(f"Total {NODE_LABEL} nodes: {total_count:,}")

    if "embedding" not in df.columns:
        print("Warning: 'embedding' column not found!")
        return False

    with_embeddings = df.filter(col("embedding").isNotNull())
    embedding_count = with_embeddings.count()
    print(f"Nodes with embeddings: {embedding_count:,}")

    print("\nSample nodes:")
    sample = with_embeddings.limit(3).collect()

    all_valid = True
    for i, row in enumerate(sample):
        emb = row["embedding"]
        if emb is None:
            print(f"  [{i+1}] No embedding")
            all_valid = False
        elif len(emb) != EMBEDDING_DIMENSIONS:
            print(f"  [{i+1}] Wrong dimensions ({len(emb)})")
            all_valid = False
        else:
            preview = [f"{v:.4f}" for v in emb[:3]]
            print(f"  [{i+1}] {len(emb)} dims, [{', '.join(preview)}, ...]")

    return all_valid and embedding_count == total_count


def test_vector_search(config: Config, test_text: str = "hydraulic pump failure"):
    """Test vector similarity search with the loaded embeddings."""
    print_section_header("TESTING VECTOR SEARCH")

    print(f"Query: '{test_text}'")

    print("\nGenerating query embedding...")
    test_embedding = generate_query_embedding(
        EMBEDDING_ENDPOINT,
        test_text,
        api_format="custom",  # Custom model format
    )
    print(f"  Dimensions: {len(test_embedding)}")

    query = f"""
        CALL db.index.vector.queryNodes(
            '{VECTOR_INDEX_NAME}',
            5,
            $embedding
        ) YIELD node, score
        RETURN node.removal_id AS removal_id,
               node.removal_reason AS reason,
               score
    """

    print("\nTop 5 similar items:")
    with neo4j_driver(config) as driver:
        with driver.session(database=config.database) as session:
            result = session.run(query, embedding=test_embedding)
            records = list(result)

            if not records:
                print("  No results found. Is the vector index populated?")
            else:
                for i, record in enumerate(records):
                    reason = record["reason"] or ""
                    print(f"  [{i+1}] Score: {record['score']:.4f}")
                    print(f"      ID: {record['removal_id']}")
                    print(f"      Reason: {reason}")
                    print()

print("Helper functions defined!")

## Run the Pipeline

Execute each step below to run the embedding pipeline with your custom model.

In [None]:
# Step 3: Test Neo4j connection
if not test_neo4j_connection(config):
    raise Exception("Neo4j connection failed! Check credentials and network.")

In [None]:
# Step 4: Validate embedding endpoint
embedding_config = EmbeddingConfig(
    endpoint_name=EMBEDDING_ENDPOINT,
    dimensions=EMBEDDING_DIMENSIONS,
    text_column="removal_reason",
    output_column="embedding",
)
embedding_provider = CustomModelEmbeddingProvider(embedding_config)

print_section_header("VALIDATING EMBEDDING ENDPOINT")
if not embedding_provider.validate_endpoint():
    raise Exception(
        f"Embedding endpoint validation failed!\n"
        f"Check that your Model Serving endpoint '{EMBEDDING_ENDPOINT}' is running\n"
        f"and EMBEDDING_DIMENSIONS ({EMBEDDING_DIMENSIONS}) matches your model."
    )

In [None]:
# Step 5: Setup Neo4j schema (if requested)
if SETUP_SCHEMA:
    schema_config = SchemaConfig(
        node_label=NODE_LABEL,
        id_property="removal_id",
        embedding_dimensions=EMBEDDING_DIMENSIONS,
        constraint_name=CONSTRAINT_NAME,
        vector_index_name=VECTOR_INDEX_NAME,
    )
    setup_neo4j_schema(config, schema_config)
    wait_for_vector_index(config, VECTOR_INDEX_NAME)

In [None]:
# Step 6: Configure pipeline
pipeline_config = PipelineConfig(
    source_table=SOURCE_TABLE,
    node_label=NODE_LABEL,
    id_column="removal_id",
    batch_size=BATCH_SIZE,
    write_partitions=WRITE_PARTITIONS,
    checkpoint_location=CHECKPOINT_LOCATION,
    max_files_per_trigger=1,
    max_rows=MAX_ROWS,
)

# Validate max_rows was set correctly
print(f"PipelineConfig created: max_rows={pipeline_config.max_rows}")
if MAX_ROWS > 0 and pipeline_config.max_rows != MAX_ROWS:
    raise ValueError(f"max_rows mismatch: expected {MAX_ROWS}, got {pipeline_config.max_rows}")

In [None]:
# Step 7: Run pipeline
pipeline_start = time.time()

stats = run_pipeline(
    spark=spark,
    neo4j_config=config,
    pipeline_config=pipeline_config,
    embedding_provider=embedding_provider,
    column_selector=select_columns,
    clear_checkpoint=True,  # Fresh start
    dbutils=dbutils,
)

pipeline_duration = time.time() - pipeline_start
print(f"\nPipeline completed in {format_duration(pipeline_duration)}")

In [None]:
# Step 8: Verify embeddings
verification_passed = verify_embeddings(config)
print(f"\nVerification: {'PASSED' if verification_passed else 'FAILED'}")

In [None]:
# Step 9: Test similarity search (if requested)
if TEST_SEARCH:
    test_vector_search(config)

In [None]:
# Step 10: Print summary
print_pipeline_summary(stats, pipeline_config, embedding_provider)
print("\nDone!")

## Interpret Results

### Performance Expectations

| Model | Typical Throughput |
|-------|-------------------|
| MiniLM-L6 (384d) | 200-500 rows/second |
| MPNet (768d) | 100-300 rows/second |
| E5-Large (1024d) | 50-150 rows/second |

### Common Issues

| Error | Solution |
|-------|----------|
| Dimension mismatch | Update EMBEDDING_DIMENSIONS to match your model |
| Endpoint not found | Check Model Serving endpoint is running |
| Rate limiting | Reduce BATCH_SIZE |
| Lock contention | Reduce WRITE_PARTITIONS |

### Custom Model Checklist

- [ ] Model Serving endpoint is running
- [ ] EMBEDDING_ENDPOINT matches your endpoint name
- [ ] EMBEDDING_DIMENSIONS matches your model's output
- [ ] Input format uses `dataframe_records` with `text` field