# Neo4j to Delta Lake: Spark Connector Data Extraction

This notebook demonstrates a **simplified approach** to extracting graph data from Neo4j into Databricks Delta Lake tables using the **Neo4j Spark Connector**.

---

## Overview

### What This Notebook Does
Extracts **node data** from Neo4j using the Neo4j Spark Connector and creates separate Delta tables for each entity type:
- üë• **Customer** - Customer profile data
- üè¶ **Bank** - Financial institution data
- üíº **Account** - Customer account information
- üè¢ **Company** - Corporate entity information
- üìà **Stock** - Stock and security data
- üìä **Position** - Investment portfolio holdings
- üí∞ **Transaction** - Financial transaction records

### Key Features
- ‚úÖ **Direct Spark Pipeline** - Neo4j ‚Üí Spark ‚Üí Delta (no intermediate conversions)
- ‚úÖ **Type Safety** - Fixed Spark schemas using `neo4j_schemas.py` module
- ‚úÖ **Simple Code** - No custom config modules or helper functions
- ‚úÖ **Metadata Tracking** - Includes node ID, labels, and ingestion timestamp

### Unity Catalog Location
This notebook writes to: `fintech.default.neo4j_databricks_graph_demo_*`

---

## Setup: Import Schema Module

Import the `neo4j_schemas` module which provides fixed Spark schemas for all node types.

In [None]:
# Standard library imports
import os
import sys
import traceback
from typing import Dict, Any

# Third-party imports
import pandas as pd
from pyspark.sql.functions import col, current_timestamp, lit

# Configure Python path for schema module
python_repo_url = os.environ.get("PYTHON_REPO_URL")
sys.path.append(f"{python_repo_url}/neo4j_schemas.py")
sys.path.append(f"{python_repo_url}/databricks_constants.py")

# Import schema module - platform-agnostic schemas
from neo4j_schemas import (
    get_node_schema,
    get_version,
    list_node_schemas,
    validate_node_schema,
    # Relationship schema imports
    get_relationship_schema,
    list_relationship_schemas,
    RELATIONSHIP_METADATA,
    get_relationship_metadata,
    validate_relationship_schema,
)

# Import Databricks-specific constants
from databricks_constants import (
    NODE_TABLE_NAMES,
    RELATIONSHIP_TABLE_NAMES,
)

print(f"‚úÖ Schema module loaded (version {get_version()})")
print(f"\nAvailable node schemas: {', '.join(list_node_schemas())}")
print(f"Available relationship schemas: {', '.join(list_relationship_schemas())}")

In [None]:
# Optional - List all tables in the schema
tables = spark.sql("SHOW TABLES IN fintech.default").toPandas()

# Generate and execute DROP TABLE statements
for table in tables['tableName']:
    drop_stmt = f"DROP TABLE fintech.default.{table}"
    print(f"Executing: {drop_stmt}")
    spark.sql(drop_stmt)

## Configure Neo4j Spark Connector

Configure connection to Neo4j using environment variables and Databricks secrets.

In [None]:
# Get Neo4j connection details from environment variables
neo4j_url = os.environ.get("NEO4J_URL")
neo4j_username = os.environ.get("NEO4J_USERNAME")
neo4j_database = os.environ.get("NEO4J_DATABASE")

# Get password from Databricks secrets
neo4j_password = dbutils.secrets.get(scope="neo4j", key="password")

# Validate configuration
if not all([neo4j_url, neo4j_username, neo4j_database, neo4j_password]):
    raise ValueError("Missing Neo4j configuration. Please check environment variables and secrets.")

print("=" * 60)
print("Neo4j Spark Connector Configuration")
print("=" * 60)
print(f"URL:      {neo4j_url}")
print(f"Username: {neo4j_username}")
print(f"Database: {neo4j_database}")
print(f"Password: {'*' * len(neo4j_password)}")
print("=" * 60)
print("‚úÖ Configuration validated")

## Define Extraction Function

Define the function that extracts nodes from Neo4j to Delta Lake using the **optimized `labels` option approach**.

This approach:
- ‚úÖ Automatically includes all node properties
- ‚úÖ Includes metadata (`<id>` and `<labels>` columns)
- ‚úÖ No manual field listing required
- ‚úÖ Simple, clean code for demo purposes

