In [None]:
import os
import uuid
from notebookutils import mssparkutils
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
from pyspark.sql.types import StringType,BooleanType
# from graphframes import *

f_uuid = F.udf(lambda: str(uuid.uuid4()), StringType())
f_bool = F.udf(lambda: True, BooleanType())


In [None]:
# update the values from your infra
cosmosEndpoint = "wss://ontologypoc.gremlin.cosmos.azure.com:443/"
cosmosMasterKey = "RcyO8fytL4FX7s8Lo9ZejRhvLwXLjN0Kp9GCUHXKTeyBLBuwrAPoAfDDBLPuoEh0jrqBMtBXbCw4ACDblJYPqg==" 
cosmosDatabaseName = "ontology_nn"
cosmosContainerName = "graphnn7"

In [None]:
# update csv file path based on your infra
df = spark.read.load('https://owl2jsonmanish.blob.core.windows.net/cosmosdbgp/PS_20174392719_1491204439457_log.csv', format='csv',header=True)
display(df.limit(10))

In [None]:
raw_data = df.selectExpr("type", 
                        "cast(amount as int) amount", 
                        "nameOrig", 
                        "cast(oldbalanceOrg as int) oldbalanceOrg", 
                        "cast(newbalanceOrig as int) newbalanceOrig",
                        "nameDest",
                        "cast(oldbalanceDest as int) oldbalanceDest", 
                        "cast(newbalanceDest as int) newbalanceDest"
)                         


In [None]:
raw_data.show()

In [None]:
cfg = {
  "spark.cosmos.accountEndpoint" : cosmosEndpoint,
  "spark.cosmos.accountKey" : cosmosMasterKey,
  "spark.cosmos.database" : cosmosDatabaseName,
  "spark.cosmos.container" : cosmosContainerName,
}
# Configure Catalog Api to be used
spark.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", cosmosMasterKey)
spark.conf.set("spark.cosmos.throughputControl.enabled",True)
spark.conf.set("spark.cosmos.throughputControl.targetThroughput",40000)

def write_to_cosmos_graph(df: DataFrame, data_type: str, save: bool = False):
    if (save):
        df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(f"synfs:/{job_id}/mydata/{data_type}/")
        
    df.write\
   .format("cosmos.oltp")\
   .options(**cfg)\
   .mode("APPEND")\
   .save()

In [None]:
# PySpark function to create veritces and edges dataframes in a format accepted by Cosmos SQL api from raw dataframe. 
# TODO: Add vertex properties 
def prepare_vertices_edge_df(
    df: DataFrame,
    source_col_name: str,
    dest_col_name: str,
    parition_key_col_name: str,
    cosmos_parition_name: str,
    edge_properties_col_name: list,
    vertex_properties_col_name: list,
    vertex_label: str = "account",
    edge_label_col_name: str = "type",
    sample: bool = False,
):
    if (sample):
        df = df.limit(100)
    nameOrig = df.select(source_col_name).withColumnRenamed(source_col_name, "id")
    nameDest = df.select(dest_col_name).withColumnRenamed(dest_col_name, "id")
    all_vertices = nameOrig.union(nameDest).distinct()
    cosmos_vertices_df = (
        all_vertices.withColumn(cosmos_parition_name, all_vertices["id"])
        .withColumn("label", F.lit(vertex_label))
        .select("label", "id", cosmos_parition_name, *vertex_properties_col_name)
        .distinct()
    )
    # Create dataframe with required columns
    # _sink => target account => nameDest
    # _sinkLabel => target label => "account"
    # _vertexId => source account => nameOrig
    # _vertexLabel => source label => "account"
    # cosmos_parition_name => partition key defined in Cosmos => "accountId"
    cosmos_edges_df = (
        df.withColumn("id", f_uuid())
        .withColumn(cosmos_parition_name, df[parition_key_col_name])
        .withColumn("label", df[edge_label_col_name])
        .withColumn("_sinkPartition", df[dest_col_name])
        .withColumn("_vertexId", df[source_col_name])
        .withColumn("_sink", df[dest_col_name])
        .withColumn("_sinkLabel", F.lit(vertex_label))
        .withColumn("_vertexLabel", F.lit(vertex_label))
        .withColumn("_isEdge", f_bool())
        .select(
            "id",
            "label",
            "_sink",
            "_sinkLabel",
            "_sinkPartition",
            "_vertexId",
            "_vertexLabel",
            "_isEdge",
            cosmos_parition_name,
            *edge_properties_col_name
        )
    )
    return cosmos_vertices_df, cosmos_edges_df


In [None]:
v, e = prepare_vertices_edge_df(
    df=raw_data,
    source_col_name="nameOrig",
    dest_col_name="nameDest",
    parition_key_col_name="nameOrig",
    cosmos_parition_name="accountId",
    edge_properties_col_name=[
        "amount",
        "oldbalanceOrg",
        "oldbalanceDest",
        "newbalanceDest",
    ],
    vertex_properties_col_name=[],
    sample=False
)


In [None]:
write_to_cosmos_graph(v,"vertices",False)
write_to_cosmos_graph(e,"edges",False)