### Create training dataset for anomaly detection model
In this notebook We are going to create training dataset from node embeddings feature group and register to Hopsworks Feature Store. 
![Training Dataset](./images/create_training_dataset.png)

### Create a connection to hsfs

In [1]:
import hsfs
from hops import hdfs
# Create a connection
connection = hsfs.connection()
# Get the feature store handle for the project's feature store
fs = connection.get_feature_store()

Connected. Call `.close()` to terminate connection gracefully.


### Retrieve alert nodes feature group from hsfs

In [4]:
#alert_nodes_fg = fs.get_feature_group("alert_nodes_fg", 1)
transactions_fg = fs.get_feature_group("transactions_fg", 1)
alert_transactions_fg = fs.get_feature_group("alert_transactions_fg", 1)
node_embeddings_fg = fs.get_feature_group("node_embeddings_fg", 1) 

In [8]:
# create alert_nodes_fg, whether nodes were part of previously known money laundering scheme or not
edges_with_labels = transactions_fg.select(["source","target","tran_id","base_amt","tran_timestamp"]).join(alert_transactions_fg.select(["is_sar"]),["tran_id"],"left")
edges_with_labels_pdf = edges_with_labels.read()

2022-05-30 22:54:28,062 INFO: USE `aml_demo_featurestore`
2022-05-30 22:54:28,939 INFO: SELECT `fg1`.`source` `source`, `fg1`.`target` `target`, `fg1`.`tran_id` `tran_id`, `fg1`.`base_amt` `base_amt`, `fg1`.`tran_timestamp` `tran_timestamp`, `fg0`.`is_sar` `is_sar`
FROM `aml_demo_featurestore`.`transactions_fg_1` `fg1`
INNER JOIN `aml_demo_featurestore`.`alert_transactions_fg_1` `fg0` ON `fg1`.`tran_id` = `fg0`.`tran_id`


In [9]:
edges_with_labels_pdf.head()

Unnamed: 0,source,target,tran_id,base_amt,tran_timestamp,is_sar
0,fa9657cd,b7c3ee5f,309059,2927.58,1596499200000,1
1,f9660357,d5c32e28,310070,2443.28,1596499200000,1
2,5cc91626,56e9fc25,569446,2715.63,1612224000000,1
3,fd166006,7afc4353,929777,2916.66,1633996800000,1
4,56eaa6a7,7afc4353,929776,2916.66,1633996800000,1


In [13]:
alert_edges = edges_with_labels_pdf[edges_with_labels_pdf.is_sar ==1]
alert_edges.head()
alert_sources = alert_edges[["source"]]
alert_sources.columns = ["id"]
alert_sources.head()
alert_targets = alert_edges[["target"]]
alert_targets.columns = ["id"]
alert_nodes = alert_sources.append(alert_targets, ignore_index=True)
alert_nodes = alert_nodes.drop_duplicates()
alert_nodes.head()

Unnamed: 0,id
0,fa9657cd
1,f9660357
2,5cc91626
3,fd166006
4,56eaa6a7


### Prepare training datasets for anomaly detection 
###### In the next notebook we are going to train [gan for anomaly detection](https://arxiv.org/pdf/1905.11034.pdf). Durring training step  we will provide only features of accounts that have never been reported for money laundering behaviour.  But we will disclose previously reported accounts to the model only in evaluation step.   

In [3]:
non_sar_emb_query = node_embeddings_fg.select(["embedding"])\
                                      .join(alert_nodes_fg.select(["is_sar"])\
                                      .filter(alert_nodes_fg.is_sar == 0))

In [4]:
non_sar_emb_query.show(5)

+--------------------+------+
|           embedding|is_sar|
+--------------------+------+
|[-0.2894337177276...|     0|
|[-0.8168580532073...|     0|
|[0.89537668228149...|     0|
|[0.55149841308593...|     0|
|[0.28041338920593...|     0|
+--------------------+------+
only showing top 5 rows

In [5]:
non_sar_emb_query.read().count()

6531

In [6]:
non_sar_td = fs.create_training_dataset(name="gan_non_sar_training_df",
                                       version=1,
                                       data_format="tfrecord",
                                       label=["is_sar"], 
                                       statistics_config={"enabled": False, "histograms": False, "correlations": False, "exact_uniqueness": False}, 
                                       splits={'train': 0.8, 'test': 0.2},
                                       coalesce=True,
                                       description="non sar dataset for gan training")
non_sar_td.save(non_sar_emb_query)



## For testing and evaluation we will include known SAR nodes to measure anomaly score  

In [7]:
non_sar_td = fs.get_training_dataset("gan_non_sar_training_df", 1)
non_sar_test_df = non_sar_td.read(split="test")

In [8]:
sar_emb_query = node_embeddings_fg.select(["embedding"])\
                                  .join(alert_nodes_fg.select(["is_sar"])\
                                  .filter(alert_nodes_fg.is_sar == 1))

In [9]:
sar_df = sar_emb_query.read()
sar_df = sar_df.select(*non_sar_test_df.columns)
eval_df = non_sar_test_df.union(sar_df)
eval_df.cache()
eval_df.show()

+--------------------+------+
|           embedding|is_sar|
+--------------------+------+
|[-0.9998507499694...|     0|
|[-0.9986557960510...|     0|
|[-0.9984421730041...|     0|
|[-0.9970214366912...|     0|
|[-0.9947502613067...|     0|
|[-0.9934816360473...|     0|
|[-0.9908211231231...|     0|
|[-0.9882340431213...|     0|
|[-0.9830579757690...|     0|
|[-0.9823658466339...|     0|
|[-0.9815602302551...|     0|
|[-0.9814951419830...|     0|
|[-0.9812114238739...|     0|
|[-0.9809970855712...|     0|
|[-0.9808425903320...|     0|
|[-0.9780859947204...|     0|
|[-0.9746136665344...|     0|
|[-0.9715352058410...|     0|
|[-0.9711296558380...|     0|
|[-0.9682199954986...|     0|
+--------------------+------+
only showing top 20 rows

In [10]:
non_sar_test_df.count()

1267

In [11]:
sar_df.count()

816

In [12]:
eval_df.count()

2083

In [None]:
gan_eval_ds = fs.create_training_dataset(name="gan_eval_df",
                                       version=1,
                                       data_format="tfrecord",
                                       label=["is_sar"], 
                                       statistics_config={"enabled": False, "histograms": False, "correlations": False, "exact_uniqueness": False}, 
                                       coalesce = True,
                                       description="evaluation dataset for gan training")
gan_eval_ds.save(eval_df)

## Training dataset provenance
![Training dataset provenance](./images/provenance_td.png)