In [None]:
def extract_node_type_to_delta(node_label: str, limit: int = 100) -> Dict[str, Any]:
    """
    Extract nodes from Neo4j using Spark Connector and write to Delta table.
    
    IMPLEMENTATION: Uses option("labels", ...) approach (BEST PRACTICE)
    - Automatically includes ALL node properties
    - Automatically includes <id> and <labels> metadata
    - Simple, clean code for demo purposes
    
    Args:
        node_label (str): Neo4j node label (e.g., 'Customer', 'Bank')
        limit (int): Maximum number of records to extract (default: 100)
    
    Returns:
        Dict[str, Any]: Extraction statistics
    """
    print(f"\n{'=' * 70}")
    print(f"Extracting {node_label} nodes...")
    print(f"{'=' * 70}")
    
    try:
        # Read from Neo4j using labels option
        # This automatically includes:
        # - <id>: Neo4j internal node ID (long)
        # - <labels>: Array of node labels
        # - All node properties
        df = (spark.read
              .format("org.neo4j.spark.DataSource")
              .option("url", neo4j_url)
              .option("authentication.type", "basic")
              .option("authentication.basic.username", neo4j_username)
              .option("authentication.basic.password", neo4j_password)
              .option("database", neo4j_database)
              .option("labels", node_label)
              .load())
        
        # Rename metadata columns to match schema module expectations
        # <id> ‚Üí neo4j_id, <labels> ‚Üí neo4j_labels
        # and add ingestion timestamp
        df_final = (df
                    .withColumnRenamed("<id>", "neo4j_id")
                    .withColumnRenamed("<labels>", "neo4j_labels")
                    .withColumn("ingestion_timestamp", current_timestamp())
                    .limit(limit))
        
        # Get table name from schema module
        table_name = NODE_TABLE_NAMES[node_label]
        
        # Write to Delta table
        print(f"  ‚öôÔ∏è  Writing to Delta table...")
        (df_final.write
         .format("delta")
         .mode("overwrite")
         .option("overwriteSchema", "true")
         .saveAsTable(table_name))
        
        # Get record count
        count = spark.sql(f"SELECT COUNT(*) as count FROM {table_name}").first()["count"]
        
        print(f"  ‚úÖ Complete: {count} records ‚Üí {table_name}")
        
        return {
            "node_label": node_label,
            "record_count": count,
            "table_name": table_name,
            "status": "success"
        }
        
    except Exception as e:
        print(f"  ‚ùå ERROR: {str(e)}")
        return {
            "node_label": node_label,
            "record_count": 0,
            "table_name": NODE_TABLE_NAMES.get(node_label, "unknown"),
            "status": "error",
            "error": str(e)
        }

print("‚úÖ Extraction function defined with simplified approach for demo")

## Extraction Summary

Display statistics about the extraction process.

In [None]:
print("\n" + "=" * 80)
print("EXTRACTING NODE DATA FROM NEO4J TO DELTA TABLES")
print("=" * 80 + "\n")

# Define node types to extract - aligned with neo4j_schemas.py definitions
node_types = [
    "Account",
    "Bank",
    "Company",
    "Customer",
    "Position",
    "Stock",
    "Transaction",
]

# Track extraction statistics with enhanced details
extraction_stats = {}

# Extract each node type
for node_label in node_types:
    stats = extract_node_type_to_delta(node_label, limit=100)
    extraction_stats[node_label] = stats

print("\n" + "=" * 80)
print("‚úÖ DATA EXTRACTION COMPLETE")
print("=" * 80)

In [None]:
# Create summary DataFrame with extraction results
summary_data = []
for node_type in node_types:
    stats = extraction_stats[node_type]
    
    # Determine status display
    if stats["status"] == "error":
        status_display = "‚ùå Error"
        details = stats.get("error", "Unknown error")
    elif stats["record_count"] == 0:
        status_display = "‚ö†Ô∏è  Empty"
        details = "No records found"
    else:
        status_display = "‚úÖ Success"
        details = "OK"
    
    summary_data.append({
        "Node Type": node_type,
        "Records": stats["record_count"],
        "Delta Table": stats["table_name"],
        "Status": status_display,
        "Details": details
    })

summary_df = pd.DataFrame(summary_data)

print("\n" + "=" * 100)
print("NODE EXTRACTION SUMMARY")
print("=" * 100)
print(summary_df.to_string(index=False))
print("\n" + "=" * 100)
print(f"Total Tables Created: {len(extraction_stats)}")
print(f"Total Records Extracted: {sum(s['record_count'] for s in extraction_stats.values())}")
print(f"Successful Extractions: {sum(1 for s in extraction_stats.values() if s['status'] == 'success')}")
print(f"Catalog Location: fintech.default")
print(f"Extraction Method: Neo4j Spark Connector (org.neo4j.spark.DataSource)")
print("=" * 100)

