# Export Neo4j Graph Data to Databricks Unity Catalog

This notebook extracts nodes and relationships from Neo4j and writes them as Delta tables in Unity Catalog.

## Prerequisites

1. **Neo4j database** with financial demo data loaded
2. **Databricks Secrets** - run `./scripts/setup_databricks_secrets.sh`
3. **Databricks cluster** with Neo4j Spark Connector installed
4. **Cluster access mode**: Dedicated (not Shared)

## Output

Creates 14 Delta tables in Unity Catalog:
- 7 node tables: customer, bank, account, company, stock, position, transaction
- 7 relationship tables: has_account, at_bank, of_company, performs, benefits_to, has_position, of_security

## Step 1: Configuration

Define the node labels and relationship types to export from Neo4j. The catalog and schema names are set as defaults but will be overridden by values from Databricks Secrets in the next step.

In [None]:
import time

# Default catalog/schema (will be overridden from secrets)
CATALOG = "neo4j_augmentation_demo"
SCHEMA = "graph_data"

# Node labels to extract
NODE_LABELS = [
    "Customer",
    "Bank",
    "Account",
    "Company",
    "Stock",
    "Position",
    "Transaction",
]

# Relationships: (type, source_label, target_label)
RELATIONSHIPS = [
    ("HAS_ACCOUNT", "Customer", "Account"),
    ("AT_BANK", "Account", "Bank"),
    ("OF_COMPANY", "Stock", "Company"),
    ("PERFORMS", "Account", "Transaction"),
    ("BENEFITS_TO", "Transaction", "Account"),
    ("HAS_POSITION", "Account", "Position"),
    ("OF_SECURITY", "Position", "Stock"),
]

print(f"Nodes to export: {len(NODE_LABELS)}")
print(f"Relationships to export: {len(RELATIONSHIPS)}")

## Step 2: Load Credentials

Retrieve Neo4j connection credentials and Unity Catalog configuration from Databricks Secrets. The volume path secret is parsed to extract the catalog and schema names, ensuring consistency with other labs.

In [None]:
print("Loading credentials from Databricks Secrets...")

NEO4J_URL = dbutils.secrets.get(scope="neo4j-creds", key="url")
NEO4J_USER = dbutils.secrets.get(scope="neo4j-creds", key="username")
NEO4J_PASS = dbutils.secrets.get(scope="neo4j-creds", key="password")
NEO4J_DATABASE = "neo4j"

# Extract catalog/schema from volume_path: /Volumes/{catalog}/{schema}/{volume}
try:
    volume_path = dbutils.secrets.get(scope="neo4j-creds", key="volume_path")
    parts = volume_path.strip("/").split("/")
    if len(parts) >= 3 and parts[0] == "Volumes":
        CATALOG = parts[1]
        SCHEMA = parts[2]
        print(f"[OK] Catalog: {CATALOG}")
        print(f"[OK] Schema: {SCHEMA}")
except Exception:
    print(f"[INFO] Using defaults: {CATALOG}.{SCHEMA}")

# Configure Spark for Neo4j
spark.conf.set("neo4j.url", NEO4J_URL)
spark.conf.set("neo4j.authentication.basic.username", NEO4J_USER)
spark.conf.set("neo4j.authentication.basic.password", NEO4J_PASS)
spark.conf.set("neo4j.database", NEO4J_DATABASE)

print(f"[OK] Neo4j URL: {NEO4J_URL}")

## Step 3: Test Neo4j Connection

Verify that the Spark session can connect to Neo4j using the configured credentials. This runs a simple Cypher query to confirm the Neo4j Spark Connector is working before attempting the full export.

In [None]:
print("Testing Neo4j connection...")

try:
    test_df = spark.read.format("org.neo4j.spark.DataSource").option("query", "RETURN 1 AS test").load()
    test_df.collect()
    print("[OK] Connected to Neo4j!")
except Exception as e:
    print(f"[FAIL] {e}")
    raise

## Step 4: Setup Unity Catalog

Verify the target catalog exists and create the schema if it doesn't exist. The schema will hold all the Delta tables created during the export process.

In [None]:
print("Setting up Unity Catalog...")

# Check catalog exists
try:
    catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()]
    if CATALOG not in catalogs:
        print(f"[ERROR] Catalog '{CATALOG}' not found.")
        print(f"Available: {catalogs}")
        raise ValueError(f"Catalog '{CATALOG}' not found")
    print(f"[OK] Catalog exists: {CATALOG}")
except ValueError:
    raise
except Exception:
    print(f"[INFO] Could not list catalogs, trying to use '{CATALOG}' directly")

# Create schema
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}")
print(f"[OK] Schema ready: {CATALOG}.{SCHEMA}")

