# Financial Demo Data Import to Neo4j

This notebook imports the retail banking and investment portfolio demonstration data from Databricks Unity Catalog into Neo4j.

## Prerequisites

1. **Neo4j database** running (Aura or self-hosted)
2. **Databricks Secrets** configured with `neo4j-creds` scope containing `username`, `password`, and `url`
3. **Unity Catalog Volume** with CSV files uploaded
4. **Databricks Cluster** with Neo4j Spark Connector installed

## Data Overview

- **102 Customers** with demographics and financial profiles
- **102 Banks** across multiple types (commercial, regional, community)
- **123 Accounts** (checking, savings, investment)
- **102 Companies** across 12+ sectors
- **102 Stocks** with market data
- **110 Portfolio Positions** linking accounts to stocks
- **123 Transactions** between accounts

---

## Step 1: Configuration

Configure Neo4j connection and data source paths. All credentials are retrieved from Databricks Secrets.

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# Retrieve Neo4j credentials from Databricks Secrets
NEO4J_USER = dbutils.secrets.get(scope="neo4j-creds", key="username")
NEO4J_PASS = dbutils.secrets.get(scope="neo4j-creds", key="password")
NEO4J_URL = dbutils.secrets.get(scope="neo4j-creds", key="url")

# Unity Catalog Volume path containing CSV files
VOLUME_PATH = dbutils.secrets.get(scope="neo4j-creds", key="volume_path")

# Neo4j database name (default: neo4j)
NEO4J_DATABASE = "neo4j"

# Configure Spark session for Neo4j connector
spark.conf.set("neo4j.url", NEO4J_URL)
spark.conf.set("neo4j.authentication.basic.username", NEO4J_USER)
spark.conf.set("neo4j.authentication.basic.password", NEO4J_PASS)
spark.conf.set("neo4j.database", NEO4J_DATABASE)

print(f"Neo4j URL: {NEO4J_URL}")
print(f"Database: {NEO4J_DATABASE}")
print(f"Volume Path: {VOLUME_PATH}")

## Step 2: Verify Prerequisites

Verify Neo4j connectivity and CSV file availability before proceeding.

In [None]:
# Verify CSV files exist in Unity Catalog Volume
print("Checking CSV files in Unity Catalog Volume...")
print("=" * 60)

expected_files = [
    "customers.csv",
    "banks.csv",
    "accounts.csv",
    "companies.csv",
    "stocks.csv",
    "portfolio_holdings.csv",
    "transactions.csv"
]

files = dbutils.fs.ls(VOLUME_PATH)
found_files = [f.name for f in files]

all_present = True
for expected in expected_files:
    status = "FOUND" if expected in found_files else "MISSING"
    if status == "MISSING":
        all_present = False
    print(f"  {expected}: {status}")

if all_present:
    print("\nAll required CSV files are present.")
else:
    raise FileNotFoundError("Missing required CSV files. Please upload all files to the Unity Catalog Volume.")

In [None]:
# Verify Neo4j connectivity
print("Testing Neo4j connection...")
print("=" * 60)

try:
    test_df = (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("url", NEO4J_URL)
        .option("authentication.basic.username", NEO4J_USER)
        .option("authentication.basic.password", NEO4J_PASS)
        .option("database", NEO4J_DATABASE)
        .option("query", "RETURN 'Connected' AS status")
        .load()
    )
    test_df.show()
    print("Neo4j connection successful!")
except Exception as e:
    print(f"Connection failed: {str(e)}")
    print("\nPlease verify:")
    print("  1. Neo4j database is running")
    print("  2. Connection URL is correct")
    print("  3. Credentials are valid")
    print("  4. Network connectivity is available")
    raise

## Step 3: Helper Functions

Define reusable functions for data loading and Neo4j operations.

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType, DateType


def read_csv(filename: str) -> DataFrame:
    """Read a CSV file from the Unity Catalog Volume."""
    path = f"{VOLUME_PATH}/{filename}"
    return spark.read.option("header", "true").csv(path)


