In [1]:
spark

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
44,application_1610641989196_0291,pyspark,idle,Link,Link


SparkSession available as 'spark'.
<pyspark.sql.session.SparkSession object at 0x7f3965eb4c10>

In [2]:
import hashlib
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
import hsfs

In [3]:
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Connected. Call `.close()` to terminate connection gracefully.

## Load transactions feature group from hsfs

In [4]:
transactions_fg = fs.get_feature_group("transactions_fg", 1)
transactions_df = transactions_fg.read()
transactions_df.show()

+--------+--------+-------+-------+--------+--------------+
|     src|base_amt|tran_id|tx_type|     dst|tran_timestamp|
+--------+--------+-------+-------+--------+--------------+
|3aa9646b|  858.77|    496|      4|1e46e726|        Jan-01|
|49203bc3|  386.86|   1342|      4|a74d1101|        Jan-01|
|616d4505|  616.43|   1580|      4|99af2455|        Jan-02|
|39be1ea2|  146.44|   2866|      4|e7ec7bdb|        Jan-02|
|e2e0d938|  439.09|   3997|      4|afc399a9|        Jan-03|
|75c9a805|   361.0|   5518|      4|d7a317f6|        Jan-04|
|c14f4989|  768.98|   7340|      4|733a496b|        Jan-06|
|576eb672|   943.4|   9376|      4|aa49b0eb|        Jan-07|
|847a9cf6|   668.3|  10362|      4|b070a6bb|        Jan-08|
|12a388ff|  139.84|  10817|      4|586377aa|        Jan-08|
|b36f9c84|  499.47|  11317|      4|1b467848|        Jan-08|
|362e42e0|  357.96|  11748|      4|385afb8b|        Jan-09|
|572014da|   630.9|  13285|      4|acd60eca|        Jan-10|
|5ff2d9a7|  685.07|  14832|      4|31976

## Load alert transactions feature group from hsfs

In [5]:
alert_transactions_fg = fs.get_feature_group("alert_transactions_fg", 1)
alert_transactions_df = alert_transactions_fg.read()
alert_transactions_df.show()

+------+-------+--------+--------------+
|is_sar|tran_id|alert_id|    alert_type|
+------+-------+--------+--------------+
|  true| 447977|      52|gather_scatter|
|  true| 449282|      23|scatter_gather|
|  true| 454797|      52|gather_scatter|
|  true| 462363|      68|gather_scatter|
|  true| 468776|      68|gather_scatter|
|  true| 518050|      26|scatter_gather|
|  true| 518475|      10|scatter_gather|
|  true| 519362|      26|scatter_gather|
|  true| 521249|      26|scatter_gather|
|  true| 521357|      65|gather_scatter|
|  true| 557238|      25|scatter_gather|
|  true| 558782|      25|scatter_gather|
|  true| 559459|      61|gather_scatter|
|  true| 559460|      61|gather_scatter|
|  true| 559567|      69|gather_scatter|
|  true| 553958|      69|gather_scatter|
|  true| 554411|      25|scatter_gather|
|  true| 555697|      61|gather_scatter|
|  true| 556608|      61|gather_scatter|
|  true| 556609|      61|gather_scatter|
+------+-------+--------+--------------+
only showing top

## Load party feature group from hsfs

In [6]:
party_fg = fs.get_feature_group("party_fg", 1)
party_df = party_fg.read()
party_df.show()

+----+--------+
|type|      id|
+----+--------+
|   0|5628bd6c|
|   0|a1fcba39|
|   1|f56c9501|
|   0|9969afdd|
|   1|b356eeae|
|   0|3406706a|
|   0|26c56102|
|   1|e386ebf7|
|   1|8c094b0d|
|   1|939235aa|
|   0|de6bf2a5|
|   0|33a8ff5b|
|   1|a32807a1|
|   0|2906ef08|
|   1|c2a01b8d|
|   1|5a99160f|
|   0|8b9017b8|
|   1|fcf3bbf3|
|   0|5132aa4d|
|   1|68b90958|
+----+--------+
only showing top 20 rows

## Create graph edge training dataset

In [7]:
alert_transactions_df.count()

