In [167]:
# %%
import pandas as pd
import numpy as np
from tqdm import tqdm
import plotly.express as px
import json

import torch
from torch_geometric.utils import dense_to_sparse, to_dense_adj
import dgl
from dgl import save_graphs, load_graphs

In [168]:
data_path = "/Users/jl102430/Documents/study/anomaly_detection/data/dynamic/DGraph/DGraphFin/dgraphfin.npz"

In [169]:
data = np.load(data_path)
data

# %%
X = data["x"]
y = data["y"]

edge_index = data["edge_index"]
edge_type = data["edge_type"]
edge_timestamp = data["edge_timestamp"]

train_mask = data["train_mask"]
valid_mask = data["valid_mask"]
test_mask = data["test_mask"]


print(
    f"""
X shape: {X.shape},
y shape: {y.shape}

edge_index shape: {edge_index.shape}
edge_type shape: {edge_type.shape}
edge_timestamp shape: {edge_timestamp.shape}

train_mask shape: {train_mask.shape}
valid_mask shape: {valid_mask.shape}
test_mask shape: {test_mask.shape}
"""
)


X shape: (3700550, 17),
y shape: (3700550,)

edge_index shape: (4300999, 2)
edge_type shape: (4300999,)
edge_timestamp shape: (4300999,)

train_mask shape: (857899,)
valid_mask shape: (183862,)
test_mask shape: (183840,)



In [170]:
edge_timestamp[edge_timestamp <= 7].shape

(32454,)

In [171]:
# train_X, train_y = X[train_mask], y[train_mask]
edge_index[train_mask].shape, edge_timestamp[train_mask].shape, edge_type[train_mask].shape

((857899, 2), (857899,), (857899,))

In [172]:
train_edge_index = pd.DataFrame(edge_index[train_mask], columns=['src_id', 'dst_id'])
train_edge_index['timestamp'] = edge_timestamp[train_mask]
train_edge_index['edge_type'] = edge_type[train_mask]
train_edge_index = train_edge_index.sort_values('timestamp').reset_index(drop=True)

train_edge_index

Unnamed: 0,src_id,dst_id,timestamp,edge_type
0,1810566,1361425,1,10
1,1783155,1544039,1,11
2,1728394,2239849,1,10
3,1886055,683274,1,10
4,2203323,773310,1,11
...,...,...,...,...
857894,3683404,3490494,821,2
857895,3683543,2721874,821,8
857896,394482,936197,821,5
857897,3683498,3595997,821,5


In [173]:
valid_edge_index = pd.DataFrame(edge_index[valid_mask], columns=['src_id', 'dst_id'])
valid_edge_index['timestamp'] = edge_timestamp[valid_mask]
valid_edge_index['edge_type'] = edge_type[valid_mask]
valid_edge_index = valid_edge_index.sort_values('timestamp').reset_index(drop=True)

valid_edge_index

Unnamed: 0,src_id,dst_id,timestamp,edge_type
0,1415795,543208,1,10
1,1736265,1592619,1,10
2,2195490,2276340,1,9
3,1884566,1879177,1,10
4,577700,24468,1,9
...,...,...,...,...
183857,185828,1669105,821,5
183858,3683854,3145323,821,5
183859,971252,3684448,821,4
183860,3684515,301434,821,5


In [174]:
test_edge_index = pd.DataFrame(edge_index[test_mask], columns=['src_id', 'dst_id'])
test_edge_index['timestamp'] = edge_timestamp[test_mask]
test_edge_index['edge_type'] = edge_type[test_mask]
test_edge_index = test_edge_index.sort_values('timestamp').reset_index(drop=True)

test_edge_index

Unnamed: 0,src_id,dst_id,timestamp,edge_type
0,682425,1496933,1,9
1,1911080,2199706,1,9
2,5388,1223207,1,9
3,666234,1265083,1,11
4,204916,400718,1,10
...,...,...,...,...
183835,3683908,408379,821,5
183836,3677127,914834,821,5
183837,2975154,1029181,821,5
183838,606050,2370665,821,4


In [175]:
# frequency of nodes appeared
train_edge_index.groupby('src_id').count().reset_index().groupby('dst_id').count()

Unnamed: 0_level_0,src_id,timestamp,edge_type
dst_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,626433,626433,626433
2,97411,97411,97411
3,10916,10916,10916
4,900,900,900
5,58,58,58
6,1,1,1