## Step 5: Define Helper Functions

Create reusable functions for reading data from Neo4j and writing to Delta tables:
- `read_nodes()`: Reads all nodes with a given label from Neo4j
- `read_relationship()`: Reads relationships of a given type between specified node labels
- `write_table()`: Writes a DataFrame to a Delta table in Unity Catalog

In [None]:
def read_nodes(label):
    """Read nodes from Neo4j."""
    return spark.read.format("org.neo4j.spark.DataSource").option("labels", label).load()

def read_relationship(rel_type, source_label, target_label):
    """Read relationships from Neo4j."""
    return (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("relationship", rel_type)
        .option("relationship.source.labels", source_label)
        .option("relationship.target.labels", target_label)
        .option("relationship.nodes.map", "false")
        .load()
    )

def write_table(df, table_name):
    """Write DataFrame to Delta table."""
    full_name = f"{CATALOG}.{SCHEMA}.{table_name}"
    df.write.format("delta").mode("overwrite").saveAsTable(full_name)
    return df.count()

print("Helper functions defined.")

## Step 6: Export Nodes

Iterate through all node labels defined in the configuration, read each node type from Neo4j, and write them as Delta tables. Progress and timing information is displayed for each table.

In [None]:
print("=" * 50)
print("EXPORTING NODES")
print("=" * 50)

node_results = {}

for i, label in enumerate(NODE_LABELS, 1):
    table_name = label.lower()
    print(f"\n[{i}/{len(NODE_LABELS)}] {label} -> {table_name}")
    
    start = time.time()
    df = read_nodes(label)
    count = write_table(df, table_name)
    elapsed = time.time() - start
    
    node_results[label] = {"count": count, "time": elapsed}
    print(f"    [OK] {count} rows in {elapsed:.2f}s")

print(f"\nExported {len(node_results)} node tables")

## Step 7: Export Relationships

Iterate through all relationship types defined in the configuration, read each relationship from Neo4j (including source and target node IDs), and write them as Delta tables.

In [None]:
print("=" * 50)
print("EXPORTING RELATIONSHIPS")
print("=" * 50)

rel_results = {}

for i, (rel_type, source, target) in enumerate(RELATIONSHIPS, 1):
    table_name = rel_type.lower()
    print(f"\n[{i}/{len(RELATIONSHIPS)}] {rel_type} -> {table_name}")
    
    start = time.time()
    df = read_relationship(rel_type, source, target)
    count = write_table(df, table_name)
    elapsed = time.time() - start
    
    rel_results[rel_type] = {"count": count, "time": elapsed}
    print(f"    [OK] {count} rows in {elapsed:.2f}s")

print(f"\nExported {len(rel_results)} relationship tables")

## Step 8: Validate

Compare the exported row counts against expected values to ensure the export completed correctly. Any mismatches indicate potential issues with the source data or export process.

In [None]:
print("=" * 50)
print("VALIDATION")
print("=" * 50)

expected_nodes = {
    "Customer": 102, "Bank": 102, "Account": 123, "Company": 102,
    "Stock": 102, "Position": 110, "Transaction": 123,
}

expected_rels = {
    "HAS_ACCOUNT": 123, "AT_BANK": 123, "OF_COMPANY": 102, "PERFORMS": 123,
    "BENEFITS_TO": 123, "HAS_POSITION": 110, "OF_SECURITY": 110,
}

all_valid = True

print(f"\n{'Table':<15} {'Expected':>10} {'Actual':>10} {'Status':>10}")
print("-" * 50)

for label, expected in expected_nodes.items():
    actual = node_results.get(label, {}).get("count", 0)
    status = "OK" if actual == expected else "MISMATCH"
    if actual != expected: all_valid = False
    print(f"{label.lower():<15} {expected:>10} {actual:>10} {status:>10}")

for rel_type, expected in expected_rels.items():
    actual = rel_results.get(rel_type, {}).get("count", 0)
    status = "OK" if actual == expected else "MISMATCH"
    if actual != expected: all_valid = False
    print(f"{rel_type.lower():<15} {expected:>10} {actual:>10} {status:>10}")

print(f"\n{'All validations passed!' if all_valid else 'Some counts do not match'}")

## Step 9: Summary

Display a final summary of the export operation, including the destination catalog/schema, total number of tables created, and total row count across all tables.

In [None]:
total_nodes = sum(r["count"] for r in node_results.values())
total_rels = sum(r["count"] for r in rel_results.values())

print("=" * 50)
print("EXPORT COMPLETE")
print("=" * 50)
print(f"Destination: {CATALOG}.{SCHEMA}")
print(f"Tables: {len(node_results) + len(rel_results)}")
print(f"Total rows: {total_nodes + total_rels}")
print("=" * 50)