def write_nodes(df: DataFrame, label: str, node_key: str) -> None:
    """Write DataFrame rows as nodes to Neo4j.
    
    Args:
        df: DataFrame containing node properties
        label: Neo4j node label
        node_key: Property name to use as the node key
    """
    (
        df.write.format("org.neo4j.spark.DataSource")
        .mode("Append")
        .option("labels", f":{label}")
        .option("node.keys", node_key)
        .save()
    )


def write_relationship(
    df: DataFrame,
    rel_type: str,
    source_label: str,
    source_key: str,
    target_label: str,
    target_key: str
) -> None:
    """Write DataFrame rows as relationships to Neo4j using key matching.
    
    Args:
        df: DataFrame with source and target key columns
        rel_type: Neo4j relationship type
        source_label: Source node label
        source_key: Source node key column name
        target_label: Target node label
        target_key: Target node key column name
    """
    (
        df.write.format("org.neo4j.spark.DataSource")
        .mode("Append")
        .option("relationship", rel_type)
        .option("relationship.save.strategy", "keys")
        .option("relationship.source.save.mode", "Match")
        .option("relationship.source.labels", f":{source_label}")
        .option("relationship.source.node.keys", f"{source_key}:{source_key}")
        .option("relationship.target.save.mode", "Match")
        .option("relationship.target.labels", f":{target_label}")
        .option("relationship.target.node.keys", f"{target_key}:{target_key}")
        .save()
    )


def run_cypher(query: str) -> DataFrame:
    """Execute a Cypher query and return results as DataFrame."""
    return (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("url", NEO4J_URL)
        .option("authentication.basic.username", NEO4J_USER)
        .option("authentication.basic.password", NEO4J_PASS)
        .option("database", NEO4J_DATABASE)
        .option("query", query)
        .load()
    )


print("Helper functions defined.")

## Step 4: Load and Transform CSV Data

Load all CSV files and apply appropriate data type conversions.

In [None]:
print("Loading CSV files from Unity Catalog Volume...")
print("=" * 60)

# Load raw CSVs
customers_raw = read_csv("customers.csv")
banks_raw = read_csv("banks.csv")
accounts_raw = read_csv("accounts.csv")
companies_raw = read_csv("companies.csv")
stocks_raw = read_csv("stocks.csv")
positions_raw = read_csv("portfolio_holdings.csv")
transactions_raw = read_csv("transactions.csv")

print(f"  customers.csv: {customers_raw.count()} rows")
print(f"  banks.csv: {banks_raw.count()} rows")
print(f"  accounts.csv: {accounts_raw.count()} rows")
print(f"  companies.csv: {companies_raw.count()} rows")
print(f"  stocks.csv: {stocks_raw.count()} rows")
print(f"  portfolio_holdings.csv: {positions_raw.count()} rows")
print(f"  transactions.csv: {transactions_raw.count()} rows")

print("\nAll CSV files loaded successfully.")

In [None]:
print("Applying data type conversions...")
print("=" * 60)

# Transform Customers
customers = (
    customers_raw
    .withColumn("annual_income", F.col("annual_income").cast(IntegerType()))
    .withColumn("credit_score", F.col("credit_score").cast(IntegerType()))
    .withColumn("registration_date", F.to_date(F.col("registration_date")))
    .withColumn("date_of_birth", F.to_date(F.col("date_of_birth")))
)
print("  Customers: annual_income (INT), credit_score (INT), dates (DATE)")

# Transform Banks
banks = (
    banks_raw
    .withColumn("total_assets_billions", F.col("total_assets_billions").cast(DoubleType()))
    .withColumn("established_year", F.col("established_year").cast(IntegerType()))
)
print("  Banks: total_assets_billions (DOUBLE), established_year (INT)")

# Transform Accounts
accounts = (
    accounts_raw
    .withColumn("balance", F.col("balance").cast(DoubleType()))
    .withColumn("interest_rate", F.col("interest_rate").cast(DoubleType()))
    .withColumn("opened_date", F.to_date(F.col("opened_date")))
)
print("  Accounts: balance (DOUBLE), interest_rate (DOUBLE), opened_date (DATE)")

