# Financial Demo Data Import to Neo4j

This notebook imports the retail banking and investment portfolio demonstration data from Databricks Unity Catalog into Neo4j.

## Prerequisites

1. **Neo4j database** running (Aura or self-hosted)
2. **Databricks Secrets** configured with `neo4j-creds` scope containing `username`, `password`, and `url`
3. **Unity Catalog Volume** with CSV files uploaded
4. **Databricks Cluster** with Neo4j Spark Connector installed

## Data Overview

- **102 Customers** with demographics and financial profiles
- **102 Banks** across multiple types (commercial, regional, community)
- **123 Accounts** (checking, savings, investment)
- **102 Companies** across 12+ sectors
- **102 Stocks** with market data
- **110 Portfolio Positions** linking accounts to stocks
- **123 Transactions** between accounts

---

## Step 1: Configuration

Configure Neo4j connection and data source paths. All credentials are retrieved from Databricks Secrets.

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================
import time

print("=" * 70)
print("CONFIGURATION - Loading secrets from Databricks")
print("=" * 70)
print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print("")

# Retrieve Neo4j credentials from Databricks Secrets
print("[DEBUG] Retrieving secrets from scope 'neo4j-creds'...")

try:
    NEO4J_USER = dbutils.secrets.get(scope="neo4j-creds", key="username")
    print(f"  [OK] username: retrieved ({len(NEO4J_USER)} chars)")
except Exception as e:
    print(f"  [FAIL] username: {str(e)}")
    raise

try:
    NEO4J_PASS = dbutils.secrets.get(scope="neo4j-creds", key="password")
    print(f"  [OK] password: retrieved ({len(NEO4J_PASS)} chars, masked)")
except Exception as e:
    print(f"  [FAIL] password: {str(e)}")
    raise

try:
    NEO4J_URL = dbutils.secrets.get(scope="neo4j-creds", key="url")
    print(f"  [OK] url: {NEO4J_URL}")
except Exception as e:
    print(f"  [FAIL] url: {str(e)}")
    raise

try:
    VOLUME_PATH = dbutils.secrets.get(scope="neo4j-creds", key="volume_path")
    print(f"  [OK] volume_path: {VOLUME_PATH}")
except Exception as e:
    print(f"  [FAIL] volume_path: {str(e)}")
    raise

# Neo4j database name (default: neo4j)
NEO4J_DATABASE = "neo4j"

print("")
print("[DEBUG] Configuring Spark session for Neo4j connector...")
try:
    spark.conf.set("neo4j.url", NEO4J_URL)
    spark.conf.set("neo4j.authentication.basic.username", NEO4J_USER)
    spark.conf.set("neo4j.authentication.basic.password", NEO4J_PASS)
    spark.conf.set("neo4j.database", NEO4J_DATABASE)
    print("  [OK] Spark session configured")
except Exception as e:
    print(f"  [FAIL] Spark configuration: {str(e)}")
    raise

print("")
print("=" * 70)
print("CONFIGURATION SUMMARY")
print("=" * 70)
print(f"  Neo4j URL:    {NEO4J_URL}")
print(f"  Database:     {NEO4J_DATABASE}")
print(f"  Username:     {NEO4J_USER}")
print(f"  Volume Path:  {VOLUME_PATH}")
print("=" * 70)

## Step 2: Verify Prerequisites

Verify Neo4j connectivity and CSV file availability before proceeding.

In [None]:
# =============================================================================
# VERIFY CSV FILES IN UNITY CATALOG VOLUME
# =============================================================================
print("=" * 70)
print("FILE VERIFICATION - Checking CSV files in Unity Catalog Volume")
print("=" * 70)

# CSV files are in the /csv subdirectory of the volume
CSV_PATH = f"{VOLUME_PATH}/csv"
print(f"CSV path: {CSV_PATH}")
print("")

expected_files = [
    "customers.csv",
    "banks.csv",
    "accounts.csv",
    "companies.csv",
    "stocks.csv",
    "portfolio_holdings.csv",
    "transactions.csv"
]

print(f"[DEBUG] Expected files: {len(expected_files)}")
for f in expected_files:
    print(f"  - {f}")
print("")

# Try listing files
print(f"[DEBUG] Calling dbutils.fs.ls('{CSV_PATH}')...")
try:
    files = dbutils.fs.ls(CSV_PATH)
    print(f"  [OK] Listed {len(files)} items")
    print("")
    
    # Show raw file info for debugging
    print("[DEBUG] Raw file listing from dbutils.fs.ls():")
    print("-" * 70)
    for i, f in enumerate(files):
        print(f"  [{i}] name: {f.name}, size: {f.size} bytes")
    print("-" * 70)
    print("")
    
    # Extract filenames - handle various formats
    found_files = []
    for f in files:
        name = f.name.rstrip('/')
        if '/' in name:
            name = name.split('/')[-1]
        found_files.append(name)
    
    # Check for expected files
    print("[DEBUG] Checking for expected files:")
    print("-" * 70)
    all_present = True
    for expected in expected_files:
        found = expected in found_files
        status = "[OK]  " if found else "[MISSING]"
        if not found:
            all_present = False
        print(f"  {status} {expected}")
    print("-" * 70)
    
    if all_present:
        print("\n[OK] All required CSV files are present!")
    else:
        print("\n[WARNING] Some files are missing!")
        print("Found files:", found_files)
        