## Sample Data Preview

Preview sample records from each extracted node type to verify data quality and schema correctness.

In [None]:
# Display sample data from each node type for verification
for node_label in node_types:
    stats = extraction_stats[node_label]
    table_name = stats["table_name"]
    
    print(f"\n{'=' * 100}")
    print(f"Sample Data: {node_label} ({stats['record_count']} total records)")
    print(f"Table: {table_name}")
    print(f"{'=' * 100}")
    
    if stats["record_count"] > 0:
        # Show schema
        sample_df = spark.sql(f"SELECT * FROM {table_name} LIMIT 3")
        print(f"\nSchema:")
        sample_df.printSchema()
        
        print(f"\nSample Records (3 of {stats['record_count']}):")
        display(sample_df)
    else:
        print("‚ö†Ô∏è  No records to display")
    
    print()

## Detailed Schema Inspection

Examine the detailed schema and metadata for the Customer table to verify field types, constraints, and Delta table properties.

In [None]:
# Show detailed schema for Customer table as an example
example_table = NODE_TABLE_NAMES["Customer"]

print(f"\n{'=' * 100}")
print(f"Detailed Schema Inspection: Customer Table")
print(f"Table: {example_table}")
print(f"{'=' * 100}\n")

# Show schema in tree format
customer_df = spark.table(example_table)
print("Schema Structure:")
customer_df.printSchema()

print(f"\n{'=' * 100}")
print("Extended Table Metadata:")
print("=" * 100)
display(spark.sql(f"DESCRIBE TABLE EXTENDED {example_table}"))

---

## PART 2: Relationship (Edge) Extraction

Extract relationship data from Neo4j to create graph edge tables in Delta Lake. This enables graph analytics, path queries, and network analysis using standard SQL and Spark.

---

## Define Relationship Extraction Function

Define the function that extracts relationships (edges) from Neo4j to Delta Lake using the **`relationship` option approach**.

This approach:
- ‚úÖ Automatically includes relationship metadata (`<rel.id>`, `<rel.type>`, source/target IDs)
- ‚úÖ Uses business-meaningful column names (e.g., `customerId`, `senderAccountId`)
- ‚úÖ Simple, clean code following best practices from RELATIONSHIP_HANDLING.md
- ‚úÖ Ready for graph analytics and path queries

In [None]:
def extract_relationship_type_to_delta(rel_type: str, limit: int = 100) -> Dict[str, Any]:
    """
    Extract relationships from Neo4j using Spark Connector and write to Delta table.
    
    Args:
        rel_type (str): Neo4j relationship type (e.g., 'HAS_ACCOUNT', 'PERFORMS')
        limit (int): Maximum number of records to extract (default: 100)
    
    Returns:
        Dict[str, Any]: Extraction statistics
    """
    print(f"\n{'=' * 70}")
    print(f"Extracting {rel_type} relationships...")
    print(f"{'=' * 70}")
    
    try:
        # Get metadata for this relationship type
        metadata = get_relationship_metadata(rel_type)
        
        source_label = metadata["source_label"]
        dest_label = metadata["destination_label"]
        source_key = metadata["source_key"]  # Actual property name on source node
        dest_key = metadata["destination_key"]  # Actual property name on target node
        
        # Get schema to determine output column names
        from neo4j_schemas import BASE_RELATIONSHIP_SCHEMAS
        schema = BASE_RELATIONSHIP_SCHEMAS[rel_type]
        output_source_col = schema.fields[0].name  # First field is source key
        output_dest_col = schema.fields[1].name    # Second field is destination key
        
        print(f"  ‚öôÔ∏è  Pattern: ({source_label})-[:{rel_type}]->({dest_label})")
        print(f"  ‚öôÔ∏è  Neo4j Keys: {source_key} ‚Üí {dest_key}")
        print(f"  ‚öôÔ∏è  Output Columns: {output_source_col} ‚Üí {output_dest_col}")
        
        # Read from Neo4j using relationship option in FLAT MODE
        df = (spark.read
              .format("org.neo4j.spark.DataSource")
              .option("url", neo4j_url)
              .option("authentication.type", "basic")
              .option("authentication.basic.username", neo4j_username)
              .option("authentication.basic.password", neo4j_password)
              .option("database", neo4j_database)
              .option("relationship", rel_type)
              .option("relationship.source.labels", source_label)
              .option("relationship.target.labels", dest_label)
              .option("relationship.nodes.map", "false")
              .load())
        
        # Build column selection using actual Neo4j property names
        source_col = f"`source.{source_key}`"
        dest_col = f"`target.{dest_key}`"
        
        # Select and rename columns to match schema expectations
        df_final = (df
                    .select(
                        col(source_col).alias(output_source_col),
                        col(dest_col).alias(output_dest_col),
                        col("`<rel.id>`").alias("rel_element_id"),
                        col("`<rel.type>`").alias("rel_type"),
                        col("`<source.id>`").alias("src_neo4j_id"),
                        col("`<target.id>`").alias("dst_neo4j_id")
                    )
                    .withColumn("ingestion_timestamp", current_timestamp())
                    .limit(limit))
        
        # Get table name from schema module
        table_name = RELATIONSHIP_TABLE_NAMES[rel_type]
        
        # Write to Delta table
        print(f"  ‚öôÔ∏è  Writing to Delta table...")
        (df_final.write
         .format("delta")
         .mode("overwrite")
         .option("overwriteSchema", "true")
         .saveAsTable(table_name))
        
        # Get record count
        count = spark.sql(f"SELECT COUNT(*) as count FROM {table_name}").first()["count"]
        
        print(f"  ‚úÖ Complete: {count} edges ‚Üí {table_name}")
        
        return {
            "rel_type": rel_type,
            "record_count": count,
            "table_name": table_name,
            "pattern": f"{source_label} ‚Üí {dest_label}",
            "status": "success"
        }
        
    except Exception as e:
        print(f"  ‚ùå ERROR: {str(e)}")
        return {
            "rel_type": rel_type,
            "record_count": 0,
            "table_name": RELATIONSHIP_TABLE_NAMES.get(rel_type, "unknown"),
            "pattern": "unknown",
            "status": "error",
            "error": str(e)
        }