# Transform Companies
companies = (
    companies_raw
    .withColumn("market_cap_billions", F.col("market_cap_billions").cast(DoubleType()))
    .withColumn("annual_revenue_billions", F.col("annual_revenue_billions").cast(DoubleType()))
    .withColumn("founded_year", F.col("founded_year").cast(IntegerType()))
    .withColumn("employee_count", F.col("employee_count").cast(IntegerType()))
)
print("  Companies: market_cap (DOUBLE), revenue (DOUBLE), founded_year (INT), employees (INT)")

# Transform Stocks
stocks = (
    stocks_raw
    .withColumn("current_price", F.col("current_price").cast(DoubleType()))
    .withColumn("previous_close", F.col("previous_close").cast(DoubleType()))
    .withColumn("opening_price", F.col("opening_price").cast(DoubleType()))
    .withColumn("day_high", F.col("day_high").cast(DoubleType()))
    .withColumn("day_low", F.col("day_low").cast(DoubleType()))
    .withColumn("volume", F.col("volume").cast(IntegerType()))
    .withColumn("market_cap_billions", F.col("market_cap_billions").cast(DoubleType()))
    .withColumn("pe_ratio", F.col("pe_ratio").cast(DoubleType()))
    .withColumn("dividend_yield", F.col("dividend_yield").cast(DoubleType()))
    .withColumn("fifty_two_week_high", F.col("fifty_two_week_high").cast(DoubleType()))
    .withColumn("fifty_two_week_low", F.col("fifty_two_week_low").cast(DoubleType()))
)
print("  Stocks: all price/ratio fields (DOUBLE), volume (INT)")

# Transform Positions (rename holding_id to position_id for clarity)
positions = (
    positions_raw
    .withColumnRenamed("holding_id", "position_id")
    .withColumn("shares", F.col("shares").cast(IntegerType()))
    .withColumn("purchase_price", F.col("purchase_price").cast(DoubleType()))
    .withColumn("current_value", F.col("current_value").cast(DoubleType()))
    .withColumn("percentage_of_portfolio", F.col("percentage_of_portfolio").cast(DoubleType()))
    .withColumn("purchase_date", F.to_date(F.col("purchase_date")))
)
print("  Positions: shares (INT), prices/values (DOUBLE), purchase_date (DATE)")

# Transform Transactions
transactions = (
    transactions_raw
    .withColumn("amount", F.col("amount").cast(DoubleType()))
    .withColumn("transaction_date", F.to_date(F.col("transaction_date")))
)
print("  Transactions: amount (DOUBLE), transaction_date (DATE)")

print("\nData type conversions complete.")

## Step 5: Create Indexes and Constraints

Create indexes and uniqueness constraints BEFORE loading data for optimal performance.

**Best Practice**: Creating indexes first significantly improves write performance for large datasets and ensures data integrity.

In [None]:
print("Creating indexes and constraints in Neo4j...")
print("=" * 60)

# Define constraints and indexes
constraints = [
    ("customer_id_unique", "Customer", "customer_id"),
    ("bank_id_unique", "Bank", "bank_id"),
    ("account_id_unique", "Account", "account_id"),
    ("company_id_unique", "Company", "company_id"),
    ("stock_id_unique", "Stock", "stock_id"),
    ("position_id_unique", "Position", "position_id"),
    ("transaction_id_unique", "Transaction", "transaction_id"),
]

for constraint_name, label, property_name in constraints:
    query = f"""
    CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
    FOR (n:{label})
    REQUIRE n.{property_name} IS UNIQUE
    """
    try:
        run_cypher(query).collect()
        print(f"  Constraint: {constraint_name}")
    except Exception as e:
        print(f"  Constraint {constraint_name}: {str(e)[:50]}...")

print("\nConstraints and indexes created.")

## Step 6: Write Nodes to Neo4j

Write all node types to Neo4j. The order doesn't matter since we're creating nodes first, then relationships.

**Graph Schema:**
- Customer (102 nodes)
- Bank (102 nodes)
- Account (123 nodes)
- Company (102 nodes)
- Stock (102 nodes)
- Position (110 nodes)
- Transaction (123 nodes)

