In [3]:
spark

<pyspark.sql.session.SparkSession object at 0x7fa9ef040d90>

In [19]:
import hashlib
from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
import hsfs

In [5]:
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Connected. Call `.close()` to terminate connection gracefully.

## Load transactions feature group from hsfs

In [10]:
transactions_fg = fs.get_feature_group("transactions_fg", 1)
transactions_df = transactions_fg.read()
transactions_df.show()

+-------+--------------+-------+--------+--------+--------+
|tx_type|tran_timestamp|tran_id|     dst|     src|base_amt|
+-------+--------------+-------+--------+--------+--------+
|      4|        Jan-01|    496|1e46e726|3aa9646b|  858.77|
|      4|        Jan-01|   1342|a74d1101|49203bc3|  386.86|
|      4|        Jan-02|   1580|99af2455|616d4505|  616.43|
|      4|        Jan-02|   2866|e7ec7bdb|39be1ea2|  146.44|
|      4|        Jan-03|   3997|afc399a9|e2e0d938|  439.09|
|      4|        Jan-04|   5518|d7a317f6|75c9a805|   361.0|
|      4|        Jan-06|   7340|733a496b|c14f4989|  768.98|
|      4|        Jan-07|   9376|aa49b0eb|576eb672|   943.4|
|      4|        Jan-08|  10362|b070a6bb|847a9cf6|   668.3|
|      4|        Jan-08|  10817|586377aa|12a388ff|  139.84|
|      4|        Jan-08|  11317|1b467848|b36f9c84|  499.47|
|      4|        Jan-09|  11748|385afb8b|362e42e0|  357.96|
|      4|        Jan-10|  13285|acd60eca|572014da|   630.9|
|      4|        Jan-11|  14832|31976e38

## Load alert transactions feature group from hsfs

In [9]:
alert_transactions_fg = fs.get_feature_group("alert_transactions_fg", 1)
alert_transactions_df = alert_transactions_fg.read()
alert_transactions_df.show()

+--------+--------------+-------+------+
|alert_id|    alert_type|tran_id|is_sar|
+--------+--------------+-------+------+
|      52|gather_scatter| 447977|  true|
|      23|scatter_gather| 449282|  true|
|      52|gather_scatter| 454797|  true|
|      68|gather_scatter| 462363|  true|
|      68|gather_scatter| 468776|  true|
|      26|scatter_gather| 518050|  true|
|      10|scatter_gather| 518475|  true|
|      26|scatter_gather| 519362|  true|
|      26|scatter_gather| 521249|  true|
|      65|gather_scatter| 521357|  true|
|      25|scatter_gather| 557238|  true|
|      25|scatter_gather| 558782|  true|
|      61|gather_scatter| 559459|  true|
|      61|gather_scatter| 559460|  true|
|      69|gather_scatter| 559567|  true|
|      69|gather_scatter| 553958|  true|
|      25|scatter_gather| 554411|  true|
|      61|gather_scatter| 555697|  true|
|      61|gather_scatter| 556608|  true|
|      61|gather_scatter| 556609|  true|
+--------+--------------+-------+------+
only showing top

## Load party feature group from hsfs

In [11]:
party_fg = fs.get_feature_group("party_fg", 1)
party_df = party_fg.read()
party_df.show()

+--------+----+
|      id|type|
+--------+----+
|5628bd6c|   0|
|a1fcba39|   0|
|f56c9501|   1|
|9969afdd|   0|
|b356eeae|   1|
|3406706a|   0|
|26c56102|   0|
|e386ebf7|   1|
|8c094b0d|   1|
|939235aa|   1|
|de6bf2a5|   0|
|33a8ff5b|   0|
|a32807a1|   1|
|2906ef08|   0|
|c2a01b8d|   1|
|5a99160f|   1|
|8b9017b8|   0|
|fcf3bbf3|   1|
|5132aa4d|   0|
|68b90958|   1|
+--------+----+
only showing top 20 rows

## Create graph edge training dataset

In [14]:
alert_transactions_df.count()

