
   The Elliptic Data Set maps Bitcoin transactions to real entities belonging to licit categories (exchanges,
   wallet providers, miners, licit services, etc.) versus illicit ones (scams, malware, terrorist organizations, 
   ransomware, Ponzi schemes, etc.). 
   
   The task on the dataset is to classify the illicit and licit nodes in the graph.
   
   
   We will use Graph Neural Networks (GraphSAGE) to perform the node classification task. This demo shows the
   end-to-end pipeline that can be done directly on the Katana Graph platform. 
   



   Original data located at,
      https://www.kaggle.com/datasets/ellipticco/elliptic-data-set?resource=download


      203769   10_elliptic_txs_features.csv    (nodes)
      234356   11_elliptic_txs_edgelist.csv    (edges)
      203770   14_elliptic_txs_classes.csv 
      -------------------------------------
      641895   (total)


   This anonymized data set is a transaction graph collected from the Bitcoin blockchain. A node in the graph
   represents a transaction, an edge can be viewed as a flow of Bitcoins between one transaction and the other. 
    
   Each node has 166 features (property keys) and has been labeled as being created by a "licit", "illicit" or
   "unknown" entity.

   Two percent (4,545) of the nodes are labelled class1 (illicit). Twenty-one percent (42,019) are labelled
   class2 (licit). The remaining transactions are not labelled with regard to licit versus illicit.

   There are 166 property keys associated with each node. This data is unlabeled and obfuscated to protect
   privacy.
   
   
   There is a time step associated to each node, representing a measure of the time when a transaction was
   broadcasted to the Bitcoin network. The time steps, running from 1 to 49, are evenly spaced with an interval
   of about two weeks. Each time step contains a single connected component of transactions that appeared on
   the blockchain within less than three hours between each other; there are no edges connecting the different
   time steps.

   The first 94 features represent local information about the transaction – including the time step described
   above, number of inputs/outputs, transaction fee, output volume and aggregated figures such as average BTC 
   received (spent) by the inputs/outputs and average number of incoming (outgoing) transactions associated with
   the inputs/outputs. 
   
   The remaining 72 features are aggregated features, obtained using transaction information one-hop backward/
   forward from the center node - giving the maximum, minimum, standard deviation and correlation coefficients
   of the neighbour transactions for the same information data (number of inputs/outputs, transaction fee, etc.).



In [None]:
# ##################################################################

In [2]:

#  The source data files proper ..

tx_features   = "gs://katana-demo-datasets/fsi/solution_raw_data/elliptic/elliptic_txs_features.csv"
tx_edges      = "gs://katana-demo-datasets/fsi/solution_raw_data/elliptic/elliptic_txs_edgelist.csv"
tx_classes    = "gs://katana-demo-datasets/fsi/solution_raw_data/elliptic/elliptic_txs_classes.csv"

print("--")


--


In [3]:

#  Creating column names for nodes ..

feat_col_names    = ["txId", "timestamp"]
local_feats_name  = [f"local_feat_{i}" for i in range(2,95)]
agg_feats_name    = [f"agg_feat_{i}" for i in range(95,167)]


#  display(print(f"Number of rows, feat_col_names: {len(feat_col_names)}"))
#  display(print(f"Number of rows, local_feat_name: {len(local_feats_name)}"))
#  display(print(f"Number of rows, agg_feats_name: {len(agg_feats_name)}"))

#  Number of rows, feat_col_names:    2
#  Number of rows, local_feat_name:  93
#  Number of rows, agg_feats_name:   72
#                                  ----
#                                   167


feat_col_names.extend(local_feats_name)
feat_col_names.extend(agg_feats_name  )

print("--")


#  display(print(f"Number of rows, feat_col_names: {len(feat_col_names)}"))

#  Number of rows, feat_col_names: 167


#  print(feat_col_names)

#  ['txId', 'timestamp', 'local_feat_2', 'local_feat_3', 'local_feat_4', 'local_feat_5', 'local_feat_6', 'local_feat_7', 
#         ...
#     'local_feat_90', 'local_feat_91', 'local_feat_92', 'local_feat_93', 'local_feat_94', 'agg_feat_95',
#
#     'agg_feat_96', 'agg_feat_97', 'agg_feat_98', 'agg_feat_99', 'agg_feat_100', 'agg_feat_101', 'agg_feat_102',
#         ...
#      'agg_feat_161', 'agg_feat_162', 'agg_feat_163', 'agg_feat_164', 'agg_feat_165', 'agg_feat_166']



