In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import random
import torch
from torch_geometric.utils import (
    dense_to_sparse,
    to_dense_adj,
    remove_isolated_nodes,
    contains_isolated_nodes,
    subgraph, k_hop_subgraph
)
from data_loader import HetGCNEventGraphDataset

In [None]:
data_root_dir = '../ProcessedData_HetGCN'
dataset = HetGCNEventGraphDataset(
    node_feature_csv=f'{data_root_dir}/node_feature_norm.csv',
    edge_index_csv=f'{data_root_dir}/edge_index.csv',
    node_type_txt=f'{data_root_dir}/node_types.txt',
    edge_ratio_csv=f'{data_root_dir}/edge_ratio.csv',
    ignore_weight=True,
    include_edge_type=True,
    edge_ratio_percentile=0.75
)

reading node features..
reading edge index..
Ignore Edge Weights.
read node types ..
node types txt: 132485
reading edge ratio ...
done


In [None]:
# take a sample graph
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]

In [8]:
node_feature.shape, edge_index.shape, (edge_weight, edge_type.shape), len(node_types)

(torch.Size([841, 7]), torch.Size([2, 962]), (None, torch.Size([962])), 8)

In [12]:
contains_isolated_nodes(edge_index)

False

In [67]:
sample_node_ratio = 0.01
sampled_nodes = random.sample(
    range(node_feature.shape[0]),
    int(node_feature.shape[0] * sample_node_ratio)
)
len(sampled_nodes)

8

In [76]:
# subgraph(
#     subset=sampled_nodes,
#     edge_index=edge_index
# )

In [69]:

row, col = edge_index
cond = col == 108
row[cond],col[cond]

(tensor([274]), tensor([108]))

In [138]:
device = edge_index.device
sub_nodes, sub_edge_index, _, sub_edge_mask  = k_hop_subgraph(
    node_idx=sampled_nodes,
    num_hops=1,
    edge_index=edge_index,
    flow='target_to_source'
)
sub_edge_index

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [705, 705, 747, 750, 753, 753, 812, 816, 394, 425]])

In [139]:
sampled_nodes

[482, 825, 179, 559, 181, 731, 245, 138]

In [140]:
sub_edge_index.index_fill(0, torch.tensor([1], device=device), 888)

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [888, 888, 888, 888, 888, 888, 888, 888, 888, 888]])

In [141]:
row, col == edge_index

torch.cat(
    [torch.stack([row[~sub_edge_mask], col[~sub_edge_mask]]), sub_edge_index.index_fill(0, torch.tensor([1], device=device), 888)],
    dim=1
)

tensor([[  0,   2,   2,  ..., 245, 482, 482],
        [558, 559, 559,  ..., 888, 888, 888]])

In [142]:
sub_edge_index

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [705, 705, 747, 750, 753, 753, 812, 816, 394, 425]])

In [143]:
new_node_list = []
last_node_id = node_feature.shape[0] - 1
for ntype, ntype_list in enumerate(node_types):
    if len(ntype_list) == 0:
        print(f'skip node type {ntype}')
        continue 
    _mask = sum(sub_edge_index[1] == i for i in ntype_list).bool()

    # skip if no node matched
    if _mask.sum() == 0:
        print('skip mask')
        continue

    new_node_id = last_node_id + 1
    new_node_list.append((new_node_id, ntype))
    sub_edge_index[1] = sub_edge_index[1].masked_fill_(_mask, new_node_id)

    last_node_id = new_node_id

skip mask
skip node type 6
skip node type 7


In [144]:
new_node_list

[(841, 0), (842, 2), (843, 3), (844, 4), (845, 5)]

In [145]:
sub_edge_index

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [844, 844, 842, 841, 845, 845, 842, 841, 843, 842]])

In [136]:
if _mask.sum() == 1:
    print(1)
else:
    print(0)

0


In [125]:
sub_edge_index[1].masked_fill_(_mask, 1)

tensor([705, 705, 747,   1, 753, 753, 812,   1, 394, 425])