915

In [15]:
transactions_df.count()

438386

In [31]:
edges = transactions_df.join(alert_transactions_df,["tran_id"],"left")
edges = edges.withColumn("is_sar",F.when(F.col("is_sar") == "true", 1).otherwise(0))\
             .select("dst","src","tx_type","base_amt","tran_timestamp","is_sar")  

In [32]:
edges.show()

+--------+--------+-------+--------+--------------+------+
|     dst|     src|tx_type|base_amt|tran_timestamp|is_sar|
+--------+--------+-------+--------+--------------+------+
|1e46e726|3aa9646b|      4|  858.77|        Jan-01|     0|
|a74d1101|49203bc3|      4|  386.86|        Jan-01|     0|
|99af2455|616d4505|      4|  616.43|        Jan-02|     0|
|e7ec7bdb|39be1ea2|      4|  146.44|        Jan-02|     0|
|afc399a9|e2e0d938|      4|  439.09|        Jan-03|     0|
|d7a317f6|75c9a805|      4|   361.0|        Jan-04|     0|
|733a496b|c14f4989|      4|  768.98|        Jan-06|     0|
|aa49b0eb|576eb672|      4|   943.4|        Jan-07|     0|
|b070a6bb|847a9cf6|      4|   668.3|        Jan-08|     0|
|586377aa|12a388ff|      4|  139.84|        Jan-08|     0|
|1b467848|b36f9c84|      4|  499.47|        Jan-08|     0|
|385afb8b|362e42e0|      4|  357.96|        Jan-09|     0|
|acd60eca|572014da|      4|   630.9|        Jan-10|     0|
|31976e38|5ff2d9a7|      4|  685.07|        Jan-11|     

In [33]:
edges.count()

438386

In [34]:
edges.where(F.col("is_sar")==1).count()

915

In [35]:
edges_td_meta = fs.create_training_dataset(name="edges_td",
                                       version=1,
                                       data_format="csv",
                                       label = ["is_sar"],   
                                       description="edges training dataset")
edges_td_meta.save(edges)

<hsfs.training_dataset.TrainingDataset object at 0x7fa9fa558150>

## Create graph node training dataset

In [36]:
sources = edges.select(["src"]).toDF("id")
targets = edges.select(["dst"]).toDF("id")
nodes = sources.union(targets).dropDuplicates(subset=["id"])
nodes.show()

+--------+
|      id|
+--------+
|26c56102|
|9969afdd|
|2906ef08|
|5a99160f|
|de6bf2a5|
|b356eeae|
|f56c9501|
|a32807a1|
|939235aa|
|e386ebf7|
|33a8ff5b|
|a9edaba6|
|8c094b0d|
|fcf3bbf3|
|5628bd6c|
|a1fcba39|
|c2a01b8d|
|7138cbc6|
|3406706a|
|68b90958|
+--------+
only showing top 20 rows

In [37]:
nodes.count()

7347

In [38]:
nodes_td = nodes.join(party_df, ["id"])
nodes_td.count()

7347

In [39]:
nodes_td.show()

+--------+----+
|      id|type|
+--------+----+
|26c56102|   0|
|9969afdd|   0|
|2906ef08|   0|
|5a99160f|   1|
|de6bf2a5|   0|
|b356eeae|   1|
|f56c9501|   1|
|a32807a1|   1|
|939235aa|   1|
|e386ebf7|   1|
|33a8ff5b|   0|
|a9edaba6|   0|
|8c094b0d|   1|
|fcf3bbf3|   1|
|5628bd6c|   0|
|a1fcba39|   0|
|c2a01b8d|   1|
|7138cbc6|   1|
|3406706a|   0|
|68b90958|   1|
+--------+----+
only showing top 20 rows

In [41]:
node_td_meta = fs.create_training_dataset(name="node_td",
                                       version=1,
                                       data_format="csv",   
                                       description="node training dataset")
node_td_meta.save(nodes_td)

<hsfs.training_dataset.TrainingDataset object at 0x7fa9fa53d290>