except Exception as e:
    print(f"  [FAIL] Error listing files: {type(e).__name__}: {str(e)}")
    print("")
    print("[WARNING] Could not list files. Will attempt to read files directly.")

In [None]:
# =============================================================================
# VERIFY NEO4J CONNECTIVITY
# =============================================================================
print("=" * 70)
print("NEO4J CONNECTION TEST")
print("=" * 70)
print(f"URL: {NEO4J_URL}")
print(f"Database: {NEO4J_DATABASE}")
print(f"Username: {NEO4J_USER}")
print("")

print("[DEBUG] Building Spark DataSource read...")
print("  Format: org.neo4j.spark.DataSource")
print(f"  Query: RETURN 'Connected' AS status")
print("")

try:
    print("[DEBUG] Executing connection test...")
    start_time = time.time()
    
    test_df = (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("url", NEO4J_URL)
        .option("authentication.basic.username", NEO4J_USER)
        .option("authentication.basic.password", NEO4J_PASS)
        .option("database", NEO4J_DATABASE)
        .option("query", "RETURN 'Connected' AS status")
        .load()
    )
    
    result = test_df.collect()
    elapsed = time.time() - start_time
    
    print(f"  [OK] Query executed in {elapsed:.2f}s")
    print(f"  [OK] Result: {result}")
    print("")
    print("=" * 70)
    print("[OK] NEO4J CONNECTION SUCCESSFUL!")
    print("=" * 70)
    
except Exception as e:
    print(f"  [FAIL] Connection failed!")
    print(f"  Error type: {type(e).__name__}")
    print(f"  Error message: {str(e)}")
    print("")
    print("[DEBUG] Troubleshooting tips:")
    print("  1. Verify Neo4j database is running")
    print("  2. Check URL format:")
    print("     - Aura: neo4j+s://xxxxx.databases.neo4j.io")
    print("     - Self-hosted with TLS: neo4j+s://host:7687")
    print("     - Self-hosted no TLS: bolt://host:7687")
    print("  3. Verify credentials are correct")
    print("  4. Check network connectivity (firewall, VPC)")
    print("  5. Ensure Neo4j Spark Connector is installed on cluster")
    raise

In [None]:
# =============================================================================
# CLEAR EXISTING DATABASE
# =============================================================================
from neo4j import GraphDatabase

print("=" * 70)
print("DATABASE CLEANUP - Removing existing nodes and relationships")
print("=" * 70)
print("")

print("[DEBUG] Deleting all nodes and relationships...")
print("        Query: MATCH (n) DETACH DELETE n")
print("")

try:
    start_time = time.time()

    # Use neo4j Python driver for direct Cypher execution
    driver = GraphDatabase.driver(
        NEO4J_URL,
        auth=(NEO4J_USER, NEO4J_PASS)
    )

    with driver.session(database=NEO4J_DATABASE) as session:
        result = session.run("MATCH (n) DETACH DELETE n")
        summary = result.consume()
        nodes_deleted = summary.counters.nodes_deleted
        rels_deleted = summary.counters.relationships_deleted

    driver.close()

    elapsed = time.time() - start_time
    print(f"  [OK] Deleted {nodes_deleted} nodes and {rels_deleted} relationships")
    print(f"  [OK] Completed in {elapsed:.2f}s")
    print("")
    print("=" * 70)
    print("[OK] DATABASE CLEANUP COMPLETE!")
    print("=" * 70)

except Exception as e:
    print(f"  [FAIL] Cleanup failed: {type(e).__name__}")
    print(f"         {str(e)[:200]}")
    raise

## Step 3: Helper Functions

Define reusable functions for data loading and Neo4j operations.

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType, DateType


def read_csv(filename: str) -> DataFrame:
    """Read a CSV file from the Unity Catalog Volume /csv subdirectory."""
    path = f"{CSV_PATH}/{filename}"
    return spark.read.option("header", "true").csv(path)


def write_nodes(df: DataFrame, label: str, node_key: str) -> dict:
    """Write DataFrame rows as nodes to Neo4j with debug logging."""
    count = df.count()
    print(f"[DEBUG] write_nodes(label={label}, key={node_key}, count={count})")
    
    try:
        start_time = time.time()
        (
            df.write.format("org.neo4j.spark.DataSource")
            .mode("Append")
            .option("labels", f":{label}")
            .option("node.keys", node_key)
            .save()
        )
        elapsed = time.time() - start_time
        print(f"  [OK] {count} {label} nodes written in {elapsed:.2f}s")
        return {"status": "OK", "count": count, "elapsed": elapsed}
    except Exception as e:
        print(f"  [FAIL] Error writing {label} nodes: {type(e).__name__}")
        print(f"         {str(e)[:200]}")
        raise


