# Load Edgy output from S3 to a databricks table

Use this notebook to load CashConnectedGraphChangeEvent data from avro files in edgy's S3 archive into a databricks table.

In [0]:
# Install Dependencies

%pip install --extra-index-url https://artifactory.global.square/artifactory/api/pypi/block-pypi/simple sq-protos-py
%pip install --extra-index-url https://artifactory.global.square/artifactory/api/pypi/block-pypi/simple block-cloud-auth
%pip install --extra-index-url https://artifactory.global.square/artifactory/api/pypi/block-pypi/simple kafka-python

Python interpreter will be restarted.
Looking in indexes: https://pypi.org/simple, https://artifactory.global.square/artifactory/api/pypi/block-pypi/simple
Collecting sq-protos-py
  Downloading https://artifactory.global.square/artifactory/api/pypi/block-pypi/sq-protos-py/20250701.2936/sq_protos_py-20250701.2936-py2.py3-none-any.whl (63.5 MB)
Collecting protobuf<=4.21.0,>=3.20.0
  Downloading https://artifactory.global.square/artifactory/api/pypi/block-pypi/packages/packages/c7/df/ec3ecb8c940b36121c7b77c10acebf3d1c736498aa2f1fe3b6231ee44e76/protobuf-3.20.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
Collecting redbaron
  Downloading https://artifactory.global.square/artifactory/api/pypi/block-pypi/packages/packages/d8/06/c1c97efe5d30593337721923c5813b3b4eaffcffb706e523acf3d3bc9e8c/redbaron-0.9.2-py2.py3-none-any.whl (34 kB)
Collecting baron>=0.7
  Downloading https://artifactory.global.square/artifactory/api/pypi/block-pypi/packages/packages/5b/e5/d0bff1cda8e5404a41ae

In [0]:
dbutils.library.restartPython()

In [0]:
from sq_protos_py.squareup.riskapi.wrapper import wrapper_pb2

# Parse byte_array to RiskInput
def deserialize_ri(ri_bytes):
    ri = wrapper_pb2.RiskInput()
    ri.ParseFromString(ri_bytes)
    return ri

In [0]:
# Name of the table we want to load the data into
# Backfill test 1:
#OUTPUT_TABLE = "cash_banking_ml_eng.cash_connected_graph_change_event.edgy_test"
# Backfill test 2:
#OUTPUT_TABLE = "cash_banking_ml_eng.cash_connected_graph_change_event.edgy_test_20250512"
# BAU test:
OUTPUT_TABLE = "cash_banking_ml_eng.cash_connected_graph_change_event.edgy_test_20250626"

# Path to the files we want to load. This will load a particular date
# Live output:
#S3_PATH = "s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=*"
# Backfill test 1:
# S3_PATH = "s3a://lakehouse-edgy-production-us-west-2/output_20250417/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=*"
# Backfill test 2:
#S3_PATH = "s3a://lakehouse-edgy-production-us-west-2/output_20250509/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=*"
# BAU test:
S3_PATH = "s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26"

In [0]:
#files = dbutils.fs.ls("s3a://lakehouse-edgy-production-us-west-2/output_20250509/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3")
files = dbutils.fs.ls(S3_PATH)
display(files)

path,name,size,modificationTime
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-041074d3-f65d-45d3-843f-bf082eefab35.c000.avro,part-00000-041074d3-f65d-45d3-843f-bf082eefab35.c000.avro,1826,1751047553000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-062796e6-31b2-4be7-aff4-58b26f5f407b.c000.avro,part-00000-062796e6-31b2-4be7-aff4-58b26f5f407b.c000.avro,3811555,1750950148000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-068d22f1-6d71-49cd-8fc2-1fcdd45ba248.c000.avro,part-00000-068d22f1-6d71-49cd-8fc2-1fcdd45ba248.c000.avro,615,1751248928000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-0fdfbec1-7648-4cb9-9a5f-b91af1162243.c000.avro,part-00000-0fdfbec1-7648-4cb9-9a5f-b91af1162243.c000.avro,752,1751137366000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-116b75f3-7e41-4922-a927-59e60f2b23a7.c000.avro,part-00000-116b75f3-7e41-4922-a927-59e60f2b23a7.c000.avro,1638,1751065342000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-11bbfcdb-c816-44d7-b894-7a63d1163d73.c000.avro,part-00000-11bbfcdb-c816-44d7-b894-7a63d1163d73.c000.avro,760,1751011310000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-1b803605-c714-4889-b95d-b17403fa6ad6.c000.avro,part-00000-1b803605-c714-4889-b95d-b17403fa6ad6.c000.avro,607,1751331740000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-1bd4e974-2a08-4954-8612-e18dab309b23.c000.avro,part-00000-1bd4e974-2a08-4954-8612-e18dab309b23.c000.avro,764,1751259742000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-1c620080-bd65-4070-8753-290ad68514ca.c000.avro,part-00000-1c620080-bd65-4070-8753-290ad68514ca.c000.avro,884,1751169900000
s3a://lakehouse-edgy-production-us-west-2/output/CASH_CONNECTED_GRAPH_CHANGE_EVENT/V3/date=2025-06-26/part-00000-1d24e4e4-5111-4f92-842b-d189959982af.c000.avro,part-00000-1d24e4e4-5111-4f92-842b-d189959982af.c000.avro,745,1751270493000