In [None]:
print("Writing nodes to Neo4j...")
print("=" * 60)

# Customer nodes
print("\n[1/7] Writing Customer nodes...")
write_nodes(customers, "Customer", "customer_id")
print(f"       {customers.count()} Customer nodes written")

# Bank nodes
print("\n[2/7] Writing Bank nodes...")
write_nodes(banks, "Bank", "bank_id")
print(f"       {banks.count()} Bank nodes written")

# Account nodes (exclude foreign keys, we'll create relationships)
print("\n[3/7] Writing Account nodes...")
account_props = accounts.select(
    "account_id", "account_number", "account_type", 
    "balance", "currency", "opened_date", "status", "interest_rate"
)
write_nodes(account_props, "Account", "account_id")
print(f"       {account_props.count()} Account nodes written")

# Company nodes
print("\n[4/7] Writing Company nodes...")
write_nodes(companies, "Company", "company_id")
print(f"       {companies.count()} Company nodes written")

# Stock nodes (exclude foreign key)
print("\n[5/7] Writing Stock nodes...")
stock_props = stocks.select(
    "stock_id", "ticker", "current_price", "previous_close", "opening_price",
    "day_high", "day_low", "volume", "market_cap_billions", "pe_ratio",
    "dividend_yield", "fifty_two_week_high", "fifty_two_week_low", "exchange"
)
write_nodes(stock_props, "Stock", "stock_id")
print(f"       {stock_props.count()} Stock nodes written")

# Position nodes (exclude foreign keys)
print("\n[6/7] Writing Position nodes...")
position_props = positions.select(
    "position_id", "shares", "purchase_price", "purchase_date",
    "current_value", "percentage_of_portfolio"
)
write_nodes(position_props, "Position", "position_id")
print(f"       {position_props.count()} Position nodes written")

# Transaction nodes (exclude foreign keys)
print("\n[7/7] Writing Transaction nodes...")
transaction_props = transactions.select(
    "transaction_id", "amount", "currency", "transaction_date",
    "transaction_time", "type", "status", "description"
)
write_nodes(transaction_props, "Transaction", "transaction_id")
print(f"       {transaction_props.count()} Transaction nodes written")

print("\n" + "=" * 60)
print("ALL NODES WRITTEN SUCCESSFULLY!")
print("=" * 60)

## Step 7: Write Relationships to Neo4j

Create all relationships between nodes using key matching.

**Relationship Types:**
1. `(:Customer)-[:HAS_ACCOUNT]->(:Account)` - Customer owns account
2. `(:Account)-[:AT_BANK]->(:Bank)` - Account held at bank
3. `(:Stock)-[:OF_COMPANY]->(:Company)` - Stock issued by company
4. `(:Account)-[:PERFORMS]->(:Transaction)` - Account initiates transfer
5. `(:Transaction)-[:BENEFITS_TO]->(:Account)` - Account receives funds
6. `(:Account)-[:HAS_POSITION]->(:Position)` - Account holds position
7. `(:Position)-[:OF_SECURITY]->(:Stock)` - Position is in specific stock

In [None]:
print("Writing relationships to Neo4j...")
print("=" * 60)

# 1. HAS_ACCOUNT: Customer -> Account
print("\n[1/7] Writing HAS_ACCOUNT relationships...")
has_account_df = accounts.select("customer_id", "account_id")
write_relationship(
    has_account_df, "HAS_ACCOUNT",
    "Customer", "customer_id",
    "Account", "account_id"
)
print(f"       {has_account_df.count()} HAS_ACCOUNT relationships written")

# 2. AT_BANK: Account -> Bank
print("\n[2/7] Writing AT_BANK relationships...")
at_bank_df = accounts.select("account_id", "bank_id")
write_relationship(
    at_bank_df, "AT_BANK",
    "Account", "account_id",
    "Bank", "bank_id"
)
print(f"       {at_bank_df.count()} AT_BANK relationships written")