def write_relationship(
    df: DataFrame,
    rel_type: str,
    source_label: str,
    source_key: str,
    target_label: str,
    target_key: str
) -> dict:
    """Write DataFrame rows as relationships to Neo4j with debug logging."""
    count = df.count()
    print(f"[DEBUG] write_relationship(type={rel_type})")
    print(f"        Pattern: (:{source_label})-[:{rel_type}]->(:{target_label})")
    print(f"        Keys: {source_key} -> {target_key}, count={count}")
    
    try:
        start_time = time.time()
        (
            df.write.format("org.neo4j.spark.DataSource")
            .mode("Append")
            .option("relationship", rel_type)
            .option("relationship.save.strategy", "keys")
            .option("relationship.source.save.mode", "Match")
            .option("relationship.source.labels", f":{source_label}")
            .option("relationship.source.node.keys", f"{source_key}:{source_key}")
            .option("relationship.target.save.mode", "Match")
            .option("relationship.target.labels", f":{target_label}")
            .option("relationship.target.node.keys", f"{target_key}:{target_key}")
            .save()
        )
        elapsed = time.time() - start_time
        print(f"  [OK] {count} {rel_type} relationships written in {elapsed:.2f}s")
        return {"status": "OK", "count": count, "elapsed": elapsed}
    except Exception as e:
        print(f"  [FAIL] Error writing {rel_type} relationships: {type(e).__name__}")
        print(f"         {str(e)[:200]}")
        raise


def run_cypher(query: str, debug: bool = False) -> DataFrame:
    """Execute a Cypher query and return results as DataFrame."""
    if debug:
        print(f"[DEBUG] run_cypher: {query[:100]}...")
    return (
        spark.read.format("org.neo4j.spark.DataSource")
        .option("url", NEO4J_URL)
        .option("authentication.basic.username", NEO4J_USER)
        .option("authentication.basic.password", NEO4J_PASS)
        .option("database", NEO4J_DATABASE)
        .option("query", query)
        .load()
    )


print("=" * 70)
print("HELPER FUNCTIONS LOADED")
print("=" * 70)
print("Functions available:")
print("  - read_csv(filename) -> DataFrame")
print("  - write_nodes(df, label, node_key) -> dict")
print("  - write_relationship(df, rel_type, src_label, src_key, tgt_label, tgt_key) -> dict")
print("  - run_cypher(query, debug=False) -> DataFrame")
print("=" * 70)

## Step 4: Load and Transform CSV Data

Load all CSV files and apply appropriate data type conversions.

In [None]:
# =============================================================================
# LOAD CSV FILES FROM UNITY CATALOG VOLUME
# =============================================================================
print("=" * 70)
print("DATA LOADING - Reading CSV files from Unity Catalog Volume")
print("=" * 70)
print(f"CSV path: {CSV_PATH}")
print("")

def load_csv_with_debug(filename: str) -> "DataFrame":
    """Load CSV with detailed debug logging."""
    path = f"{CSV_PATH}/{filename}"
    print(f"[DEBUG] Loading: {filename}")
    print(f"        Path: {path}")
    
    try:
        start_time = time.time()
        df = spark.read.option("header", "true").csv(path)
        count = df.count()
        elapsed = time.time() - start_time
        print(f"  [OK] Loaded {count} rows in {elapsed:.2f}s")
        print(f"       Columns: {df.columns}")
        return df
    except Exception as e:
        print(f"  [FAIL] Error: {type(e).__name__}: {str(e)[:100]}")
        raise

# Load all CSVs with debug output
print("-" * 70)
customers_raw = load_csv_with_debug("customers.csv")
print("")

banks_raw = load_csv_with_debug("banks.csv")
print("")

accounts_raw = load_csv_with_debug("accounts.csv")
print("")

companies_raw = load_csv_with_debug("companies.csv")
print("")

stocks_raw = load_csv_with_debug("stocks.csv")
print("")

positions_raw = load_csv_with_debug("portfolio_holdings.csv")
print("")

transactions_raw = load_csv_with_debug("transactions.csv")
print("-" * 70)

print("")
print("=" * 70)
print("[OK] ALL CSV FILES LOADED SUCCESSFULLY!")
print("=" * 70)
print("")
print("Summary:")
print(f"  customers.csv:          {customers_raw.count()} rows")
print(f"  banks.csv:              {banks_raw.count()} rows")
print(f"  accounts.csv:           {accounts_raw.count()} rows")
print(f"  companies.csv:          {companies_raw.count()} rows")
print(f"  stocks.csv:             {stocks_raw.count()} rows")
print(f"  portfolio_holdings.csv: {positions_raw.count()} rows")
print(f"  transactions.csv:       {transactions_raw.count()} rows")