print("‚úÖ Relationship extraction function defined")

## Extract All Relationship Types

Extract all 7 relationship types from Neo4j to Delta Lake.

In [None]:
print("\n" + "=" * 80)
print("EXTRACTING RELATIONSHIP DATA FROM NEO4J TO DELTA TABLES")
print("=" * 80 + "\n")

# Define relationship types to extract - aligned with neo4j_schemas.py definitions
relationship_types = [
    "HAS_ACCOUNT",
    "AT_BANK",
    "OF_COMPANY",
    "PERFORMS",
    "BENEFITS_TO",
    "HAS_POSITION",
    "OF_SECURITY",
]

# Track extraction statistics
relationship_stats = {}

# Extract each relationship type
for rel_type in relationship_types:
    stats = extract_relationship_type_to_delta(rel_type, limit=100)
    relationship_stats[rel_type] = stats

print("\n" + "=" * 80)
print("‚úÖ RELATIONSHIP EXTRACTION COMPLETE")
print("=" * 80)

## Relationship Extraction Summary

Display statistics about the relationship extraction process.

In [None]:
# Create summary DataFrame with relationship extraction results
rel_summary_data = []
for rel_type in relationship_types:
    stats = relationship_stats[rel_type]
    
    # Determine status display
    if stats["status"] == "error":
        status_display = "‚ùå Error"
        error_msg = stats.get("error", "Unknown error")
    elif stats["record_count"] == 0:
        status_display = "‚ö†Ô∏è  Empty"
        error_msg = "No records found"
    else:
        status_display = "‚úÖ Success"
        error_msg = "OK"
    
    rel_summary_data.append({
        "Relationship": rel_type,
        "Pattern": stats["pattern"],
        "Records": stats["record_count"],
        "Delta Table": stats["table_name"],
        "Status": status_display,
        "Details": error_msg if stats["status"] != "success" else "OK"
    })

rel_summary_df = pd.DataFrame(rel_summary_data)

print("\n" + "=" * 100)
print("RELATIONSHIP EXTRACTION SUMMARY")
print("=" * 100)
print(rel_summary_df.to_string(index=False))
print("\n" + "=" * 100)
print(f"Total Edge Tables Created: {len(relationship_stats)}")
print(f"Total Edges Extracted: {sum(s['record_count'] for s in relationship_stats.values())}")
print(f"Successful Extractions: {sum(1 for s in relationship_stats.values() if s['status'] == 'success')}")
print(f"Catalog Location: fintech.default")
print(f"Extraction Method: Neo4j Spark Connector (org.neo4j.spark.DataSource) with relationship option")
print("=" * 100)

## Sample Relationship Queries

Demonstrate how to use the relationship tables with joins to perform graph analytics using standard SQL.