--


In [4]:

#  Creating column data types for nodes ..

feat_types   = {
   "class": "string",
   "timestamp": "string", 
   "target": "float",
   "node_type": "string"
   }


local_cols   = {}
   #
for i in range(2,95):
   local_cols[f"local_feat_{i}"] = "float"
    
agg_cols     = {}
   #
for i in range(95,167):
   agg_cols[f"agg_feat_{i}"] = "float"

    
feat_types.update(local_cols)
feat_types.update(agg_cols)


print("--")

#  print(feat_types)

#  {'class': 'string', 'timestamp': 'string', 'target': 'float', 'node_type': 'string', 'local_feat_2': 'float',
#     'local_feat_3': 'float', 'local_feat_4': 'float', 'local_feat_5': 'float', 'local_feat_6': 'float',
#         ...
#     'local_feat_92': 'float', 'local_feat_93': 'float', 'local_feat_94': 'float', 'agg_feat_95': 'float',
#
#     'agg_feat_96': 'float', 'agg_feat_97': 'float', 'agg_feat_98': 'float', 'agg_feat_99': 'float',
#        ...
#     'agg_feat_163': 'float', 'agg_feat_164': 'float', 'agg_feat_165': 'float', 'agg_feat_166': 'float'}



--


In [None]:
# ##################################################################

In [5]:

#  Create Python Dask DataFrames ..
#
#  Dont forget; Lazy evaluation, so we'll add an optional compute ..
#     (This cell will take a while because of the reading of a GS/S3 file.)

import dask.dataframe as dd


features   = dd.read_csv(tx_features, header=None, names=feat_col_names).compute()
edges      = dd.read_csv(tx_edges).compute()
classes    = dd.read_csv(tx_classes).compute()

print("--")


--


In [6]:

#  Setting display width ..

import pandas as pd
   #
pd.set_option('display.max_columns', 500 )
pd.set_option('display.width'      , 1000)
    
print("--")



--


In [None]:

#  Non-functional: Just looking at the data ..

#  display(print(f"Number of rows, features: {len(features)}"))
#  display(print(f"Sample data, features: {features.head(5)}"))
#      
#  display(print(f"Number of rows, edges: {len(edges)}"))
#  display(print(f"Sample data, edges: {edges.head(5)}"))
#      
#  display(print(f"Number of rows, classes: {len(classes)}"))
#  display(print(f"Sample data, classes: {classes.head(5)}"))
#      
#  print("--")    


#  Number of rows, features: 203769
#  Sample data, features:    txId        timestamp  local_feat_2  local_feat_3   ...   local_feat_93  local_feat_94     agg_feat_95  agg_feat_96  agg_feat_97   ...    agg_feat_164  agg_feat_165  agg_feat_166  
#                         0  230425980   1          -0.171469     -0.184668      ...    1.135523        1.135279        -0.169160    -0.201584    -0.116817   ...      -0.097524     -0.120613     -0.119792  
#                         1    5530458   1          -0.171484     -0.184668      ...   -1.084907       -1.084845        -0.170113    -0.202332    -0.116817   ...      -0.097524     -0.120613     -0.119792  
#                         2  232022460   1          -0.172107     -0.184668      ...   -1.084907       -1.084845        -0.170528    -0.202658    -0.116817   ...      -0.183671     -0.120613     -0.119792  
#                         3  232438397   1           0.163054      1.963790      ...    0.025308        0.025217        -0.171098     0.266450     0.159432   ...       0.677799     -0.120613     -0.119792  
#                         4  230460314   1           1.011523     -0.081127      ...   -0.487315       -0.563089        -0.162974     0.844932     1.723414   ...       1.293750      0.178136      0.179117   

#  Number of rows, edges: 234355
#  Sample data, edges:     txId1      txId2
#                       0  230425980    5530458
#                       1  232022460  232438397
#                       2  230460314  230459870
#                       3  230333930  230595899
#                       4  232013274  232029206

#  Number of rows, classes: 203769
#  Sample data, classes:    txId    class
#                        0  230425980  unknown
#                        1    5530458  unknown
#                        2  232022460  unknown
#                        3  232438397        2
#                        4  230460314  unknown



In [None]:

#  Mutation to classes ..