In [0]:
# Test that data loads correctly
df_s3_egress = spark.read.format("avro").load(S3_PATH)
raw_proto = df_s3_egress.first()[0]
ri_s3 = deserialize_ri(raw_proto)
ri_s3

Out[11]: event_id: "E/AH_7py1kneng/AH_v6y0a22jh/EMAIL"
message_type: CASH_CONNECTED_GRAPH_CHANGE_EVENT
event_time_millis: 1750968061985
cash_connected_graph_change_event {
  event_id: "E/AH_7py1kneng/AH_v6y0a22jh/EMAIL"
  event_type: CONNECTION_ADDED
  target_user_token: "AH_v6y0a22jh"
  source_user_token: "AH_7py1kneng"
  effective_at_millis: 1750968061985
  user_type: CASH_ACCOUNT_HOLDER
  connection_change {
    changed_node_type: EMAIL
    source_edge {
      from_node {
        type: CASH_CUSTOMER
        token: "C_gnenk1yp7"
      }
      to_node {
        type: EMAIL
        token: "1ca7755baf8c1579816ca07aacd11794f90a736a"
      }
      effective_at_msec: 1750968061985
      created_at: 1750968062549
      updated_at: 1750968062549
    }
  }
  event_source_type: BAU
  published_at_millis: 1750968065201
}

In [0]:
from pyspark.sql.types import StructType, StructField, StringType, LongType, ArrayType, BooleanType # Added BooleanType
from pyspark.sql import Row
from sq_protos_py.squareup.edgy.events_pb2 import CashConnectedGraphChangeEvent as CashConnectedGraphChangeEvent_pb2
from sq_protos_py.squareup.duplograph.data_pb2 import NodeType as NodeType_pb2
from sq_protos_py.squareup.duplograph.labels_pb2 import LabelType as LabelType_pb2
from sq_protos_py.squareup.duplograph.labels_pb2 import LabelStorageMethod as LabelStorageMethod_pb2

# Define the schema for NodeId (used in Edge and LabelEvent)
node_id_schema = StructType([
    StructField("type", StringType(), True),  # Enum as StringType (NodeType)
    StructField("token", StringType(), True)
])

# Define the schema for SourceEdge
source_edge_schema = StructType([
    StructField("from_node", node_id_schema, True),
    StructField("to_node", node_id_schema, True),
    StructField("effective_at_msec", LongType(), True),
    StructField("created_at", LongType(), True),
    StructField("updated_at", LongType(), True)
])

# Define the schema for SourceLabelEvent
source_label_event_schema = StructType([
    StructField("node", node_id_schema, True),
    StructField("label", StringType(), True),  # Enum as StringType (LabelType)
    StructField("effective_at_msec", LongType(), True),
    StructField("present", BooleanType(), True),
    StructField("created_at", LongType(), True),
    StructField("updated_at", LongType(), True),
    StructField("storage_method", StringType(), True)  # Enum as StringType (LabelStorageMethod)
])