In [None]:
# =============================================================================
# DATA TYPE CONVERSIONS
# =============================================================================
print("=" * 70)
print("DATA TRANSFORMATION - Applying data type conversions")
print("=" * 70)
print("")

def transform_with_debug(df: DataFrame, name: str, transformations: list) -> DataFrame:
    """Apply transformations with debug logging."""
    print(f"[DEBUG] Transforming: {name}")
    result = df
    for col_name, transform_type in transformations:
        try:
            if transform_type == "INT":
                result = result.withColumn(col_name, F.col(col_name).cast(IntegerType()))
            elif transform_type == "DOUBLE":
                result = result.withColumn(col_name, F.col(col_name).cast(DoubleType()))
            elif transform_type == "DATE":
                result = result.withColumn(col_name, F.to_date(F.col(col_name)))
            print(f"  [OK] {col_name} -> {transform_type}")
        except Exception as e:
            print(f"  [FAIL] {col_name}: {str(e)[:50]}")
    return result

# Transform Customers
customers = transform_with_debug(customers_raw, "Customers", [
    ("annual_income", "INT"),
    ("credit_score", "INT"),
    ("registration_date", "DATE"),
    ("date_of_birth", "DATE"),
])
print("")

# Transform Banks
banks = transform_with_debug(banks_raw, "Banks", [
    ("total_assets_billions", "DOUBLE"),
    ("established_year", "INT"),
])
print("")

# Transform Accounts
accounts = transform_with_debug(accounts_raw, "Accounts", [
    ("balance", "DOUBLE"),
    ("interest_rate", "DOUBLE"),
    ("opened_date", "DATE"),
])
print("")

# Transform Companies
companies = transform_with_debug(companies_raw, "Companies", [
    ("market_cap_billions", "DOUBLE"),
    ("annual_revenue_billions", "DOUBLE"),
    ("founded_year", "INT"),
    ("employee_count", "INT"),
])
print("")

# Transform Stocks
stocks = transform_with_debug(stocks_raw, "Stocks", [
    ("current_price", "DOUBLE"),
    ("previous_close", "DOUBLE"),
    ("opening_price", "DOUBLE"),
    ("day_high", "DOUBLE"),
    ("day_low", "DOUBLE"),
    ("volume", "INT"),
    ("market_cap_billions", "DOUBLE"),
    ("pe_ratio", "DOUBLE"),
    ("dividend_yield", "DOUBLE"),
    ("fifty_two_week_high", "DOUBLE"),
    ("fifty_two_week_low", "DOUBLE"),
])
print("")

# Transform Positions (rename holding_id to position_id)
print("[DEBUG] Transforming: Positions")
print("  [OK] holding_id -> position_id (renamed)")
positions = positions_raw.withColumnRenamed("holding_id", "position_id")
positions = transform_with_debug(positions, "Positions (continued)", [
    ("shares", "INT"),
    ("purchase_price", "DOUBLE"),
    ("current_value", "DOUBLE"),
    ("percentage_of_portfolio", "DOUBLE"),
    ("purchase_date", "DATE"),
])
print("")

# Transform Transactions
transactions = transform_with_debug(transactions_raw, "Transactions", [
    ("amount", "DOUBLE"),
    ("transaction_date", "DATE"),
])

print("")
print("=" * 70)
print("[OK] ALL DATA TRANSFORMATIONS COMPLETE!")
print("=" * 70)

## Step 5: Create Indexes and Constraints

Create indexes and uniqueness constraints BEFORE loading data for optimal performance.

**Best Practice**: Creating indexes first significantly improves write performance for large datasets and ensures data integrity.

In [None]:
# =============================================================================
# CREATE INDEXES AND CONSTRAINTS
# =============================================================================
from neo4j import GraphDatabase

print("=" * 70)
print("SCHEMA SETUP - Creating indexes and constraints in Neo4j")
print("=" * 70)
print("")
print("[DEBUG] Creating uniqueness constraints...")
print("        (If constraint already exists, it will be skipped)")
print("")

# Define constraints
constraints = [
    ("customer_id_unique", "Customer", "customer_id"),
    ("bank_id_unique", "Bank", "bank_id"),
    ("account_id_unique", "Account", "account_id"),
    ("company_id_unique", "Company", "company_id"),
    ("stock_id_unique", "Stock", "stock_id"),
    ("position_id_unique", "Position", "position_id"),
    ("transaction_id_unique", "Transaction", "transaction_id"),
]

# Use neo4j Python driver for DDL operations
driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USER, NEO4J_PASS))

success_count = 0
skip_count = 0
fail_count = 0