classes['target']    = classes['class'].map({'unknown': -1.0, '1': 1.0, '2': 0.0})
classes['node_type'] = classes['class'].map({'unknown': 'Unclassified_Txn', '1': 'Classified_Txn', '2': 'Classified_Txn'})

print("--")


In [None]:

#  Non-functional: Just looking at the data ..

#  display(print(f"Number of rows, classes: {len(classes)}"))
#  display(print(f"Sample data, classes: {classes.head(5)}"))


#  Number of rows, classes: 203769
#  Sample data, classes:     txId       class      target       node_type
#                         0  230425980  unknown    -1.0         Unclassified_Txn
#                         1    5530458  unknown    -1.0         Unclassified_Txn
#                         2  232022460  unknown    -1.0         Unclassified_Txn
#                         3  232438397        2     0.0           Classified_Txn
#                         4  230460314  unknown    -1.0         Unclassified_Txn



In [None]:

#  New DataFrame, nodes; features merged with classess ..
#
#     (Number of columns; the two source DataFrames share a key column name.)


#  display(print(f"Number of columns, features: {len(features.columns)}"))
#  display(print(f"Number of columns, classes: {len(classes.columns)}"))


nodes = features.merge(classes)


#  display(print(f"Number of columns, nodes: {len(nodes.columns)}"))

print("--")


#  Number of columns, features: 167
#  Number of columns, classes: 4
#  Number of columns, nodes: 170



In [None]:
# ##################################################################

In [None]:

#  Get a KatanaGraph Connection Handle ..

from katana import remote
from katana.remote import import_data


NUM_PARTITIONS = 8
   #
DB_NAME        = "my_db"
GRAPH_NAME     = "my_graph"


my_client = remote.Client()

print(my_client)


In [None]:

#  Create KatanaGraph Graph object ..

my_graph = my_client.get_database(name=DB_NAME).create_graph(name=GRAPH_NAME, num_partitions=NUM_PARTITIONS)

print(my_graph)


In [None]:
#  CONNECT TO GRAPH

for l_graph in my_client.graphs():
   if (l_graph.name == GRAPH_NAME):
      my_graph=my_client.get_database(name=DB_NAME).get_graph_by_id(id=l_graph.graph_id)
        
print(my_graph)

In [None]:

#  Load Graph ..


REVERSE_EDGES = True


with import_data.DataFrameImporter(my_graph) as df_importer:   
    
   df_importer.nodes_dataframe(nodes,
      id_column             = "txId",
      id_space              = "transaction", 
      property_columns      = feat_types,
         #
      label_column          = "node_type"
      )
    
   df_importer.edges_dataframe(edges,
      source_id_space       = "transaction",
      destination_id_space  = "transaction",
         #
      source_column         = "txId1",
      destination_column    = "txId2",
         #
      type                  = "tx_flow"
      )

   if (REVERSE_EDGES):
      df_importer.edges_dataframe(edges,
         source_id_space        = "transaction",
         destination_id_space   = "transaction",
            #
         source_column          = "txId2",
         destination_column     = "txId1",
            #
         type                   = "rev_tx_flow"
         )
    
        
print("--")
    

In [None]:

#  Checking number of partitions in the graph ..

my_graph.num_partitions


In [None]:

#  Check our work from above ..

display(print(f"Number of nodes: {my_graph.num_nodes()}"))
display(print(f"Number of edges: {my_graph.num_edges()}"))


#  Number of nodes: 203769
#  Number of edges: 468710


In [None]:

#  Check the schema and counts ..


l_result1 = my_graph.query("""

   MATCH (a) WITH DISTINCT LABELS(a) AS temp, COUNT(a) AS tempCnt
   UNWIND temp AS label
   RETURN label, SUM(tempCnt) AS cnt
   ORDER BY label
   
   """)
display(print(l_result1))


l_result2 = my_graph.query("""

   MATCH (m)-[r]->(n) 
   WITH DISTINCT TYPE(r) AS temp, COUNT(r) AS tempCnt
   RETURN temp, tempCnt
   ORDER BY temp

   """)
display(print(l_result2))


#        cnt        label
#  0  203769  transaction
#  
#            temp  tempCnt
#  0  rev_tx_flow   234355
#  1      tx_flow   234355



In [None]:

#  Compared to above; pure schema ..


my_graph.query("CALL graph.schema() RETURN *")


