In [1]:
spark

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
3,application_1613335206545_0001,pyspark,idle,Link,Link


SparkSession available as 'spark'.
<pyspark.sql.session.SparkSession object at 0x7ff05dc49f10>

In [2]:
import hashlib
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
import hsfs

In [3]:
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Connected. Call `.close()` to terminate connection gracefully.

## Load transactions feature group from hsfs

In [4]:
transactions_fg = fs.get_feature_group("transactions_fg", 1)
transactions_df = transactions_fg.read()
transactions_df.show()

+--------+-------+--------------+-------+--------+--------+
|     dst|tran_id|tran_timestamp|tx_type|     src|base_amt|
+--------+-------+--------------+-------+--------+--------+
|1e46e726|    496|        Jan-01|      4|3aa9646b|  858.77|
|a74d1101|   1342|        Jan-01|      4|49203bc3|  386.86|
|99af2455|   1580|        Jan-02|      4|616d4505|  616.43|
|e7ec7bdb|   2866|        Jan-02|      4|39be1ea2|  146.44|
|afc399a9|   3997|        Jan-03|      4|e2e0d938|  439.09|
|d7a317f6|   5518|        Jan-04|      4|75c9a805|   361.0|
|733a496b|   7340|        Jan-06|      4|c14f4989|  768.98|
|aa49b0eb|   9376|        Jan-07|      4|576eb672|   943.4|
|b070a6bb|  10362|        Jan-08|      4|847a9cf6|   668.3|
|586377aa|  10817|        Jan-08|      4|12a388ff|  139.84|
|1b467848|  11317|        Jan-08|      4|b36f9c84|  499.47|
|385afb8b|  11748|        Jan-09|      4|362e42e0|  357.96|
|acd60eca|  13285|        Jan-10|      4|572014da|   630.9|
|31976e38|  14832|        Jan-11|      4

## Load alert transactions feature group from hsfs

In [5]:
alert_transactions_fg = fs.get_feature_group("alert_transactions_fg", 1)
alert_transactions_df = alert_transactions_fg.read()
alert_transactions_df.show()

+--------------+------+-------+--------+
|    alert_type|is_sar|tran_id|alert_id|
+--------------+------+-------+--------+
|gather_scatter|  true|  11873|      47|
|gather_scatter|  true|  11874|      47|
|gather_scatter|  true|  11875|      47|
|gather_scatter|  true|  13151|      47|
|gather_scatter|  true|  23148|      47|
|scatter_gather|  true|  23779|      17|
|scatter_gather|  true|  23780|      17|
|scatter_gather|  true|  26441|      17|
|scatter_gather|  true|  26442|      17|
|gather_scatter|  true|  28329|      47|
|gather_scatter|  true|  31581|      47|
|gather_scatter|  true|  34310|      47|
|scatter_gather|  true|  34433|      17|
|gather_scatter|  true|  36131|      58|
|scatter_gather|  true|  36563|      17|
|scatter_gather|  true|  41430|      17|
|scatter_gather|  true|  42363|      17|
|gather_scatter|  true|  42511|      58|
|gather_scatter|  true|  44370|      58|
|gather_scatter|  true|  46176|      58|
+--------------+------+-------+--------+
only showing top

## Load party feature group from hsfs

In [6]:
party_fg = fs.get_feature_group("party_fg", 1)
party_df = party_fg.read()
party_df.show()

+--------+----+
|      id|type|
+--------+----+
|5628bd6c|   0|
|a1fcba39|   0|
|f56c9501|   1|
|9969afdd|   0|
|b356eeae|   1|
|3406706a|   0|
|26c56102|   0|
|e386ebf7|   1|
|8c094b0d|   1|
|939235aa|   1|
|de6bf2a5|   0|
|33a8ff5b|   0|
|a32807a1|   1|
|2906ef08|   0|
|c2a01b8d|   1|
|5a99160f|   1|
|8b9017b8|   0|
|fcf3bbf3|   1|
|5132aa4d|   0|
|68b90958|   1|
+--------+----+
only showing top 20 rows

## Create graph edge training dataset

In [7]:
alert_transactions_df.count()

915