with driver.session(database=NEO4J_DATABASE) as session:
    for constraint_name, label, property_name in constraints:
        query = f"""
        CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
        FOR (n:{label})
        REQUIRE n.{property_name} IS UNIQUE
        """
        print(f"[DEBUG] Creating: {constraint_name}")
        print(f"        FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE")
        
        try:
            start_time = time.time()
            session.run(query)
            elapsed = time.time() - start_time
            print(f"  [OK] Created/verified in {elapsed:.2f}s")
            success_count += 1
        except Exception as e:
            error_msg = str(e)
            if "already exists" in error_msg.lower():
                print(f"  [SKIP] Already exists")
                skip_count += 1
            else:
                print(f"  [FAIL] {type(e).__name__}: {error_msg[:100]}")
                fail_count += 1
        print("")

driver.close()

print("=" * 70)
print("SCHEMA SETUP SUMMARY")
print("=" * 70)
print(f"  Created/Verified: {success_count}")
print(f"  Skipped (exists): {skip_count}")
print(f"  Failed:           {fail_count}")

if fail_count > 0:
    print("\n[WARNING] Some constraints failed to create!")
else:
    print("\n[OK] All constraints ready!")
print("=" * 70)

## Step 6: Write Nodes to Neo4j

Write all node types to Neo4j. The order doesn't matter since we're creating nodes first, then relationships.

**Graph Schema:**
- Customer (102 nodes)
- Bank (102 nodes)
- Account (123 nodes)
- Company (102 nodes)
- Stock (102 nodes)
- Position (110 nodes)
- Transaction (123 nodes)

In [None]:
# =============================================================================
# WRITE NODES TO NEO4J
# =============================================================================
print("=" * 70)
print("NODE CREATION - Writing nodes to Neo4j")
print("=" * 70)
print("")

node_results = {}
total_start = time.time()

# Customer nodes
print("[1/7] CUSTOMER NODES")
print("-" * 40)
node_results["Customer"] = write_nodes(customers, "Customer", "customer_id")
print("")

# Bank nodes
print("[2/7] BANK NODES")
print("-" * 40)
node_results["Bank"] = write_nodes(banks, "Bank", "bank_id")
print("")

# Account nodes (exclude foreign keys)
print("[3/7] ACCOUNT NODES")
print("-" * 40)
print("[DEBUG] Selecting columns (excluding foreign keys)...")
account_props = accounts.select(
    "account_id", "account_number", "account_type", 
    "balance", "currency", "opened_date", "status", "interest_rate"
)
print(f"        Columns: {account_props.columns}")
node_results["Account"] = write_nodes(account_props, "Account", "account_id")
print("")

# Company nodes
print("[4/7] COMPANY NODES")
print("-" * 40)
node_results["Company"] = write_nodes(companies, "Company", "company_id")
print("")

# Stock nodes (exclude foreign key)
print("[5/7] STOCK NODES")
print("-" * 40)
print("[DEBUG] Selecting columns (excluding foreign keys)...")
stock_props = stocks.select(
    "stock_id", "ticker", "current_price", "previous_close", "opening_price",
    "day_high", "day_low", "volume", "market_cap_billions", "pe_ratio",
    "dividend_yield", "fifty_two_week_high", "fifty_two_week_low", "exchange"
)
print(f"        Columns: {stock_props.columns}")
node_results["Stock"] = write_nodes(stock_props, "Stock", "stock_id")
print("")

# Position nodes (exclude foreign keys)
print("[6/7] POSITION NODES")
print("-" * 40)
print("[DEBUG] Selecting columns (excluding foreign keys)...")
position_props = positions.select(
    "position_id", "shares", "purchase_price", "purchase_date",
    "current_value", "percentage_of_portfolio"
)
print(f"        Columns: {position_props.columns}")
node_results["Position"] = write_nodes(position_props, "Position", "position_id")
print("")

# Transaction nodes (exclude foreign keys)
print("[7/7] TRANSACTION NODES")
print("-" * 40)
print("[DEBUG] Selecting columns (excluding foreign keys)...")
transaction_props = transactions.select(
    "transaction_id", "amount", "currency", "transaction_date",
    "transaction_time", "type", "status", "description"
)
print(f"        Columns: {transaction_props.columns}")
node_results["Transaction"] = write_nodes(transaction_props, "Transaction", "transaction_id")

total_elapsed = time.time() - total_start

print("")
print("=" * 70)
print("NODE CREATION SUMMARY")
print("=" * 70)
print("")
print(f"{'Label':<15} {'Count':>10} {'Time':>10} {'Status':>10}")
print("-" * 50)
for label, result in node_results.items():
    print(f"{label:<15} {result['count']:>10} {result['elapsed']:>9.2f}s {'[OK]':>10}")