915

In [8]:
transactions_df.count()

438386

In [9]:
edges = transactions_df.join(alert_transactions_df,["tran_id"],"left")
edges = edges.withColumn("is_sar",F.when(F.col("is_sar") == "true", 1).otherwise(0))\
             .select("src","dst","tx_type","base_amt","tran_id","is_sar")\
             .toDF('source', 'target', 'tx_type', 'base_amt', 'tran_id', 'is_sar')

In [10]:
edges.show()

+--------+--------+-------+--------+-------+------+
|  source|  target|tx_type|base_amt|tran_id|is_sar|
+--------+--------+-------+--------+-------+------+
|3aa9646b|1e46e726|      4|  858.77|    496|     0|
|49203bc3|a74d1101|      4|  386.86|   1342|     0|
|616d4505|99af2455|      4|  616.43|   1580|     0|
|39be1ea2|e7ec7bdb|      4|  146.44|   2866|     0|
|e2e0d938|afc399a9|      4|  439.09|   3997|     0|
|75c9a805|d7a317f6|      4|   361.0|   5518|     0|
|c14f4989|733a496b|      4|  768.98|   7340|     0|
|576eb672|aa49b0eb|      4|   943.4|   9376|     0|
|847a9cf6|b070a6bb|      4|   668.3|  10362|     0|
|12a388ff|586377aa|      4|  139.84|  10817|     0|
|b36f9c84|1b467848|      4|  499.47|  11317|     0|
|362e42e0|385afb8b|      4|  357.96|  11748|     0|
|572014da|acd60eca|      4|   630.9|  13285|     0|
|5ff2d9a7|31976e38|      4|  685.07|  14832|     0|
|24bf603c|fcf3bbf3|      4|  964.81|  15619|     0|
|9a118f8d|ca0967a6|      4|  919.76|  16574|     0|
|65b8a85f|bc

In [11]:
edges.count()

438386

In [12]:
edges.where(F.col("is_sar")==1).count()

915

In [13]:
edges_td_meta = fs.create_training_dataset(name="edges_td",
                                       version=1,
                                       data_format="csv",
                                       label = ["is_sar"],   
                                       description="edges training dataset")
edges_td_meta.save(edges)

<hsfs.training_dataset.TrainingDataset object at 0x7f3973229cd0>

## Create graph node training dataset

In [14]:
sources = edges.select(["source"]).toDF("id")
targets = edges.select(["target"]).toDF("id")
nodes = sources.union(targets).dropDuplicates(subset=["id"])
nodes.show()

+--------+
|      id|
+--------+
|5c01ec6e|
|5132aa4d|
|62827917|
|7138cbc6|
|243b1e8b|
|a4e0bd48|
|a32807a1|
|c412103b|
|01fdc089|
|5a99160f|
|d3adb450|
|d7a0ca48|
|c2a01b8d|
|de5c22e0|
|68b90958|
|939235aa|
|a9edaba6|
|e386ebf7|
|26c56102|
|9969afdd|
+--------+
only showing top 20 rows

In [15]:
nodes.count()

7347

In [16]:
nodes_td = nodes.join(party_df, ["id"])
nodes_td.count()

7347

In [17]:
nodes_td.show()

+--------+----+
|      id|type|
+--------+----+
|243b1e8b|   1|
|a4e0bd48|   0|
|01fdc089|   1|
|4b46d80d|   0|
|ab04afb9|   0|
|cf2f4c98|   0|
|62827917|   1|
|5a99160f|   1|
|e386ebf7|   1|
|68b90958|   1|
|d3adb450|   1|
|90e0340f|   1|
|1a14903a|   0|
|c412103b|   1|
|5645140a|   0|
|b6529244|   0|
|939235aa|   1|
|5132aa4d|   0|
|fcf3bbf3|   1|
|9969afdd|   0|
+--------+----+
only showing top 20 rows

In [18]:
node_td_meta = fs.create_training_dataset(name="node_td",
                                       version=1,
                                       data_format="csv",   
                                       description="node training dataset")
node_td_meta.save(nodes_td)

<hsfs.training_dataset.TrainingDataset object at 0x7f39736bff90>