In [None]:
# Example 1: Customer Accounts with Bank Information
# Join: Customer -[HAS_ACCOUNT]-> Account -[AT_BANK]-> Bank

print("=" * 100)
print("Example 1: Customer Accounts with Bank Information")
print("=" * 100)

query1 = """
SELECT 
    c.customerId,
    c.firstName,
    c.lastName,
    a.accountId,
    a.accountType,
    a.balance,
    b.name AS bank_name,
    b.bankType
FROM fintech.default.neo4j_customer c
JOIN fintech.default.neo4j_has_account ha ON c.customerId = ha.customerId
JOIN fintech.default.neo4j_account a ON ha.accountId = a.accountId
JOIN fintech.default.neo4j_at_bank ab ON a.accountId = ab.accountId
JOIN fintech.default.neo4j_bank b ON ab.bankId = b.bankId
ORDER BY a.balance DESC
LIMIT 10
"""

result1 = spark.sql(query1)
print("\nTop 10 Accounts by Balance with Customer and Bank Information:")
display(result1)

In [None]:
# Example 2: Portfolio Holdings Analysis
# Join: Account -[HAS_POSITION]-> Position -[OF_SECURITY]-> Stock -[OF_COMPANY]-> Company

print("\n" + "=" * 100)
print("Example 3: Portfolio Holdings with Stock and Company Details")
print("=" * 100)

query3 = """
SELECT 
    hp.accountId,
    p.positionId,
    p.shares,
    p.currentValue,
    p.percentageOfPortfolio,
    s.ticker,
    s.currentPrice,
    co.name AS company_name,
    co.sector,
    co.industry
FROM fintech.default.neo4j_has_position hp
JOIN fintech.default.neo4j_position p ON hp.positionId = p.positionId
JOIN fintech.default.neo4j_of_security os ON p.positionId = os.positionId
JOIN fintech.default.neo4j_stock s ON os.stockId = s.stockId
JOIN fintech.default.neo4j_of_company oc ON s.stockId = oc.stockId
JOIN fintech.default.neo4j_company co ON oc.companyId = co.companyId
ORDER BY p.currentValue DESC
LIMIT 10
"""

result3 = spark.sql(query3)
print("\nTop 10 Portfolio Positions by Value:")
display(result3)

In [None]:
# Example 3: Complete Customer Financial Profile
# Multi-path join showing customer's accounts, transactions, and investments

print("\n" + "=" * 100)
print("Example 4: Complete Financial Profile for a Sample Customer")
print("=" * 100)

query4 = """
WITH customer_accounts AS (
    SELECT 
        c.customerId,
        c.firstName,
        c.lastName,
        c.riskProfile,
        a.accountId,
        a.accountType,
        a.balance
    FROM fintech.default.neo4j_customer c
    JOIN fintech.default.neo4j_has_account ha ON c.customerId = ha.customerId
    JOIN fintech.default.neo4j_account a ON ha.accountId = a.accountId
),
account_positions AS (
    SELECT 
        hp.accountId,
        COUNT(*) AS num_positions,
        SUM(p.currentValue) AS total_portfolio_value
    FROM fintech.default.neo4j_has_position hp
    JOIN fintech.default.neo4j_position p ON hp.positionId = p.positionId
    GROUP BY hp.accountId
)
/*
,account_transactions AS (
    SELECT
        p.senderAccountId AS accountId,
        COUNT(*) AS num_transactions,
        SUM(t.amount) AS total_sent
    FROM fintech.default.neo4j_performs p
    JOIN fintech.default.neo4j_transaction t ON p.transactionId = t.transactionId
    GROUP BY p.senderAccountId
)
*/
SELECT
    ca.customerId,
    ca.firstName,
    ca.lastName,
    ca.riskProfile,
    ca.accountId,
    ca.accountType,
    ca.balance,
    COALESCE(ap.num_positions, 0) AS num_positions,
    COALESCE(ap.total_portfolio_value, 0) AS portfolio_value
    -- ,COALESCE(at.num_transactions, 0) AS num_transactions_sent
    -- ,COALESCE(at.total_sent, 0) AS total_amount_sent
FROM customer_accounts ca
LEFT JOIN account_positions ap ON ca.accountId = ap.accountId
-- LEFT JOIN account_transactions at ON ca.accountId = at.accountId
ORDER BY ca.balance DESC
LIMIT 10
"""

result4 = spark.sql(query4)
print("\nComplete Financial Profile (Top 10 Customers by Account Balance):")
display(result4)