print("-" * 50)
total_nodes = sum(r['count'] for r in node_results.values())
print(f"{'TOTAL':<15} {total_nodes:>10} {total_elapsed:>9.2f}s")
print("")
print("=" * 70)
print("[OK] ALL NODES WRITTEN SUCCESSFULLY!")
print("=" * 70)

## Step 7: Write Relationships to Neo4j

Create all relationships between nodes using key matching.

**Relationship Types:**
1. `(:Customer)-[:HAS_ACCOUNT]->(:Account)` - Customer owns account
2. `(:Account)-[:AT_BANK]->(:Bank)` - Account held at bank
3. `(:Stock)-[:OF_COMPANY]->(:Company)` - Stock issued by company
4. `(:Account)-[:PERFORMS]->(:Transaction)` - Account initiates transfer
5. `(:Transaction)-[:BENEFITS_TO]->(:Account)` - Account receives funds
6. `(:Account)-[:HAS_POSITION]->(:Position)` - Account holds position
7. `(:Position)-[:OF_SECURITY]->(:Stock)` - Position is in specific stock

In [None]:
# =============================================================================
# WRITE RELATIONSHIPS TO NEO4J
# =============================================================================
print("=" * 70)
print("RELATIONSHIP CREATION - Writing relationships to Neo4j")
print("=" * 70)
print("")

rel_results = {}
total_start = time.time()

# 1. HAS_ACCOUNT: Customer -> Account
print("[1/7] HAS_ACCOUNT RELATIONSHIPS")
print("-" * 40)
has_account_df = accounts.select("customer_id", "account_id")
rel_results["HAS_ACCOUNT"] = write_relationship(
    has_account_df, "HAS_ACCOUNT",
    "Customer", "customer_id",
    "Account", "account_id"
)
print("")

# 2. AT_BANK: Account -> Bank
print("[2/7] AT_BANK RELATIONSHIPS")
print("-" * 40)
at_bank_df = accounts.select("account_id", "bank_id")
rel_results["AT_BANK"] = write_relationship(
    at_bank_df, "AT_BANK",
    "Account", "account_id",
    "Bank", "bank_id"
)
print("")

# 3. OF_COMPANY: Stock -> Company
print("[3/7] OF_COMPANY RELATIONSHIPS")
print("-" * 40)
of_company_df = stocks.select("stock_id", "company_id")
rel_results["OF_COMPANY"] = write_relationship(
    of_company_df, "OF_COMPANY",
    "Stock", "stock_id",
    "Company", "company_id"
)
print("")

# 4. PERFORMS: Account -> Transaction (from_account initiates)
print("[4/7] PERFORMS RELATIONSHIPS")
print("-" * 40)
print("[DEBUG] Aliasing from_account_id -> account_id")
performs_df = transactions.select(
    F.col("from_account_id").alias("account_id"),
    "transaction_id"
)
rel_results["PERFORMS"] = write_relationship(
    performs_df, "PERFORMS",
    "Account", "account_id",
    "Transaction", "transaction_id"
)
print("")

# 5. BENEFITS_TO: Transaction -> Account (to_account receives)
print("[5/7] BENEFITS_TO RELATIONSHIPS")
print("-" * 40)
print("[DEBUG] Aliasing to_account_id -> account_id")
benefits_df = transactions.select(
    "transaction_id",
    F.col("to_account_id").alias("account_id")
)
rel_results["BENEFITS_TO"] = write_relationship(
    benefits_df, "BENEFITS_TO",
    "Transaction", "transaction_id",
    "Account", "account_id"
)
print("")

# 6. HAS_POSITION: Account -> Position
print("[6/7] HAS_POSITION RELATIONSHIPS")
print("-" * 40)
has_position_df = positions.select("account_id", "position_id")
rel_results["HAS_POSITION"] = write_relationship(
    has_position_df, "HAS_POSITION",
    "Account", "account_id",
    "Position", "position_id"
)
print("")

# 7. OF_SECURITY: Position -> Stock
print("[7/7] OF_SECURITY RELATIONSHIPS")
print("-" * 40)
of_security_df = positions.select("position_id", "stock_id")
rel_results["OF_SECURITY"] = write_relationship(
    of_security_df, "OF_SECURITY",
    "Position", "position_id",
    "Stock", "stock_id"
)

total_elapsed = time.time() - total_start

print("")
print("=" * 70)
print("RELATIONSHIP CREATION SUMMARY")
print("=" * 70)
print("")
print(f"{'Type':<20} {'Count':>10} {'Time':>10} {'Status':>10}")
print("-" * 55)
for rel_type, result in rel_results.items():
    print(f"{rel_type:<20} {result['count']:>10} {result['elapsed']:>9.2f}s {'[OK]':>10}")
print("-" * 55)
total_rels = sum(r['count'] for r in rel_results.values())
print(f"{'TOTAL':<20} {total_rels:>10} {total_elapsed:>9.2f}s")
print("")
print("=" * 70)
print("[OK] ALL RELATIONSHIPS WRITTEN SUCCESSFULLY!")
print("=" * 70)