In [8]:
transactions_df.count()

438386

In [9]:
edges = transactions_df.join(alert_transactions_df,["tran_id"],"left")
edges = edges.withColumn("is_sar",F.when(F.col("is_sar") == "true", 1).otherwise(0))\
             .select("src","dst","tx_type","base_amt","tran_id","is_sar")\
             .toDF('source', 'target', 'tx_type', 'base_amt', 'tran_id', 'is_sar')

In [10]:
edges.show()

+--------+--------+-------+--------+-------+------+
|  source|  target|tx_type|base_amt|tran_id|is_sar|
+--------+--------+-------+--------+-------+------+
|3aa9646b|1e46e726|      4|  858.77|    496|     0|
|49203bc3|a74d1101|      4|  386.86|   1342|     0|
|616d4505|99af2455|      4|  616.43|   1580|     0|
|39be1ea2|e7ec7bdb|      4|  146.44|   2866|     0|
|e2e0d938|afc399a9|      4|  439.09|   3997|     0|
|75c9a805|d7a317f6|      4|   361.0|   5518|     0|
|c14f4989|733a496b|      4|  768.98|   7340|     0|
|576eb672|aa49b0eb|      4|   943.4|   9376|     0|
|847a9cf6|b070a6bb|      4|   668.3|  10362|     0|
|12a388ff|586377aa|      4|  139.84|  10817|     0|
|b36f9c84|1b467848|      4|  499.47|  11317|     0|
|362e42e0|385afb8b|      4|  357.96|  11748|     0|
|572014da|acd60eca|      4|   630.9|  13285|     0|
|5ff2d9a7|31976e38|      4|  685.07|  14832|     0|
|24bf603c|fcf3bbf3|      4|  964.81|  15619|     0|
|9a118f8d|ca0967a6|      4|  919.76|  16574|     0|
|65b8a85f|bc

In [11]:
edges.count()

438386

In [12]:
edges.where(F.col("is_sar")==1).count()

915

In [13]:
edges_td_meta = fs.create_training_dataset(name="edges_td",
                                       version=1,
                                       data_format="csv",
                                       label = ["is_sar"],   
                                       description="edges training dataset")
edges_td_meta.save(edges)

<hsfs.training_dataset.TrainingDataset object at 0x7ff0115cd3d0>

## Create graph node training dataset

In [14]:
sources = edges.select(["source"]).toDF("id")
targets = edges.select(["target"]).toDF("id")
nodes = sources.union(targets).dropDuplicates(subset=["id"])
nodes.show()

+--------+
|      id|
+--------+
|26c56102|
|9969afdd|
|2906ef08|
|5a99160f|
|de6bf2a5|
|b356eeae|
|f56c9501|
|a32807a1|
|939235aa|
|e386ebf7|
|33a8ff5b|
|a9edaba6|
|8c094b0d|
|fcf3bbf3|
|5628bd6c|
|a1fcba39|
|c2a01b8d|
|7138cbc6|
|3406706a|
|68b90958|
+--------+
only showing top 20 rows

In [15]:
nodes.count()

7347

In [16]:
nodes_td = nodes.join(party_df, ["id"])
nodes_td.count()

7347

In [17]:
nodes_td.show()

+--------+----+
|      id|type|
+--------+----+
|26c56102|   0|
|9969afdd|   0|
|2906ef08|   0|
|5a99160f|   1|
|de6bf2a5|   0|
|b356eeae|   1|
|f56c9501|   1|
|a32807a1|   1|
|939235aa|   1|
|e386ebf7|   1|
|33a8ff5b|   0|
|a9edaba6|   0|
|8c094b0d|   1|
|fcf3bbf3|   1|
|5628bd6c|   0|
|a1fcba39|   0|
|c2a01b8d|   1|
|7138cbc6|   1|
|3406706a|   0|
|68b90958|   1|
+--------+----+
only showing top 20 rows

In [18]:
node_td_meta = fs.create_training_dataset(name="node_td",
                                       version=1,
                                       data_format="csv",   
                                       description="node training dataset")
node_td_meta.save(nodes_td)

<hsfs.training_dataset.TrainingDataset object at 0x7ff011147a90>