In [156]:
node_feature

tensor([[ 2.3552e-04,  2.6773e-04, -1.8384e-05,  ...,  0.0000e+00,
          0.0000e+00,  2.6262e-04],
        [ 2.3552e-04,  2.6773e-04, -1.8384e-05,  ...,  1.0313e-04,
          0.0000e+00,  2.6278e-04],
        [ 8.0982e-05,  2.6773e-04, -1.8384e-05,  ...,  2.0626e-04,
          0.0000e+00,  7.9132e-05],
        ...,
        [ 3.8716e-06,  1.2357e-04,  2.3899e-04,  ...,  2.0626e-04,
          4.9475e-04,  3.7832e-06],
        [ 2.2149e-04,  1.2357e-04,  1.1030e-04,  ...,  2.5782e-04,
          4.9622e-04,  2.2116e-04],
        [ 2.2149e-04,  1.2357e-04,  1.1030e-04,  ...,  2.5782e-04,
          4.9768e-04,  2.2116e-04]])

In [6]:
from graph_augmentation import GraphAugmentator
GraphAugmentator().create_het_node_insertion.__name__
# a = create_het_node_insertion([dataset[0]])

'create_het_node_insertion'

In [202]:
a[1].shape, edge_index.shape, b[1].shape

(torch.Size([2, 962]), torch.Size([2, 962]), torch.Size([2, 962]))

In [203]:
a[0].shape, node_feature.shape, b[0].shape

(torch.Size([846, 7]), torch.Size([841, 7]), torch.Size([850, 7]))

In [204]:
a[1], edge_index

(tensor([[  0,   2,   2,  ..., 458, 520, 520],
         [558, 559, 559,  ..., 843, 844, 843]]),
 tensor([[  0,   2,   2,  ..., 557, 557, 558],
         [558, 559, 559,  ..., 700, 837,  73]]))

In [205]:
edge_index

tensor([[  0,   2,   2,  ..., 557, 557, 558],
        [558, 559, 559,  ..., 700, 837,  73]])

In [206]:
b[1]

tensor([[  0,   2,   2,  ..., 376, 445, 445],
        [558, 559, 559,  ..., 846, 849, 848]])

In [209]:
b[1][1][b[1][1] <= 840]

tensor([558, 559, 559, 556, 560, 562, 562, 563, 563, 561, 565, 567, 567, 568,
        568, 569, 569, 566, 570, 572, 572, 573, 573, 574, 574, 571, 575, 577,
        577, 578, 578, 579, 579, 576, 580, 582, 582, 583, 583, 584, 584, 581,
        585, 587, 587, 588, 588, 589, 589, 590, 590, 586, 591, 593, 593, 594,
        594, 595, 595, 592, 596, 598, 598, 599, 599, 600, 600, 601, 601, 597,
        602, 604, 604, 605, 605, 603, 606, 608, 608, 609, 609, 607, 610, 612,
        612, 613, 613, 614, 614, 611, 615, 617, 617, 618, 618, 619, 619, 616,
        620, 622, 622, 623, 623, 624, 624, 621, 625, 627, 627, 628, 628, 629,
        629, 630, 630, 626, 631, 633, 633, 634, 634, 635, 635, 636, 636, 632,
        637, 639, 639, 640, 640, 641, 641, 642, 642, 638, 643, 645, 645, 646,
        646, 644, 647, 649, 649, 650, 650, 648, 651, 653, 653, 654, 654, 652,
        655, 657, 657, 658, 658, 659, 659, 656, 660, 662, 662, 663, 663, 664,
        664, 665, 665, 666, 666, 667, 667, 668, 668, 669, 669, 6

In [208]:
a[1]

tensor([[  0,   2,   2,  ..., 458, 520, 520],
        [558, 559, 559,  ..., 843, 844, 843]])

In [None]:
from graph_augmentation import GraphAugmentator
GraphAugmentator().create_node_type_swap()
# a = create_het_node_insertion([dataset[0]])