## Step 8: Validate Import

Run validation queries to verify the import completed correctly.

In [None]:
# =============================================================================
# VALIDATE NODE COUNTS
# =============================================================================
print("=" * 70)
print("VALIDATION - Node counts")
print("=" * 70)
print("")

# Expected counts
expected_nodes = {
    "Customer": 102,
    "Bank": 102,
    "Account": 123,
    "Company": 102,
    "Stock": 102,
    "Position": 110,
    "Transaction": 123
}

# Query actual counts
print("[DEBUG] Running node count query...")
node_count_query = """
MATCH (n)
RETURN labels(n)[0] AS label, count(n) AS count
ORDER BY label
"""

try:
    start_time = time.time()
    node_counts = run_cypher(node_count_query).collect()
    elapsed = time.time() - start_time
    print(f"  [OK] Query completed in {elapsed:.2f}s")
    print("")
    
    print(f"{'Label':<15} {'Expected':>10} {'Actual':>10} {'Status':>10}")
    print("-" * 50)
    
    all_valid = True
    for row in node_counts:
        label = row["label"]
        actual = row["count"]
        expected = expected_nodes.get(label, "N/A")
        if expected == "N/A":
            status = "[?]"
        elif actual == expected:
            status = "[OK]"
        else:
            status = "[MISMATCH]"
            all_valid = False
        print(f"{label:<15} {expected:>10} {actual:>10} {status:>10}")
    
    print("-" * 50)
    if all_valid:
        print("\n[OK] Node validation PASSED!")
    else:
        print("\n[WARNING] Node validation FAILED - check mismatched counts")
        
except Exception as e:
    print(f"  [FAIL] Validation query failed: {type(e).__name__}")
    print(f"         {str(e)[:200]}")

In [None]:
# =============================================================================
# VALIDATE RELATIONSHIP COUNTS
# =============================================================================
print("=" * 70)
print("VALIDATION - Relationship counts")
print("=" * 70)
print("")

# Expected counts
expected_rels = {
    "HAS_ACCOUNT": 123,
    "AT_BANK": 123,
    "OF_COMPANY": 102,
    "PERFORMS": 123,
    "BENEFITS_TO": 123,
    "HAS_POSITION": 110,
    "OF_SECURITY": 110
}

# Query actual counts
print("[DEBUG] Running relationship count query...")
rel_count_query = """
MATCH ()-[r]->()
RETURN type(r) AS relationship_type, count(r) AS count
ORDER BY relationship_type
"""

try:
    start_time = time.time()
    rel_counts = run_cypher(rel_count_query).collect()
    elapsed = time.time() - start_time
    print(f"  [OK] Query completed in {elapsed:.2f}s")
    print("")
    
    print(f"{'Relationship':<20} {'Expected':>10} {'Actual':>10} {'Status':>10}")
    print("-" * 55)
    
    all_valid = True
    for row in rel_counts:
        rel_type = row["relationship_type"]
        actual = row["count"]
        expected = expected_rels.get(rel_type, "N/A")
        if expected == "N/A":
            status = "[?]"
        elif actual == expected:
            status = "[OK]"
        else:
            status = "[MISMATCH]"
            all_valid = False
        print(f"{rel_type:<20} {expected:>10} {actual:>10} {status:>10}")
    
    print("-" * 55)
    if all_valid:
        print("\n[OK] Relationship validation PASSED!")
    else:
        print("\n[WARNING] Relationship validation FAILED - check mismatched counts")
        
except Exception as e:
    print(f"  [FAIL] Validation query failed: {type(e).__name__}")
    print(f"         {str(e)[:200]}")

In [None]:
# =============================================================================
# SAMPLE QUERIES - Customer Profile
# =============================================================================
print("=" * 70)
print("SAMPLE QUERY 1 - Customer C0001's complete financial profile")
print("=" * 70)
print("")

query1 = """
MATCH (c:Customer {customer_id: 'C0001'})-[:HAS_ACCOUNT]->(a:Account)
OPTIONAL MATCH (a)-[:AT_BANK]->(b:Bank)
OPTIONAL MATCH (a)-[:HAS_POSITION]->(p:Position)-[:OF_SECURITY]->(s:Stock)
RETURN
    c.first_name + ' ' + c.last_name AS customer_name,
    a.account_id AS account,
    a.account_type AS account_type,
    a.balance AS balance,
    b.name AS bank_name,
    s.ticker AS ticker,
    p.shares AS shares,
    p.current_value AS holding_value
ORDER BY holding_value DESC
LIMIT 5
"""

print("[DEBUG] Query:")
for line in query1.strip().split('\n'):
    print(f"  {line}")
print("")

try:
    start_time = time.time()
    result = run_cypher(query1)
    elapsed = time.time() - start_time
    print(f"[OK] Query completed in {elapsed:.2f}s")
    print("")
    display(result)