# Define the schema for CashConnectedGraphChangeEvent
cash_connected_graph_change_event_schema = StructType([
    StructField("event_id", StringType(), True),
    StructField("event_type", StringType(), True),  # Enum as StringType
    StructField("target_user_token", StringType(), True),
    StructField("source_user_token", StringType(), True),
    StructField("effective_at_millis", LongType(), True),
    StructField("published_at_millis", LongType(), True),
    StructField("user_type", StringType(), True),  # Enum as StringType
    StructField("event_source_type", StringType(), True),  # Enum as StringType
    StructField("connection_change", StructType([
        StructField("changed_node_type", StringType(), True),
        StructField("source_user_labels", ArrayType(StringType()), True),
        StructField("target_user_labels", ArrayType(StringType()), True),
        StructField("source_edge", source_edge_schema, True)
    ]), True),
    StructField("label_change", StructType([
        StructField("connection_node_types", ArrayType(StringType()), True),
        StructField("changed_source_user_label", StringType(), True),
        StructField("source_label_event", source_label_event_schema, True)
    ]), True)
])

# Define the schema for RiskInput (updates based on nested cash_connected_graph_change_event_schema)
risk_input_schema = StructType([
    StructField("event_id", StringType(), True),
    StructField("event_time_millis", LongType(), True),
    StructField("cash_connected_graph_change_event", cash_connected_graph_change_event_schema, True)
])

# Helper function to map enum integers to strings (remains the same)
def get_enum_name(enum_class, value):
    return enum_class.Name(value)

# Broadcast the enum mappings
event_type_mapping = {v: k for k, v in CashConnectedGraphChangeEvent_pb2.EventType.items()}
user_type_mapping = {v: k for k, v in CashConnectedGraphChangeEvent_pb2.UserType.items()}
changed_node_type_mapping = {v: k for k, v in NodeType_pb2.items()} # Covers NodeType in NodeId
label_type_mapping = {v: k for k, v in LabelType_pb2.items()} # Covers LabelType in LabelEvent
event_source_type_mapping = {v: k for k, v in CashConnectedGraphChangeEvent_pb2.EventSourceType.items()}
label_storage_method_mapping = {v: k for k, v in LabelStorageMethod_pb2.items()}

broadcast_event_type_mapping = spark.sparkContext.broadcast(event_type_mapping)
broadcast_user_type_mapping = spark.sparkContext.broadcast(user_type_mapping)
broadcast_changed_node_type_mapping = spark.sparkContext.broadcast(changed_node_type_mapping)
broadcast_label_type_mapping = spark.sparkContext.broadcast(label_type_mapping)
broadcast_event_source_type_mapping = spark.sparkContext.broadcast(event_source_type_mapping)
broadcast_label_storage_method_mapping = spark.sparkContext.broadcast(label_storage_method_mapping)

# Function to convert a deserialized RiskInput to a Row object
def risk_input_to_row(ri_bytes):
    ri = deserialize_ri(ri_bytes)  # Assuming deserialize_ri is defined elsewhere and returns the protobuf object
    event = ri.cash_connected_graph_change_event

    connection_change_row = None
    if event.HasField("connection_change"):
        cc = event.connection_change
        source_edge_row = None
        if cc.HasField("source_edge"):
            se = cc.source_edge
            from_node_row = Row(
                type=broadcast_changed_node_type_mapping.value.get(se.from_node.type, "UNKNOWN"),
                token=se.from_node.token
            ) if se.HasField("from_node") else None
            to_node_row = Row(
                type=broadcast_changed_node_type_mapping.value.get(se.to_node.type, "UNKNOWN"),
                token=se.to_node.token
            ) if se.HasField("to_node") else None
            source_edge_row = Row(
                from_node=from_node_row,
                to_node=to_node_row,
                effective_at_msec=se.effective_at_msec,
                created_at=se.created_at,
                updated_at=se.updated_at
            )
        connection_change_row = Row(
            changed_node_type=broadcast_changed_node_type_mapping.value.get(cc.changed_node_type, "UNKNOWN"),
            source_user_labels=[broadcast_label_type_mapping.value.get(label, "UNKNOWN") for label in cc.source_user_labels],
            target_user_labels=[broadcast_label_type_mapping.value.get(label, "UNKNOWN") for label in cc.target_user_labels],
            source_edge=source_edge_row
        )

    label_change_row = None
    if event.HasField("label_change"):
        lc = event.label_change
        source_label_event_row = None
        if lc.HasField("source_label_event"):
            sle = lc.source_label_event
            node_row = Row(
                type=broadcast_changed_node_type_mapping.value.get(sle.node.type, "UNKNOWN"),
                token=sle.node.token
            ) if sle.HasField("node") else None
            source_label_event_row = Row(
                node=node_row,
                label=broadcast_label_type_mapping.value.get(sle.label, "UNKNOWN"),
                effective_at_msec=sle.effective_at_msec,
                present=sle.present,
                created_at=sle.created_at,
                updated_at=sle.updated_at,
                storage_method=broadcast_label_storage_method_mapping.value.get(sle.storage_method, "UNKNOWN")
            )
        label_change_row = Row(
            connection_node_types=[broadcast_changed_node_type_mapping.value.get(node, "UNKNOWN") for node in lc.connection_node_types],
            changed_source_user_label=broadcast_label_type_mapping.value.get(lc.changed_source_user_label, "UNKNOWN"),
            source_label_event=source_label_event_row
        )

    return Row(
        event_id=ri.event_id,
        event_time_millis=ri.event_time_millis,
        cash_connected_graph_change_event=Row(
            event_id=event.event_id,
            event_type=broadcast_event_type_mapping.value.get(event.event_type, "UNKNOWN"),
            target_user_token=event.target_user_token,
            source_user_token=event.source_user_token,
            effective_at_millis=event.effective_at_millis,
            published_at_millis=event.published_at_millis,
            user_type=broadcast_user_type_mapping.value.get(event.user_type, "UNKNOWN"),
            event_source_type=broadcast_event_source_type_mapping.value.get(event.event_source_type, "UNKNOWN"),
            connection_change=connection_change_row,
            label_change=label_change_row
        )
    )

