In [1]:
spark

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
27,application_1609605496842_0009,pyspark,idle,Link,Link


SparkSession available as 'spark'.
<pyspark.sql.session.SparkSession object at 0x7f336b801790>

In [2]:
import hashlib
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
import hsfs

In [3]:
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Connected. Call `.close()` to terminate connection gracefully.

## Load transactions feature group from hsfs

In [4]:
transactions_fg = fs.get_feature_group("transactions_fg", 1)
transactions_df = transactions_fg.read()
transactions_df.show()

+--------+-------+--------------+-------+--------+--------+
|     dst|tx_type|tran_timestamp|tran_id|     src|base_amt|
+--------+-------+--------------+-------+--------+--------+
|1e46e726|      4|        Jan-01|    496|3aa9646b|  858.77|
|a74d1101|      4|        Jan-01|   1342|49203bc3|  386.86|
|99af2455|      4|        Jan-02|   1580|616d4505|  616.43|
|e7ec7bdb|      4|        Jan-02|   2866|39be1ea2|  146.44|
|afc399a9|      4|        Jan-03|   3997|e2e0d938|  439.09|
|d7a317f6|      4|        Jan-04|   5518|75c9a805|   361.0|
|733a496b|      4|        Jan-06|   7340|c14f4989|  768.98|
|aa49b0eb|      4|        Jan-07|   9376|576eb672|   943.4|
|b070a6bb|      4|        Jan-08|  10362|847a9cf6|   668.3|
|586377aa|      4|        Jan-08|  10817|12a388ff|  139.84|
|1b467848|      4|        Jan-08|  11317|b36f9c84|  499.47|
|385afb8b|      4|        Jan-09|  11748|362e42e0|  357.96|
|acd60eca|      4|        Jan-10|  13285|572014da|   630.9|
|31976e38|      4|        Jan-11|  14832

## Load alert transactions feature group from hsfs

In [5]:
alert_transactions_fg = fs.get_feature_group("alert_transactions_fg", 1)
alert_transactions_df = alert_transactions_fg.read()
alert_transactions_df.show()

+--------------+-------+--------+------+
|    alert_type|tran_id|alert_id|is_sar|
+--------------+-------+--------+------+
|gather_scatter| 447977|      52|  true|
|scatter_gather| 449282|      23|  true|
|gather_scatter| 454797|      52|  true|
|gather_scatter| 462363|      68|  true|
|gather_scatter| 468776|      68|  true|
|scatter_gather| 518050|      26|  true|
|scatter_gather| 518475|      10|  true|
|scatter_gather| 519362|      26|  true|
|scatter_gather| 521249|      26|  true|
|gather_scatter| 521357|      65|  true|
|scatter_gather| 557238|      25|  true|
|scatter_gather| 558782|      25|  true|
|gather_scatter| 559459|      61|  true|
|gather_scatter| 559460|      61|  true|
|gather_scatter| 559567|      69|  true|
|gather_scatter| 553958|      69|  true|
|scatter_gather| 554411|      25|  true|
|gather_scatter| 555697|      61|  true|
|gather_scatter| 556608|      61|  true|
|gather_scatter| 556609|      61|  true|
+--------------+-------+--------+------+
only showing top

## Load party feature group from hsfs

In [6]:
party_fg = fs.get_feature_group("party_fg", 1)
party_df = party_fg.read()
party_df.show()

+----+--------+
|type|      id|
+----+--------+
|   0|5628bd6c|
|   0|a1fcba39|
|   1|f56c9501|
|   0|9969afdd|
|   1|b356eeae|
|   0|3406706a|
|   0|26c56102|
|   1|e386ebf7|
|   1|8c094b0d|
|   1|939235aa|
|   0|de6bf2a5|
|   0|33a8ff5b|
|   1|a32807a1|
|   0|2906ef08|
|   1|c2a01b8d|
|   1|5a99160f|
|   0|8b9017b8|
|   1|fcf3bbf3|
|   0|5132aa4d|
|   1|68b90958|
+----+--------+
only showing top 20 rows

## Create graph edge training dataset

In [7]:
alert_transactions_df.count()

915

In [8]:
transactions_df.count()

