In [None]:
# Databricks Notebook: Fraud_Detection_With_Graph_Embeddings.py

# COMMAND ----------
# MAGIC %md
# MAGIC # Fraud Detection with Neo4j Graph Embeddings (Serverless-Compatible)
# MAGIC This notebook connects to a Neo4j Aura instance, runs graph algorithms via GDS, and extracts provider-level embeddings for fraud prediction.

# COMMAND ----------
# 1. Setup SparkSession and Retrieve Secrets for Neo4j GDS

from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
from graphdatascience.session import DbmsConnectionInfo, GdsSessions, AuraAPICredentials, SessionMemory, CloudLocation
import os

# --- Configuration Constants ---
SECRET_SCOPE_NAME = "my-neo4j-scope"
GDS_SESSION_NAME = "gs-gds-session"
GDS_SESSION_MEMORY = SessionMemory.m_4GB
GDS_CLOUD_PROVIDER = "gcp"
GDS_CLOUD_REGION = "us-east-1"

print("--- Cell 1: SparkSession Initialization and Credential Retrieval ---")

try:
    neo4j_url = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="neo4j-url")
    neo4j_username = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="neo4j-username")
    neo4j_password = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="neo4j-password")
    neo4j_dbname = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="neo4j-database")

    aura_client_id = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="aura-gds-client-id")
    aura_client_secret = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="aura-gds-client-secret")
    aura_project_id = dbutils.secrets.get(scope=SECRET_SCOPE_NAME, key="aura-gds-project-id")

    print("Successfully retrieved all credentials from Databricks Secrets.")
except Exception as e:
    print(f"ERROR: Failed to retrieve secrets. Please ensure your secret scope '{SECRET_SCOPE_NAME}' exists and contains all required keys. Details: {e}")
    raise

# Initialize SparkSession
spark = (
    SparkSession.builder.appName("Neo4jGDSIntegration")
    .config("neo4j.url", neo4j_url)
    .config("neo4j.authentication.basic.username", neo4j_username)
    .config("neo4j.authentication.basic.password", neo4j_password)
    .config("neo4j.database", neo4j_dbname)
    .getOrCreate()
)

print("SparkSession initialized with Neo4j configurations.")
spark = (
    SparkSession.builder.appName("Neo4jGDSIntegration")
    .config("neo4j.url", neo4j_url)
    .config("neo4j.authentication.basic.username", neo4j_username)
    .config("neo4j.authentication.basic.password", neo4j_password)
    .config("neo4j.database", neo4j_dbname)
    .getOrCreate()
)

print("SparkSession initialized with Neo4j configuration.")

# COMMAND ----------
# 2. Initialize GDS Session

print("--- GDS Session Initialization ---")

# Initialize AuraAPICredentials with retrieved secrets
api_credentials = AuraAPICredentials(
    client_id=aura_client_id,
    client_secret=aura_client_secret,
    project_id=aura_project_id
)

sessions = GdsSessions(api_credentials=api_credentials)

# Create or Get GDS Session
try:
    gds = sessions.get_or_create(
        session_name=GDS_SESSION_NAME,
        memory=GDS_SESSION_MEMORY,
        db_connection=DbmsConnectionInfo(neo4j_url, neo4j_username, neo4j_password),
    )
    print(f"Successfully connected to GDS session: '{GDS_SESSION_NAME}'")
    print(f"GDS Version: {gds.version()}")

except Exception as e:
    print(f"ERROR: Failed to create or connect to GDS session '{GDS_SESSION_NAME}'. Details: {e}")
    raise
print("GDS session is active.")

# COMMAND ----------
# 5. Train Fraud Model with Augmented Features

print("📦 Projecting graph using Cypher query and running fastRP...")

# Drop graph if it already exists to avoid ALREADY_EXISTS error
if gds.graph.exists("provider_graph")["exists"]:
    gds.graph.drop("provider_graph")

G, result = gds.graph.project(
    graph_name="provider_graph",
    query="""
    CALL {
        MATCH (p:Provider)
        OPTIONAL MATCH (p)<-[r1:HAS_INPATIENT_CLAIM|HAS_OUTPATIENT_CLAIM]-(claim)
        RETURN p AS source, r1 AS rel, claim AS target, {} AS sourceNodeProperties, {} AS targetNodeProperties
    }
    RETURN gds.graph.project.remote(source, target, {
      sourceNodeProperties: sourceNodeProperties,
      targetNodeProperties: targetNodeProperties,
      sourceNodeLabels: labels(source),
      targetNodeLabels: labels(target),
      relationshipType: type(rel),
      relationshipProperties: properties(rel)
    })
    """,
)

print(f"Projected graph '{G.name}' with {G.node_count()} nodes.")

# Run fastRP
embedding_result = gds.fastRP.mutate(
    G,
    embeddingDimension=64,
    mutateProperty="embedding",
    iterationWeights=[0.8, 1, 1, 0.5]
)
print("fastRP embeddings generated.")

# Store embeddings in memory for export
embedding_df = gds.graph.nodeProperties.stream(G, "embedding")
print("Embedding properties streamed.")

# Convert to Pandas and then Spark DataFrame
import pandas as pd
pdf_embeddings = pd.DataFrame(embedding_df)
print("Embedding DataFrame preview:")
print(pdf_embeddings.head())

# Confirm available columns
print("Available columns:", pdf_embeddings.columns.tolist())
embedding_col_name = 'embedding' if 'embedding' in pdf_embeddings.columns else 'nodeProperty'
def safe_float_embedding(val):
    if isinstance(val, (list, tuple)) and all(isinstance(e, (int, float)) for e in val):
        return [float(e) for e in val]
    return []  # return empty embedding if malformed

pdf_embeddings[embedding_col_name] = pdf_embeddings[embedding_col_name].apply(safe_float_embedding)
pdf_embeddings.rename(columns={"nodeId": "ProviderID", embedding_col_name: "embedding"}, inplace=True)
pdf_embeddings["ProviderID"] = pdf_embeddings["ProviderID"].astype(str)
pdf_embeddings["ProviderID"] = pdf_embeddings["ProviderID"].str.replace("PRV", "", regex=False)
pdf_embeddings.rename(columns={"nodeId": "ProviderID"}, inplace=True)

import pandas as pd
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType

schema = StructType([
    StructField("ProviderID", StringType(), True),
    StructField("embedding", ArrayType(FloatType()), True)
])

spark_embeddings_df = spark.createDataFrame(pdf_embeddings, schema)
print("Embedding DataFrame created in Spark.")

# Load baseline features
baseline_df = spark.read.table("workspace.default.temp_provider_features_for_augmented_run")

# Join on ProviderID
augmented_df = baseline_df.join(spark_embeddings_df, on="ProviderID", how="inner")
print(f"Joined DataFrame has {augmented_df.count()} rows and {len(augmented_df.columns)} columns.")

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix

# Convert Spark DataFrame to Pandas for sklearn training
augmented_pdf = augmented_df.toPandas()

X_aug = augmented_pdf.drop(columns=["ProviderID", "PotentialFraud"])
y_aug = augmented_pdf["PotentialFraud"]

X_train, X_test, y_train, y_test = train_test_split(X_aug, y_aug, stratify=y_aug, random_state=42)

model_aug = RandomForestClassifier(n_estimators=100, random_state=42)
model_aug.fit(X_train, y_train)

y_pred = model_aug.predict(X_test)
y_proba = model_aug.predict_proba(X_test)[:, 1]

precision = precis