#        neighbor                         nodeType         properties
#        -----------------------------------------------------------------------------------------
#    0                                    [transaction]    agg_feat_100,agg_feat_101,agg_feat_102,agg_fea...
#    1   ([transaction]::[rev_tx_flow])   [transaction] 	
#    2   ([transaction]::[tx_flow])       [transaction] 	


In [None]:

#  Ability to set labels and more on the visualization ..

from katana_visualization_widget import GraphVisOptions, NodeVisOption, EdgeVisOption, ANY

options = GraphVisOptions(
   node_options = [
      NodeVisOption(ANY, label = "title")
   ])

print("--")
    

In [None]:

#  Visualize the graph;  using a small sample ..

l_result = my_graph.query("""

   MATCH (n: transaction )  - [ r ] ->  (m: transaction)
   RETURN n, m, r
   LIMIT 1000
   
   """, contextualize=True)

l_result.view()


#  Result set ..

<img src="./01_Images/result4.png" alt="Result" style="width: 5000px;"/>

In [None]:
# ##################################################################


   To initialize features for the GNN, combine all node properties into a feature vector and save as a new feature on the graph. 

   Katana Graph supports saving binary feature vectors as individual properties on the graph. In this case, we save 3 different feature vectors: 

      1. local_feats   - raw features provided for each transaction
      2. agg_feats     - aggregated features from each node's neighborhood
    
      3. h_init        - both (local_feats + agg_feats) combined into one feature. 
        
         This will be the starting point for our GNN



In [None]:

#  Check the given property keys ..

l_result = my_graph.query("""

   MATCH (n: transaction )
   RETURN n.h_init, n.local_feats, n.agg_feats
   LIMIT 5
   
   """, contextualize=True)

l_result.view()


#  Count: 5 rows
#  n.agg_feats   n.h_init   n.local_feats
#  None	         None       None
#  None	         None       None
#  None	         None       None
#  None	         None       None
#  None	         None       None



In [None]:

#  This next step requires katana_ai.py be placed o nthe worker nodes,
#
#     .  Get the name of just one worker node. 
#        Program 70* will do this.
#
#     .  Sample SCP,
#
#        First, on worker node, make a folder,  (ssh in and),
#           mkdir 05_Packages
#
#        Then,
#           gcloud compute scp  katana_ai.py  known-mongrel-compute-0:/home/farrell_katanagraph_com/05_Packages



In [None]:

#  Last of data mung; mutate the graph for GNN operation ..

def run_feature_init(g): 
    
   import os
   import sys
      #
   import numpy as np

    
   #  Must copy katana_ai.py to Workder node, $HOME/05_Packages
   #     (This seems finicky; finding the file.)
   #
   # print(os.getcwd())
   # sys.path.append(os.path.join("/home/farrell_katanagraph_com/05_Packages"))
   #
   sys.path.insert(1, "/home/farrell_katanagraph_com/05_Packages")
    
   from katana_ai import get_node_property_list, visualize_embeddings, train_test_split_mask, save_features_to_graph

    
    
   # extract features
   #
   local_feats = get_node_property_list(g, property_list=local_feats_name)
   agg_feats   = get_node_property_list(g, property_list=agg_feats_name)
      #
   feat_vec    = np.concatenate([local_feats, agg_feats], axis=-1)
    
    
   # save new features vector to graph
   #
   g = save_features_to_graph(g, feat_vec,    feature_name="h_init"     )
   g = save_features_to_graph(g, local_feats, feature_name="local_feats")
   g = save_features_to_graph(g, agg_feats,   feature_name="agg_feats"  )
    
    
   # create train/test split mask
   #
   g = train_test_split_mask(g, train_test_validation_split=[0.8, 0.15, 0.05])
   g.write()

    
my_graph.run(lambda g: run_feature_init(g))


print("--")


In [None]:

#  Check the given property keys ..

l_result = my_graph.query("""

   MATCH (n: transaction )
   RETURN n.h_init, n.local_feats, n.agg_feats
   LIMIT 5
   
   """, contextualize=True)

l_result.view()


#  Count: 5 rows
#  n.agg_feats                                                                n.h_init                                                                 n.local_feats
#  [-0.1687650829553604, -0.20127403736114502, -0.11681671440601349, ... ]    [-0.16940271854400635, -0.18466755747795105, -1.201368808746338, ... ]   [-0.16940271854400635, -0.18466755747795105, -1.201368808746338, ...
#        ...



In [None]:

#  Generate run ID for Tensorboard ..

import uuid

run_id = uuid.uuid4().hex

print("--")