In [176]:
# node may appear again in later timestamps, large gap
train_edge_index[train_edge_index['timestamp'] <= 6].groupby('src_id').count().reset_index().groupby('dst_id').count()

Unnamed: 0_level_0,src_id,timestamp,edge_type
dst_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,5494,5494,5494
2,42,42,42


In [177]:
X.shape

(3700550, 17)

In [178]:
node_feature = pd.DataFrame(X, columns=[f'feat_{i}' for i in range(17)])
node_feature['y'] = y

node_feature = node_feature.reset_index()
node_feature

Unnamed: 0,index,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,feat_6,feat_7,feat_8,feat_9,feat_10,feat_11,feat_12,feat_13,feat_14,feat_15,feat_16,y
0,0,0.0,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
1,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,3
2,2,0.0,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
3,3,1.0,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,3
4,4,1.0,7.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3700545,3700545,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
3700546,3700546,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
3700547,3700547,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2
3700548,3700548,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2


In [179]:
# setting node type
node_type_map = {
    0: 'A',
    1: 'B',
    -1: 'C'
}
node_types = node_feature['feat_0'].apply(lambda x: node_type_map[int(x)] ).reset_index(name='node_type')


node_feature = node_feature.drop('feat_0', axis=1).merge(
    node_types,
    on='index',
    how='left'
)

node_types

Unnamed: 0,index,node_type
0,0,A
1,1,C
2,2,A
3,3,B
4,4,B
...,...,...
3700545,3700545,C
3700546,3700546,C
3700547,3700547,C
3700548,3700548,C


In [180]:
node_feature

Unnamed: 0,index,feat_1,feat_2,feat_3,feat_4,feat_5,feat_6,feat_7,feat_8,feat_9,feat_10,feat_11,feat_12,feat_13,feat_14,feat_15,feat_16,y,node_type
0,0,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,A
1,1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,3,C
2,2,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,A
3,3,5.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,3,B
4,4,7.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,B
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3700545,3700545,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,C
3700546,3700546,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,C
3700547,3700547,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,C
3700548,3700548,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,2,C


In [181]:
node_feature.y.values

array([2, 3, 2, ..., 2, 2, 2])

# Construct DGL Dataset

In [182]:
def construct_htg_dgraph(edge_index_df, node_feature_df, timestamp, time_window=2):
    ts = timestamp - time_window

    cond = (edge_index_df["timestamp"] > ts) & (edge_index_df["timestamp"] <= timestamp)

    _df = edge_index_df[cond]

    all_nodes = list(set(_df["src_id"].values).union(set(_df["dst_id"].values)))

    _node_feature = node_feature_df[
        node_feature_df["index"].isin(all_nodes)
    ].reset_index(drop=True)
    _node_feature["type_index"] = _node_feature.groupby(["node_type"]).cumcount()

    _df = _df.merge(
        _node_feature[["index", "node_type", "type_index"]].rename(
            columns={
                "index": "src_id",
                "node_type": "src_type",
                "type_index": "type_src_id",
            }
        ),
        on="src_id",
        how="left",
    ).merge(
        _node_feature[["index", "node_type", "type_index"]].rename(
            columns={
                "index": "dst_id",
                "node_type": "dst_type",
                "type_index": "type_dst_id",
            }
        ),
        on="dst_id",
        how="left",
    )
    hetero_dict = {}
    for idx, i in (
        _df.groupby(["src_type", "edge_type", "timestamp", "dst_type"])
        .agg({"type_src_id": lambda x: list(x), "type_dst_id": lambda x: list(x)})
        .reset_index()
        .iterrows()
    ):
        src, dst, e_ts, etype, src_type, dst_type = (
            i["type_src_id"],
            i["type_dst_id"],
            i["timestamp"],
            i["edge_type"],
            i["src_type"],
            i["dst_type"],
        )
        hetero_dict[(src_type, f"{etype}_t{e_ts}", dst_type)] = (src, dst)
    G_feat = dgl.heterograph(hetero_dict)

    feat_cols = [c for c in _node_feature.columns if 'feat' in c]

    G_label = {}
    for ntype in G_feat.ntypes:
        G_feat.nodes[ntype].data['feat'] = torch.tensor(_node_feature[_node_feature['node_type'] == ntype][feat_cols].values, dtype=torch.float32)

        G_label[ntype] = torch.tensor(_node_feature[_node_feature['node_type'] == ntype].y.values, dtype=torch.float32)

    

    # return _df, _node_feature, G_feat
    return G_feat, G_label