except Exception as e:
    print(f"[FAIL] Query failed: {type(e).__name__}")
    print(f"       {str(e)[:200]}")

In [None]:
# =============================================================================
# SAMPLE QUERIES - Top Banks by Deposits
# =============================================================================
print("=" * 70)
print("SAMPLE QUERY 2 - Top 5 banks by total deposits")
print("=" * 70)
print("")

query2 = """
MATCH (b:Bank)<-[:AT_BANK]-(a:Account)
RETURN
    b.name AS bank_name,
    b.bank_type AS bank_type,
    count(DISTINCT a) AS num_accounts,
    round(sum(a.balance), 2) AS total_deposits
ORDER BY total_deposits DESC
LIMIT 5
"""

print("[DEBUG] Query:")
for line in query2.strip().split('\n'):
    print(f"  {line}")
print("")

try:
    start_time = time.time()
    result = run_cypher(query2)
    elapsed = time.time() - start_time
    print(f"[OK] Query completed in {elapsed:.2f}s")
    print("")
    display(result)
except Exception as e:
    print(f"[FAIL] Query failed: {type(e).__name__}")
    print(f"       {str(e)[:200]}")

In [None]:
# =============================================================================
# SAMPLE QUERIES - Transaction Flow
# =============================================================================
print("=" * 70)
print("SAMPLE QUERY 3 - Recent transaction flow")
print("=" * 70)
print("")

query3 = """
MATCH (from:Account)-[:PERFORMS]->(t:Transaction)-[:BENEFITS_TO]->(to:Account)
RETURN
    from.account_id AS from_account,
    t.transaction_id AS transaction_id,
    t.amount AS amount,
    t.transaction_date AS date,
    to.account_id AS to_account
ORDER BY t.transaction_date DESC
LIMIT 5
"""

print("[DEBUG] Query:")
for line in query3.strip().split('\n'):
    print(f"  {line}")
print("")

try:
    start_time = time.time()
    result = run_cypher(query3)
    elapsed = time.time() - start_time
    print(f"[OK] Query completed in {elapsed:.2f}s")
    print("")
    display(result)
except Exception as e:
    print(f"[FAIL] Query failed: {type(e).__name__}")
    print(f"       {str(e)[:200]}")

In [None]:
# =============================================================================
# SAMPLE QUERIES - Most Popular Stocks
# =============================================================================
print("=" * 70)
print("SAMPLE QUERY 4 - Most widely held stocks")
print("=" * 70)
print("")

query4 = """
MATCH (a:Account)-[:HAS_POSITION]->(p:Position)-[:OF_SECURITY]->(s:Stock)-[:OF_COMPANY]->(c:Company)
RETURN
    c.name AS company_name,
    s.ticker AS ticker,
    count(DISTINCT a) AS num_holders,
    sum(p.shares) AS total_shares_held,
    round(sum(p.current_value), 2) AS total_market_value
ORDER BY num_holders DESC
LIMIT 5
"""

print("[DEBUG] Query:")
for line in query4.strip().split('\n'):
    print(f"  {line}")
print("")

try:
    start_time = time.time()
    result = run_cypher(query4)
    elapsed = time.time() - start_time
    print(f"[OK] Query completed in {elapsed:.2f}s")
    print("")
    display(result)
except Exception as e:
    print(f"[FAIL] Query failed: {type(e).__name__}")
    print(f"       {str(e)[:200]}")

## Import Complete

The financial demo data has been successfully imported to Neo4j!

### Summary

**Nodes Created:**
- 102 Customers
- 102 Banks
- 123 Accounts
- 102 Companies
- 102 Stocks
- 110 Positions
- 123 Transactions

**Relationships Created:**
- 123 HAS_ACCOUNT (Customer -> Account)
- 123 AT_BANK (Account -> Bank)
- 102 OF_COMPANY (Stock -> Company)
- 123 PERFORMS (Account -> Transaction)
- 123 BENEFITS_TO (Transaction -> Account)
- 110 HAS_POSITION (Account -> Position)
- 110 OF_SECURITY (Position -> Stock)

### Debug Logging

This notebook includes `[DEBUG]`, `[OK]`, and `[FAIL]` markers throughout to help identify issues:

- `[DEBUG]` - Informational messages about what operation is being attempted
- `[OK]` - Operation completed successfully
- `[FAIL]` - Operation failed (see error details)
- `[WARNING]` - Non-fatal issue that may need attention
- `[SKIP]` - Operation skipped (e.g., constraint already exists)

### Next Steps

1. **Explore with Neo4j Browser** - Visualize the graph
2. **Run Graph Algorithms** - PageRank, Community Detection
3. **Build Applications** - Portfolio dashboards, risk analysis
4. **Extend the Schema** - Add Advisors, Branches, Products

---

See `IMPORT.md` for setup instructions and troubleshooting.