In [None]:
# ##################################################################

In [None]:

#  Need to place a script on the worker nodes that install TensorFlow and related-
#
#  Full script at,
#     https://github.com/KatanaGraph/solutions/blob/main/fsi/setup_scripts/gcp_dgl_worker_install.sh
#
#  need lines 1-6, 23-33, which are now in 05_Packages/install_tensor.sh
#
#  From, ccd,

   . ./20_Defaults.sh
      #
   FILE_NAME="/workbook/04_Version04/N1_GNN_FinTech/05_Packages/install_tensor.sh"
   l_zone=`70* | grep "compute\-" | head -1 | awk '{print $2}'`
    
   for l_host in `70* | grep "compute\-" | awk '{print $1}'`
      do
      gcloud compute scp  ${MY_CLUSTER_DIR}/${FILE_NAME}  ${l_host}:/tmp 
      gcloud compute ssh --zone ${l_zone} --project ${CLOUDSDK_CORE_PROJECT} ${l_host} --  ". /tmp/install_tensor.sh"
      done

        

In [None]:
# ##################################################################

In [None]:

def analyze_features(g): 
    
   import os
   import sys

    
   #  Must copy katana_ai.py to Workder node, $HOME/05_Packages
   #
   sys.path.append(os.path.join("/home/farrell_katanagraph_com/05_Packages"))

   from katana_ai import visualize_embeddings
    
   from torch.utils.tensorboard import SummaryWriter

    
   writer = SummaryWriter(f"/tmp/tensorboard/elliptic-embed-init-{run_id}")
    
   # analyze features in tensorboard
   #
   visualize_embeddings(g, writer, feature_name="h_init", target_name="target", filter_node_type="Classified_Txn", sample_size=2000)
   visualize_embeddings(g, writer, feature_name="local_feats", target_name="target", filter_node_type="Classified_Txn", sample_size=2000)
   visualize_embeddings(g, writer, feature_name="agg_feats", target_name="target", filter_node_type="Classified_Txn", sample_size=2000)
    
    
my_graph.run(lambda g: analyze_features(g))


print(f"See results at https://demo-finance-tensorboard.katanagraph.com/")
print(f"Run ID: elliptic-embed-init-{run_id}")



In [None]:
# ##################################################################

In [None]:

#  Check the given property keys ..

l_result = my_graph.query("""

   MATCH (n: transaction )
   RETURN n.target
   LIMIT 5
   
   """, contextualize=True)

l_result.view()


#  Count: 5 rows
#  n.target
#  -1.0
#  -1.0
#  -1.0
#  -1.0
#  -1.0

l_result = my_graph.query("""

   MATCH (n: transaction )
   RETURN AVG(n.target)
   
   """, contextualize=True)

l_result.view()


#  Count: 1 rows
#  AVG(n.target)
#  -0.7491816714024214



In [None]:

#  GNN Training ..


import argparse
import numpy


args = argparse.Namespace(
   feat_name             = "h_init",
    
   label_name            = "target",
    
   label_dtype           = numpy.float32,
   split_name            = "train_test_val_mask",
   distributed_execution = True,
   tensorboard_dir       = f"/tmp/tensorboard/elliptic-remote-{run_id}",
   model_dir             = "/tmp/models",
    
   # katana_ai_dir       = "/home/gsteck_katanagraph_com/solutions/fsi/src",
   katana_ai_dir         = "/home/farrell_katanagraph_com/05_Packages ",
    
   pred_node_label       = "Classified_Txn",
   pred_node_label_prop  = "node_type",
   pos_weight            = 8,
   in_dim                = 165,
   hidden_dim            = 256,
   train_fan_in          = "100,100,100,100",
   test_fan_in           = "100,100,100,100",
   num_layers            = 4,
   out_dim               = 1,
   minibatch_size        = 1024,
   max_minibatches       = 20,
   lr                    = 0.001,
   dropout               = 0.2,
   num_epochs            = 100
)

print("--")



In [None]:


def run_gnn(graph, args):
    
    from calendar import c
    import torch
    import numpy
    import katana
    from katana_enterprise.distributed.pytorch import init_workers
    from katana_enterprise.ai.data import PyGNodeSubgraphSampler, SampledSubgraphConfig 
    from katana_enterprise.ai.data import NodeDataLoader
    from torch_geometric.nn import SAGEConv
    from torch.nn.parallel import DistributedDataParallel as torch_DDP
    from torch.utils.tensorboard import SummaryWriter
    import sys, os
    sys.path.append(os.path.join(args.katana_ai_dir))
    from katana_ai import get_split, train_model
    
    os.environ['MODIN_ENGINE']='python'
    #katana.distributed.initialize()

    katana.set_active_threads(32)
    exec_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model definition
    class DistSAGE(torch.nn.Module):
        def __init__(self, in_dim, hidden_dim, out_dim, num_layers,
                     dropout):
            super(DistSAGE, self).__init__()

            self.convs = torch.nn.ModuleList()
            self.convs.append(SAGEConv(in_dim, hidden_dim))
            self.bns = torch.nn.ModuleList()
            self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
            for _ in range(num_layers - 2):
                self.convs.append(SAGEConv(hidden_dim, hidden_dim))
                self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
            self.convs.append(SAGEConv(hidden_dim, out_dim))
            self.activation = torch.nn.functional.relu
            self.dropout = torch.nn.Dropout(dropout)

        def reset_parameters(self):
            for conv in self.convs:
                conv.reset_parameters()
            for bn in self.bns:
                bn.reset_parameters()

        def forward(self, data):
            # unpack data loader
            x, edges = data.x, data.adjs
            #x.to(exec_device)
            #edges.to(exec_device)

            for i, conv in enumerate(self.convs):
                # for multilayer, set x_target for each layer
                x_target = x[:data.dest_count[i]]
                x = conv((x, x_target), edges[i])
                if i != len(self.convs) - 1:
                    x = self.bns[i](x)
                    x = self.activation(x)
                    x = self.dropout(x) 
                    embed = x
            return x, embed
    
    # initialize torch mpi process
    main_start = time.time()
    if args.distributed_execution:
        init_workers()

    # tensorboard writer
    writer = SummaryWriter(args.tensorboard_dir)

    # split train / test node idx
    train_nodes = get_split(graph, 0, split_name=args.split_name, node_label=args.pred_node_label, node_label_prop=args.pred_node_label_prop)
    test_nodes = get_split(graph, 1, split_name=args.split_name, node_label=args.pred_node_label, node_label_prop=args.pred_node_label_prop)

    # initialize the multiminibatch sampler
    train_sampler = PyGNodeSubgraphSampler(
        graph, 
        SampledSubgraphConfig(
        layer_fan=[int(fan_in) for fan_in in args.train_fan_in.split(',')], 
            max_minibatches=args.max_minibatches, 
            batch_props_to_pull=args.max_minibatches,
            feat_prop_name=args.feat_name,
            label_prop_name=args.label_name,
            label_dtype=args.label_dtype,
            multilayer_export=True
        )
    )
    
    # test sampler used for evaluation; it samples 100s per hop to simulate getting
    test_sampler = PyGNodeSubgraphSampler(
        graph, 
        SampledSubgraphConfig(
        layer_fan=[int(fan_in) for fan_in in args.train_fan_in.split(',')], 
            max_minibatches=args.max_minibatches, 
            batch_props_to_pull=args.max_minibatches,
            feat_prop_name=args.feat_name,
            label_prop_name=args.label_name,
            label_dtype=args.label_dtype,
            pull_edge_types=args.load_edge_types,
            multilayer_export=True
        )
    )
    
    # shuffle seeds between epochs + balance seed nodes across hosts
    train_dataloader = NodeDataLoader(
        train_sampler, 
        local_batch_size=args.minibatch_size, 
        node_ids=train_nodes,  
        shuffle=True, 
        drop_last=True,
        balance_seeds=True)
    test_dataloader = NodeDataLoader(
        test_sampler, 
        local_batch_size=args.minibatch_size, 
        node_ids=test_nodes, 
        balance_seeds=True)

    # model initialization
    model = DistSAGE(
        in_dim=args.in_dim, 
        hidden_dim=args.hidden_dim, 
        out_dim=args.out_dim, 
        num_layers=args.num_layers,
        dropout=args.dropout
    ).to(exec_device)

    if args.distributed_execution:
        model = torch_DDP(model)
    
    # optimizer and loss fn
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([args.pos_weight]))
    
    # train model
    train_model(model, loss_function, optimizer, writer, train_dataloader, test_dataloader, args)
    
    # save model
    #ts = time.time()
    #torch.save(model.state_dict(), os.path.join(args.model_dir, 'graph_sage.'+str(ts)+'.pth'))
    
    