438386

In [9]:
edges = transactions_df.join(alert_transactions_df,["tran_id"],"left")
edges = edges.withColumn("is_sar",F.when(F.col("is_sar") == "true", 1).otherwise(0))\
             .select("src","dst","tx_type","base_amt","tran_id","is_sar")\
             .toDF('source', 'target', 'tx_type', 'base_amt', 'tran_id', 'is_sar')

In [10]:
edges.show()

+--------+--------+-------+--------+-------+------+
|  source|  target|tx_type|base_amt|tran_id|is_sar|
+--------+--------+-------+--------+-------+------+
|3aa9646b|1e46e726|      4|  858.77|    496|     0|
|49203bc3|a74d1101|      4|  386.86|   1342|     0|
|616d4505|99af2455|      4|  616.43|   1580|     0|
|39be1ea2|e7ec7bdb|      4|  146.44|   2866|     0|
|e2e0d938|afc399a9|      4|  439.09|   3997|     0|
|75c9a805|d7a317f6|      4|   361.0|   5518|     0|
|c14f4989|733a496b|      4|  768.98|   7340|     0|
|576eb672|aa49b0eb|      4|   943.4|   9376|     0|
|847a9cf6|b070a6bb|      4|   668.3|  10362|     0|
|12a388ff|586377aa|      4|  139.84|  10817|     0|
|b36f9c84|1b467848|      4|  499.47|  11317|     0|
|362e42e0|385afb8b|      4|  357.96|  11748|     0|
|572014da|acd60eca|      4|   630.9|  13285|     0|
|5ff2d9a7|31976e38|      4|  685.07|  14832|     0|
|24bf603c|fcf3bbf3|      4|  964.81|  15619|     0|
|9a118f8d|ca0967a6|      4|  919.76|  16574|     0|
|65b8a85f|bc

In [11]:
edges.count()

438386

In [12]:
edges.where(F.col("is_sar")==1).count()

915

In [13]:
edges_td_meta = fs.create_training_dataset(name="edges_td",
                                       version=1,
                                       data_format="csv",
                                       label = ["is_sar"],   
                                       description="edges training dataset")
edges_td_meta.save(edges)

<hsfs.training_dataset.TrainingDataset object at 0x7f3378ce9490>

## Create graph node training dataset

In [19]:
sources = edges.select(["source"]).toDF("id")
targets = edges.select(["target"]).toDF("id")
nodes = sources.union(targets).dropDuplicates(subset=["id"])
nodes.show()

+--------+
|      id|
+--------+
|fcf3bbf3|
|9969afdd|
|a1fcba39|
|d7a0ca48|
|43e028ef|
|a9edaba6|
|3406706a|
|b356eeae|
|de6bf2a5|
|f7e4e741|
|33a8ff5b|
|2906ef08|
|3406d993|
|f56c9501|
|26c56102|
|8b9017b8|
|8c094b0d|
|243b1e8b|
|5628bd6c|
|c2a01b8d|
+--------+
only showing top 20 rows

In [20]:
nodes.count()

7347

In [21]:
nodes_td = nodes.join(party_df, ["id"])
nodes_td.count()

7347

In [22]:
nodes_td.show()

+--------+----+
|      id|type|
+--------+----+
|fcf3bbf3|   1|
|9969afdd|   0|
|a1fcba39|   0|
|d7a0ca48|   1|
|43e028ef|   0|
|a9edaba6|   0|
|3406706a|   0|
|b356eeae|   1|
|de6bf2a5|   0|
|f7e4e741|   0|
|33a8ff5b|   0|
|2906ef08|   0|
|3406d993|   1|
|f56c9501|   1|
|26c56102|   0|
|8b9017b8|   0|
|8c094b0d|   1|
|243b1e8b|   1|
|5628bd6c|   0|
|c2a01b8d|   1|
+--------+----+
only showing top 20 rows

In [23]:
node_td_meta = fs.create_training_dataset(name="node_td",
                                       version=1,
                                       data_format="csv",   
                                       description="node training dataset")
node_td_meta.save(nodes_td)

<hsfs.training_dataset.TrainingDataset object at 0x7f3315312690>