# 3. OF_COMPANY: Stock -> Company
print("\n[3/7] Writing OF_COMPANY relationships...")
of_company_df = stocks.select("stock_id", "company_id")
write_relationship(
    of_company_df, "OF_COMPANY",
    "Stock", "stock_id",
    "Company", "company_id"
)
print(f"       {of_company_df.count()} OF_COMPANY relationships written")

# 4. PERFORMS: Account -> Transaction (from_account initiates)
print("\n[4/7] Writing PERFORMS relationships...")
performs_df = transactions.select(
    F.col("from_account_id").alias("account_id"),
    "transaction_id"
)
write_relationship(
    performs_df, "PERFORMS",
    "Account", "account_id",
    "Transaction", "transaction_id"
)
print(f"       {performs_df.count()} PERFORMS relationships written")

# 5. BENEFITS_TO: Transaction -> Account (to_account receives)
print("\n[5/7] Writing BENEFITS_TO relationships...")
benefits_df = transactions.select(
    "transaction_id",
    F.col("to_account_id").alias("account_id")
)
write_relationship(
    benefits_df, "BENEFITS_TO",
    "Transaction", "transaction_id",
    "Account", "account_id"
)
print(f"       {benefits_df.count()} BENEFITS_TO relationships written")

# 6. HAS_POSITION: Account -> Position
print("\n[6/7] Writing HAS_POSITION relationships...")
has_position_df = positions.select("account_id", "position_id")
write_relationship(
    has_position_df, "HAS_POSITION",
    "Account", "account_id",
    "Position", "position_id"
)
print(f"       {has_position_df.count()} HAS_POSITION relationships written")

# 7. OF_SECURITY: Position -> Stock
print("\n[7/7] Writing OF_SECURITY relationships...")
of_security_df = positions.select("position_id", "stock_id")
write_relationship(
    of_security_df, "OF_SECURITY",
    "Position", "position_id",
    "Stock", "stock_id"
)
print(f"       {of_security_df.count()} OF_SECURITY relationships written")

print("\n" + "=" * 60)
print("ALL RELATIONSHIPS WRITTEN SUCCESSFULLY!")
print("=" * 60)

## Step 8: Validate Import

Run validation queries to verify the import completed correctly.

In [None]:
print("Validating node counts...")
print("=" * 60)

# Expected counts
expected_nodes = {
    "Customer": 102,
    "Bank": 102,
    "Account": 123,
    "Company": 102,
    "Stock": 102,
    "Position": 110,
    "Transaction": 123
}

# Query actual counts
node_count_query = """
MATCH (n)
RETURN labels(n)[0] AS label, count(n) AS count
ORDER BY label
"""
node_counts = run_cypher(node_count_query).collect()

print("\n{:<15} {:>10} {:>10} {:>10}".format("Label", "Expected", "Actual", "Status"))
print("-" * 50)

all_valid = True
for row in node_counts:
    label = row["label"]
    actual = row["count"]
    expected = expected_nodes.get(label, "N/A")
    status = "OK" if actual == expected else "MISMATCH"
    if status == "MISMATCH":
        all_valid = False
    print("{:<15} {:>10} {:>10} {:>10}".format(label, expected, actual, status))

if all_valid:
    print("\nNode validation: PASSED")
else:
    print("\nNode validation: FAILED - check mismatched counts")

In [None]:
print("Validating relationship counts...")
print("=" * 60)

# Expected counts
expected_rels = {
    "HAS_ACCOUNT": 123,
    "AT_BANK": 123,
    "OF_COMPANY": 102,
    "PERFORMS": 123,
    "BENEFITS_TO": 123,
    "HAS_POSITION": 110,
    "OF_SECURITY": 110
}

# Query actual counts
rel_count_query = """
MATCH ()-[r]->()
RETURN type(r) AS relationship_type, count(r) AS count
ORDER BY relationship_type
"""
rel_counts = run_cypher(rel_count_query).collect()

print("\n{:<20} {:>10} {:>10} {:>10}".format("Relationship", "Expected", "Actual", "Status"))
print("-" * 55)

