### Create training dataset for anomaly detection model
In this notebook We are going to create training dataset from node embeddings feature group and register to Hopsworks Feature Store. 
![Training Dataset](./images/create_training_dataset.png)

### Create a connection to hsfs

In [1]:
import hsfs
from hops import hdfs
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
26,application_1630916144621_0031,pyspark,idle,Link,Link


SparkSession available as 'spark'.
Connected. Call `.close()` to terminate connection gracefully.

### Retrieve alert nodes feature group from hsfs

In [2]:
alert_nodes_fg = fs.get_feature_group("alert_nodes_fg", 1)
node_embeddings_fg = fs.get_feature_group("node_embeddings_fg", 1) 

### Prepare training datasets for anomaly detection 
###### In the next notebook we are going to train [gan for anomaly detection](https://arxiv.org/pdf/1905.11034.pdf). Durring training step  we will provide only features of accounts that have never been reported for money laundering behaviour.  But we will disclose previously reported accounts to the model only in evaluation step.   

In [11]:
non_sar_emb_query = node_embeddings_fg.select(["embedding"])\
                                      .join(alert_nodes_fg.select(["is_sar"])\
                                      .filter(alert_nodes_fg.is_sar == 0))

In [12]:
non_sar_emb_query.show(5)

+--------------------+------+
|           embedding|is_sar|
+--------------------+------+
|[0.62268567085266...|     0|
|[0.53648710250854...|     0|
|[0.27968072891235...|     0|
|[0.09070301055908...|     0|
|[-0.1516845226287...|     0|
+--------------------+------+
only showing top 5 rows

In [13]:
non_sar_emb_query.read().count()

6531

In [14]:
non_sar_td = fs.create_training_dataset(name="gan_non_sar_training_df",
                                       version=1,
                                       data_format="tfrecord",
                                       label=["is_sar"], 
                                       statistics_config=False, 
                                       splits={'train': 0.8, 'test': 0.2},
                                       coalesce=True,
                                       description="non sar dataset for gan training")
non_sar_td.save(non_sar_emb_query)

## For testing and evaluation we will include known SAR nodes to measure anomaly score  

In [15]:
non_sar_td = fs.get_training_dataset("gan_non_sar_training_df", 2)
non_sar_test_df = non_sar_td.read(split="test")

In [17]:
sar_emb_query = node_embeddings_fg.select(["embedding"])\
                                  .join(alert_nodes_fg.select(["is_sar"])\
                                  .filter(alert_nodes_fg.is_sar == 1))

In [18]:
sar_df = sar_emb_query.read()
sar_df = sar_df.select(*non_sar_test_df.columns)
eval_df = non_sar_test_df.union(sar_df)
eval_df.cache()
eval_df.show()

+--------------------+------+
|           embedding|is_sar|
+--------------------+------+
|[-0.9995496273040...|     0|
|[-0.9989874362945...|     0|
|[-0.9973371028900...|     0|
|[-0.9962804317474...|     0|
|[-0.9948248863220...|     0|
|[-0.9936296939849...|     0|
|[-0.9926474094390...|     0|
|[-0.9919345378875...|     0|
|[-0.9901528358459...|     0|
|[-0.9888505935668...|     0|
|[-0.9861581325531...|     0|
|[-0.9829399585723...|     0|
|[-0.9826931953430...|     0|
|[-0.9795489311218...|     0|
|[-0.9787058830261...|     0|
|[-0.9780728816986...|     0|
|[-0.9771800041198...|     0|
|[-0.9768762588500...|     0|
|[-0.9761099815368...|     0|
|[-0.9757995605468...|     0|
+--------------------+------+
only showing top 20 rows

In [19]:
non_sar_test_df.count()

1322

In [20]:
sar_df.count()

816

In [21]:
eval_df.count()

2138

In [22]:
gan_eval_ds = fs.create_training_dataset(name="gan_eval_df",
                                       version=1,
                                       data_format="tfrecord",
                                       label=["is_sar"], 
                                       statistics_config=False, 
                                       coalesce = True,
                                       description="evaluation dataset for gan training")
gan_eval_ds.save(eval_df)