In [None]:
import os
import uuid
from array import array
from pyspark.sql import DataFrame
import pyspark.sql.functions as f
from pyspark.sql.types import StringType,BooleanType,StructType,StructField,IntegerType, DecimalType
from pyspark.sql.functions import lit
from decimal import Decimal

f_uuid = f.udf(lambda: str(uuid.uuid4()), StringType())


In [None]:
cosmosEndpoint = "https://xxxxxx.documents.azure.com:443/"
cosmosMasterKey = "*******"
cosmosDatabaseName = "*******"
cosmosContainerName = "*******"

cfg = {
  "spark.cosmos.accountEndpoint" : cosmosEndpoint,
  "spark.cosmos.accountKey" : cosmosMasterKey,
  "spark.cosmos.database" : cosmosDatabaseName,
  "spark.cosmos.container" : cosmosContainerName,
}
# Configure Catalog Api to be used
spark.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", cosmosMasterKey)
spark.conf.set("spark.cosmos.throughputControl.enabled",True)
spark.conf.set("spark.cosmos.throughputControl.targetThroughput",20000)

def write_to_cosmos_graph(df: DataFrame):
        
    df.write\
   .format("cosmos.oltp")\
   .options(**cfg)\
    .mode("Append")\
   .save()

In [None]:
def create_vertex_df(
    df: DataFrame,
    vertex_properties_col_name: list, partition_col: str,
    vertex_label: str,id: str, display_name_col: str
):
  columns = [id, partition_col,"label"]
  columns.extend(['nvl2({x}, array(named_struct("id", uuid(), "_value", {x})), NULL) AS {x}'.format(x=x) for x in vertex_properties_col_name])
  columns.extend([f'nvl2({display_name_col}, array(named_struct("id", uuid(), "_value", {display_name_col})), NULL) AS DisplayName'])
  if "label" in df.columns:
    df=df.withColumn("label",df[vertex_label])
  else:
    df=df.withColumn("label",f.lit(vertex_label))
 
  return df.selectExpr(*columns).withColumnRenamed(id,"id")
  

In [None]:
def create_edge_df(srcdf: DataFrame, destdf: DataFrame, label: str, partition_col: str, 
                   vertexidcol: str, sinkcol: str, sinklabel: str, vertexlabel: str, sinkpartitioncol: str,srcjoincol: str,destjoincol: str,isedgetable: bool):
  if(isedgetable):
      #we have edge table
      if(sinklabel in srcdf.columns):
        srcdf=srcdf.withColumn("_sinkLabel",srcdf[sinklabel])
      else:
        srcdf=srcdf.withColumn("_sinkLabel",f.lit(sinklabel))
      if(vertexlabel in srcdf.columns):
        srcdf=srcdf.withColumn("_vertexLabel",srcdf[vertexlabel])
      else:
        srcdf=srcdf.withColumn("_vertexLabel",f.lit(vertexlabel))
      srcdf=srcdf.selectExpr("_sinkLabel","_vertexLabel",srcjoincol,partition_col)
      destdf=destdf.selectExpr(label,destjoincol,vertexidcol,sinkcol,sinkpartitioncol)
      df=srcdf.join(destdf,srcdf[srcjoincol]==destdf[destjoincol],"inner")
      if("label" in df.columns):
        df=df.withColumn("label",df[label])
      else:
        df=df.withColumn("label",f.lit(label))
      df=df.withColumn("_sink",df[sinkcol]).withColumn("_sinkPartition",df[sinkpartitioncol]).withColumn("_vertexId",df[vertexidcol])\
          .withColumn("id",f_uuid()).withColumn("_isEdge",f.lit(True))
  else:
    destdf=destdf.withColumn("_sink",destdf[sinkcol]).withColumn("_sinkPartition",destdf[sinkpartitioncol]).select(destjoincol,"_sink","_sinkPartition")
    srcdf=srcdf.withColumn("_vertexId",srcdf[vertexidcol]).select(srcjoincol,"_vertexId",partition_col)
    df=srcdf.join(destdf,srcdf[srcjoincol]==destdf[destjoincol],"inner")
    df=df.withColumn("label",f.lit(label)).withColumn("id",f_uuid()).withColumn("_sinkLabel",f.lit(sinklabel))\
        .withColumn("_vertexLabel",f.lit(vertexlabel)).withColumn("_isEdge",f.lit(True))
 
  columns=["label","_sink","_sinkLabel","_vertexId","_vertexLabel","_isEdge","_sinkPartition",partition_col,"id"]
  return df.selectExpr(*columns)
  

In [None]:
#vertex_airroutes
import pandas as pd
df=spark.createDataFrame(pd.read_csv("https://raw.githubusercontent.com/krlawrence/graph/master/sample-data/air-routes-latest-nodes.csv"))

airroutes=df.withColumn("srno",df["~id"]).withColumnRenamed("~id","id").withColumnRenamed("~label","label").withColumnRenamed("code:string","code")\
  .withColumnRenamed("desc:string","desc").withColumnRenamed("country:string","country").withColumnRenamed("city:string","city")\
  .selectExpr("cast(srno as string) srno","cast(id as string) id","label","code","desc","country","city")

airroutes.show()



In [None]:
#edges_airroutes
import pandas as pd
df=spark.createDataFrame(pd.read_csv("https://raw.githubusercontent.com/krlawrence/graph/master/sample-data/air-routes-latest-edges.csv"))

airroutesedges=df.withColumn("srno",df["~id"]).withColumnRenamed("~id","id").withColumnRenamed("~label","label").withColumnRenamed("~from","from")\
  .withColumnRenamed("~to","to").withColumnRenamed("dist:int","dist")\
  .selectExpr("id","cast(from as string) from","cast(to as string) to","label","dist","srno")

airroutesedges.show()


In [None]:
#Vertex
vertex_airroutes = create_vertex_df(
    df=airroutes,
    vertex_properties_col_name=["code","desc","country","code"],
  vertex_label = "label",id="id",partition_col="srno",
  display_name_col="code"
)

vertex_airroutes.display()



In [None]:
edges_airroutes=create_edge_df(airroutes,airroutesedges,"label","srno","from","to","label","label","to","srno","from",True)

edges_airroutes.schema

#edges_airroutes.show()

In [None]:
#Write Vertex
write_to_cosmos_graph(vertex_airroutes)


In [None]:
#Write Edges
write_to_cosmos_graph(edges_airroutes)