all_valid = True
for row in rel_counts:
    rel_type = row["relationship_type"]
    actual = row["count"]
    expected = expected_rels.get(rel_type, "N/A")
    status = "OK" if actual == expected else "MISMATCH"
    if status == "MISMATCH":
        all_valid = False
    print("{:<20} {:>10} {:>10} {:>10}".format(rel_type, expected, actual, status))

if all_valid:
    print("\nRelationship validation: PASSED")
else:
    print("\nRelationship validation: FAILED - check mismatched counts")

In [None]:
print("Running sample queries to verify data integrity...")
print("=" * 60)

# Sample query 1: Customer's complete financial profile
print("\n[Query 1] Customer C0001's complete financial profile:")
print("-" * 50)

query1 = """
MATCH (c:Customer {customer_id: 'C0001'})-[:HAS_ACCOUNT]->(a:Account)
OPTIONAL MATCH (a)-[:AT_BANK]->(b:Bank)
OPTIONAL MATCH (a)-[:HAS_POSITION]->(p:Position)-[:OF_SECURITY]->(s:Stock)
RETURN
    c.first_name + ' ' + c.last_name AS customer_name,
    a.account_id AS account,
    a.account_type AS account_type,
    a.balance AS balance,
    b.name AS bank_name,
    s.ticker AS ticker,
    p.shares AS shares,
    p.current_value AS holding_value
ORDER BY holding_value DESC
LIMIT 5
"""
display(run_cypher(query1))

In [None]:
# Sample query 2: Top 5 banks by total deposits
print("\n[Query 2] Top 5 banks by total deposits:")
print("-" * 50)

query2 = """
MATCH (b:Bank)<-[:AT_BANK]-(a:Account)
RETURN
    b.name AS bank_name,
    b.bank_type AS bank_type,
    count(DISTINCT a) AS num_accounts,
    round(sum(a.balance), 2) AS total_deposits
ORDER BY total_deposits DESC
LIMIT 5
"""
display(run_cypher(query2))

In [None]:
# Sample query 3: Transaction flow validation
print("\n[Query 3] Recent transaction flow:")
print("-" * 50)

query3 = """
MATCH (from:Account)-[:PERFORMS]->(t:Transaction)-[:BENEFITS_TO]->(to:Account)
RETURN
    from.account_id AS from_account,
    t.transaction_id AS transaction_id,
    t.amount AS amount,
    t.transaction_date AS date,
    to.account_id AS to_account
ORDER BY t.transaction_date DESC
LIMIT 5
"""
display(run_cypher(query3))

In [None]:
# Sample query 4: Most popular stocks
print("\n[Query 4] Most widely held stocks:")
print("-" * 50)

query4 = """
MATCH (a:Account)-[:HAS_POSITION]->(p:Position)-[:OF_SECURITY]->(s:Stock)-[:OF_COMPANY]->(c:Company)
RETURN
    c.name AS company_name,
    s.ticker AS ticker,
    count(DISTINCT a) AS num_holders,
    sum(p.shares) AS total_shares_held,
    round(sum(p.current_value), 2) AS total_market_value
ORDER BY num_holders DESC
LIMIT 5
"""
display(run_cypher(query4))

## Import Complete

The financial demo data has been successfully imported to Neo4j!

### Summary

**Nodes Created:**
- 102 Customers
- 102 Banks
- 123 Accounts
- 102 Companies
- 102 Stocks
- 110 Positions
- 123 Transactions

**Relationships Created:**
- 123 HAS_ACCOUNT (Customer -> Account)
- 123 AT_BANK (Account -> Bank)
- 102 OF_COMPANY (Stock -> Company)
- 123 PERFORMS (Account -> Transaction)
- 123 BENEFITS_TO (Transaction -> Account)
- 110 HAS_POSITION (Account -> Position)
- 110 OF_SECURITY (Position -> Stock)

### Next Steps

1. **Explore with Neo4j Browser** - Visualize the graph
2. **Run Graph Algorithms** - PageRank, Community Detection
3. **Build Applications** - Portfolio dashboards, risk analysis
4. **Extend the Schema** - Add Advisors, Branches, Products

---

See `DATA_IMPORT.md` for additional query examples and use cases.