In [1]:
import os
import random
import numpy as np
import torch
from torch_geometric.data import HeteroData
from neo4j import GraphDatabase

# set seed
seed = 2023
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
DATA_PATH = "./data"

URI = "neo4j://localhost"
AUTH = ("neo4j", "password")

In [2]:
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()

data = HeteroData()

In [18]:
transaction_index_mappings = {}
account_index_mappings = {}
user_index_mappings = {}
country_index_mappings = {}
lob_index_mappings = {}
sector_index_mappings = {}

In [19]:
def construct_nodes(transactions, users, accounts):#, countries, lobs, sectors):
    data['transaction'].x = []
    for index, t in enumerate(transactions):
        data['transaction'].x.append(list(t[0].values()))
        transaction_index_mappings[t[0]['id']] = index

    data['user'].x = []
    for index, t in enumerate(users):
        data['user'].x.append(list(t[0].values()))
        user_index_mappings[t[0]['id']] = index

    data['account'].x = []
    for index, t in enumerate(accounts):
        data['account'].x.append(list(t[0].values()))
        account_index_mappings[t[0]['id']] = index
    
    # data['user'].x = [list(t[0].values()) for t in users]
    # data['account'].x = [list(t[0].values()) for t in accounts]
    # data['country'].x = [list(t[0].values()) for t in countries]
    # data['lob'].x = [list(t[0].values()) for t in lobs]
    # data['sector'].x = [list(t[0].values()) for t in sectors]

def fetch_nodes(tx):
    transactions = list(tx.run("MATCH (n:Transaction) RETURN properties(n)"))
    users = list(tx.run("MATCH (n:User) RETURN properties(n)"))
    accounts = list(tx.run("MATCH (n:Account) RETURN properties(n)"))
    # countries = list(tx.run("MATCH (n:Country) RETURN properties(n)"))
    # lobs = list(tx.run("MATCH (n:Lob) RETURN properties(n)"))
    # sectors = list(tx.run("MATCH (n:Sector) RETURN properties(n)"))

    construct_nodes(transactions, users, accounts)#, countries, lobs, sectors)

In [20]:
with driver.session() as session:
    session.execute_read(fetch_nodes)

In [21]:
def construct_edges(belongs_to, received_by, transferred_by):
    data['account', 'belongs_to', 'user'].edge_index = [[account_index_mappings[r[0]['account_id']], user_index_mappings[r[0]['user_id']]] for r in belongs_to]
    # data['account', 'from', 'country'].edge_index = []#... # [2, num_edges_writes]
    # data['account', 'lob_in', 'lob'].edge_index = []#... # [2, num_edges_affiliated]
    data['transaction', 'received_by', 'account'].edge_index = [[transaction_index_mappings[r[0]['txn_id']], account_index_mappings[r[0]['account_id']]] for r in received_by]
    data['transaction', 'transferred_by', 'account'].edge_index = [[transaction_index_mappings[r[0]['txn_id']], account_index_mappings[r[0]['account_id']]] for r in transferred_by]
    # data['account', 'works_in', 'sector'].edge_index = []
    
def fetch_edges(tx):
    belongs_to = list(tx.run(f"MATCH ()-[r:BELONGS_TO]->() RETURN properties(r)"))
    # from_country = list(tx.run("MATCH ()-[r:FROM]->() RETURN r"))
    # lob_in = list(tx.run("MATCH ()-[r:LOB_IN]->() RETURN r"))
    received_by = list(tx.run(f"MATCH ()-[r:RECEIVED_BY]->() RETURN properties(r)"))
    transferred_by = list(tx.run(f"MATCH ()-[r:TRANSFERRED_BY]->() RETURN properties(r)"))
    # works_in = list(tx.run("MATCH ()-[r:WORKS_IN]->() RETURN r"))
    construct_edges(belongs_to, received_by, transferred_by)

In [22]:
with driver.session() as session:
    session.execute_read(fetch_edges)

print(data)

HeteroData(
  [1mtransaction[0m={ x=[1498177] },
  [1muser[0m={ x=[288867] },
  [1maccount[0m={ x=[305429] },
  [1m(account, belongs_to, user)[0m={ edge_index=[305429] },
  [1m(transaction, received_by, account)[0m={ edge_index=[1282284] },
  [1m(transaction, transferred_by, account)[0m={ edge_index=[1279291] }
)


In [None]:
graphs = data[0]
labels = data[1]
test_graphs = data[2]
test_labels = data[3]

for i in range(len(graphs)):
    graphs[i].graph["label"] = labels[i]
for i in range(len(test_graphs)):
    test_graphs[i].graph["label"] = test_labels[i]

if feat is None:
    featgen_const = featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float))
    for G in graphs:
        featgen_const.gen_node_features(G)
    for G in test_graphs:
        featgen_const.gen_node_features(G)

train_dataset, test_dataset, max_num_nodes = prepare_data(
    graphs, args, test_graphs=test_graphs
)
model = models.GcnEncoderGraph(
    args.input_dim,
    args.hidden_dim,
    args.output_dim,
    args.num_classes,
    args.num_gc_layers,
    bn=args.bn,
).cuda()
train(train_dataset, model, args, test_dataset=test_dataset)
evaluate(test_dataset, model, args, "Validation")

In [None]:
driver.close()