def load_dgraph_data(edge_index_df, node_feature_df):
    feats = []
    labels = []
    for i in tqdm(sorted(edge_index_df['timestamp'].unique())):
        G_feat, G_label = construct_htg_dgraph(edge_index_df, node_feature_df, timestamp=i)
        feats.append(G_feat)
        labels.append(G_label)
    
    return feats, labels


In [183]:
train_feats, train_labels = load_dgraph_data(train_edge_index, node_feature)
len(train_feats), train_feats[0], train_labels[0]

  0%|          | 0/821 [00:00<?, ?it/s]

100%|██████████| 821/821 [03:08<00:00,  4.36it/s]


(821,
 Graph(num_nodes={'A': 521, 'B': 916, 'C': 307},
       num_edges={('A', '10_t1', 'A'): 42, ('A', '10_t1', 'B'): 46, ('A', '10_t1', 'C'): 32, ('A', '11_t1', 'A'): 10, ('A', '11_t1', 'B'): 15, ('A', '11_t1', 'C'): 9, ('A', '9_t1', 'A'): 15, ('A', '9_t1', 'B'): 58, ('A', '9_t1', 'C'): 14, ('B', '10_t1', 'A'): 64, ('B', '10_t1', 'B'): 142, ('B', '10_t1', 'C'): 78, ('B', '11_t1', 'A'): 23, ('B', '11_t1', 'B'): 72, ('B', '11_t1', 'C'): 17, ('B', '9_t1', 'A'): 89, ('B', '9_t1', 'B'): 18, ('B', '9_t1', 'C'): 26, ('C', '10_t1', 'A'): 23, ('C', '10_t1', 'B'): 25, ('C', '10_t1', 'C'): 12, ('C', '11_t1', 'A'): 1, ('C', '11_t1', 'B'): 6, ('C', '11_t1', 'C'): 6, ('C', '9_t1', 'A'): 21, ('C', '9_t1', 'B'): 11, ('C', '9_t1', 'C'): 5},
       metagraph=[('A', 'A', '10_t1'), ('A', 'A', '11_t1'), ('A', 'A', '9_t1'), ('A', 'B', '10_t1'), ('A', 'B', '11_t1'), ('A', 'B', '9_t1'), ('A', 'C', '10_t1'), ('A', 'C', '11_t1'), ('A', 'C', '9_t1'), ('B', 'A', '10_t1'), ('B', 'A', '11_t1'), ('B', 'A', '9_t1')

In [184]:
valid_feats, valid_labels = load_dgraph_data(valid_edge_index, node_feature)
len(valid_feats), valid_feats[0], valid_labels[0]

100%|██████████| 821/821 [02:35<00:00,  5.26it/s]


(821,
 Graph(num_nodes={'A': 126, 'B': 178, 'C': 78},
       num_edges={('A', '10_t1', 'A'): 10, ('A', '10_t1', 'B'): 7, ('A', '10_t1', 'C'): 5, ('A', '11_t1', 'B'): 4, ('A', '11_t1', 'C'): 2, ('A', '9_t1', 'A'): 5, ('A', '9_t1', 'B'): 9, ('A', '9_t1', 'C'): 3, ('B', '10_t1', 'A'): 18, ('B', '10_t1', 'B'): 20, ('B', '10_t1', 'C'): 21, ('B', '11_t1', 'A'): 4, ('B', '11_t1', 'B'): 14, ('B', '11_t1', 'C'): 3, ('B', '9_t1', 'A'): 29, ('B', '9_t1', 'B'): 1, ('B', '9_t1', 'C'): 3, ('C', '10_t1', 'A'): 8, ('C', '10_t1', 'B'): 7, ('C', '10_t1', 'C'): 7, ('C', '11_t1', 'A'): 1, ('C', '11_t1', 'B'): 1, ('C', '9_t1', 'A'): 6, ('C', '9_t1', 'B'): 2, ('C', '9_t1', 'C'): 1},
       metagraph=[('A', 'A', '10_t1'), ('A', 'A', '9_t1'), ('A', 'B', '10_t1'), ('A', 'B', '11_t1'), ('A', 'B', '9_t1'), ('A', 'C', '10_t1'), ('A', 'C', '11_t1'), ('A', 'C', '9_t1'), ('B', 'A', '10_t1'), ('B', 'A', '11_t1'), ('B', 'A', '9_t1'), ('B', 'B', '10_t1'), ('B', 'B', '11_t1'), ('B', 'B', '9_t1'), ('B', 'C', '10_t1'), ('

In [185]:
test_feats, test_labels = load_dgraph_data(test_edge_index, node_feature)
len(test_feats), test_feats[0], test_labels[0]

100%|██████████| 821/821 [02:32<00:00,  5.38it/s]


(821,
 Graph(num_nodes={'A': 123, 'B': 187, 'C': 80},
       num_edges={('A', '10_t1', 'A'): 10, ('A', '10_t1', 'B'): 7, ('A', '10_t1', 'C'): 6, ('A', '11_t1', 'A'): 1, ('A', '11_t1', 'B'): 2, ('A', '11_t1', 'C'): 3, ('A', '9_t1', 'A'): 4, ('A', '9_t1', 'B'): 17, ('A', '9_t1', 'C'): 6, ('B', '10_t1', 'A'): 11, ('B', '10_t1', 'B'): 34, ('B', '10_t1', 'C'): 18, ('B', '11_t1', 'A'): 4, ('B', '11_t1', 'B'): 6, ('B', '11_t1', 'C'): 5, ('B', '9_t1', 'A'): 21, ('B', '9_t1', 'B'): 3, ('B', '9_t1', 'C'): 6, ('C', '10_t1', 'A'): 5, ('C', '10_t1', 'B'): 7, ('C', '11_t1', 'A'): 2, ('C', '11_t1', 'B'): 2, ('C', '11_t1', 'C'): 1, ('C', '9_t1', 'A'): 9, ('C', '9_t1', 'B'): 1, ('C', '9_t1', 'C'): 4},
       metagraph=[('A', 'A', '10_t1'), ('A', 'A', '11_t1'), ('A', 'A', '9_t1'), ('A', 'B', '10_t1'), ('A', 'B', '11_t1'), ('A', 'B', '9_t1'), ('A', 'C', '10_t1'), ('A', 'C', '11_t1'), ('A', 'C', '9_t1'), ('B', 'A', '10_t1'), ('B', 'A', '11_t1'), ('B', 'A', '9_t1'), ('B', 'B', '10_t1'), ('B', 'B', '11_t1')

In [186]:
train_labels[0]

{'A': tensor([3., 0., 0., 3., 0., 3., 0., 0., 0., 0., 3., 2., 0., 0., 0., 0., 3., 3.,
         0., 0., 0., 0., 3., 0., 3., 3., 0., 0., 0., 3., 0., 0., 2., 3., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 3., 2., 3., 0., 2., 0., 0., 0., 3., 3., 0.,
         2., 3., 2., 0., 3., 0., 1., 3., 3., 0., 2., 3., 2., 2., 0., 0., 0., 0.,
         0., 0., 3., 3., 0., 0., 0., 2., 0., 0., 0., 3., 0., 3., 3., 2., 0., 3.,
         0., 0., 0., 3., 3., 3., 0., 3., 3., 0., 0., 0., 3., 2., 0., 0., 2., 3.,
         3., 0., 3., 3., 0., 3., 0., 3., 3., 2., 0., 2., 0., 0., 3., 3., 0., 3.,
         0., 0., 3., 3., 0., 0., 2., 0., 2., 3., 0., 3., 0., 0., 3., 3., 3., 2.,
         2., 0., 0., 0., 0., 3., 3., 0., 3., 0., 3., 0., 3., 0., 3., 0., 3., 2.,
         0., 2., 3., 3., 3., 0., 2., 0., 0., 0., 3., 3., 3., 0., 0., 0., 3., 0.,
         0., 2., 0., 0., 3., 0., 2., 0., 0., 3., 3., 0., 3., 3., 3., 3., 1., 0.,
         3., 0., 2., 2., 2., 3., 2., 3., 3., 3., 3., 0., 0., 0., 0., 3., 0., 0.,
         3., 0., 0., 3.

## Save DGL Datasets

In [187]:
dgl.save_graphs("./data/dgraph/train_feats.bin", train_feats)
dgl.save_graphs("./data/dgraph/valid_feats.bin", valid_feats)
dgl.save_graphs("./data/dgraph/test_feats.bin", test_feats)

torch.save(train_labels, './data/dgraph/train_labels.pt')
torch.save(valid_labels, './data/dgraph/valid_labels.pt')
torch.save(test_labels, './data/dgraph/test_labels.pt')


: 

In [135]:
# dgl.load_graphs("./data/dgraph/valid_feats.bin")