In [0]:
# Read Avro files into a DataFrame
df_s3_egress = spark.read.format("avro").load(S3_PATH)

# Deserialize and transform the data using RDD transformations
rdd = df_s3_egress.rdd.map(lambda row: risk_input_to_row(row[0]))

# Convert the RDD to a DataFrame
df_risk_input = spark.createDataFrame(rdd, schema=risk_input_schema)

In [0]:
# Create the schema in the catalog if it doesn't exist
#spark.sql(f"CREATE SCHEMA IF NOT EXISTS {OUTPUT_TABLE}")

# Write DataFrame to a table in the specified catalog and schema
df_risk_input.write.format("delta").mode("overwrite").option("mergeSchema", "true").saveAsTable(OUTPUT_TABLE)


In [0]:
query = f"""
select count(*) from {OUTPUT_TABLE};
"""
display(spark.sql(query))

count(1)
11749550


In [0]:
query = f"""
select 
    cash_connected_graph_change_event.event_source_type as event_source_type, 
    cash_connected_graph_change_event.label_change is not null as is_label_change,
    count(*) as count,
    min(from_unixtime(cash_connected_graph_change_event.effective_at_millis / 1000)) as min_effective_at, 
    max(from_unixtime(cash_connected_graph_change_event.effective_at_millis / 1000)) as max_effective_at, 
    min(from_unixtime(cash_connected_graph_change_event.published_at_millis / 1000)) as min_published_at, 
    max(from_unixtime(cash_connected_graph_change_event.published_at_millis / 1000)) as max_published_at
from {OUTPUT_TABLE}
group by event_source_type, is_label_change;
"""
display(spark.sql(query))

event_source_type,is_label_change,count,min_effective_at,max_effective_at,min_published_at,max_published_at
BAU,False,5006462,2025-06-26 00:00:00,2025-06-26 23:59:59,2025-06-26 00:00:02,2025-07-01 01:00:35
BAU,True,6743088,2025-06-26 00:00:00,2025-06-26 23:59:57,2025-06-26 00:00:02,2025-06-27 00:00:04


In [0]:
# show duplication
query = f"""
select 
    cash_connected_graph_change_event.event_source_type as event_source_type, 
    cash_connected_graph_change_event.label_change is not null as is_label_change,
    count(distinct(event_id)) as num_unique_ids, 
    count(*) as num_events,
    round((count(*) - count(distinct(event_id))) * 1.0 / count(*), 2) as duplicate_ratio
from {OUTPUT_TABLE}
group by event_source_type, is_label_change;
"""
display(spark.sql(query))

event_source_type,is_label_change,num_unique_ids,num_events,duplicate_ratio
BAU,True,4152611,6743088,0.38
BAU,False,4823770,5006462,0.04
