In [22]:
import random
import torch
from torch_geometric.utils import (
    dense_to_sparse,
    to_dense_adj,
    remove_isolated_nodes,
    contains_isolated_nodes,
    subgraph, k_hop_subgraph
)
from data_loader import HetGCNEventGraphDataset

In [3]:
data_root_dir = '../ProcessedData_HetGCN'
dataset = HetGCNEventGraphDataset(
    node_feature_csv=f'{data_root_dir}/node_feature_norm.csv',
    edge_index_csv=f'{data_root_dir}/edge_index.csv',
    node_type_txt=f'{data_root_dir}/node_types.txt',
    edge_ratio_csv=f'{data_root_dir}/edge_ratio.csv',
    ignore_weight=True,
    include_edge_type=True,
    edge_ratio_percentile=0.75
)

reading node features..
reading edge index..
Ignore Edge Weights.
read node types ..
node types txt: 132485
reading edge ratio ...
done


In [5]:
# take a sample graph
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]

In [8]:
node_feature.shape, edge_index.shape, (edge_weight, edge_type.shape), len(node_types)

(torch.Size([841, 7]), torch.Size([2, 962]), (None, torch.Size([962])), 8)

In [12]:
contains_isolated_nodes(edge_index)

False

In [67]:
sample_node_ratio = 0.01
sampled_nodes = random.sample(
    range(node_feature.shape[0]),
    int(node_feature.shape[0] * sample_node_ratio)
)
len(sampled_nodes)

8

In [76]:
# subgraph(
#     subset=sampled_nodes,
#     edge_index=edge_index
# )

In [69]:

row, col = edge_index
cond = col == 108
row[cond],col[cond]

(tensor([274]), tensor([108]))

In [102]:
device = edge_index.device
sub_nodes, sub_edge_index, _, sub_edge_mask  = k_hop_subgraph(
    node_idx=sampled_nodes,
    num_hops=1,
    edge_index=edge_index,
    flow='target_to_source'
)
sub_edge_index

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [705, 705, 747, 750, 753, 753, 812, 816, 394, 425]])

In [98]:
sampled_nodes

[482, 825, 179, 559, 181, 731, 245, 138]

In [90]:
sub_edge_index.index_fill(0, torch.tensor([1], device=device), 888)

tensor([[  2,   2, 160, 160, 252, 252, 294, 300, 347, 445],
        [888, 888, 888, 888, 888, 888, 888, 888, 888, 888]])

In [97]:
row, col == edge_index

torch.cat(
    [torch.stack([row[~sub_edge_mask], col[~sub_edge_mask]]), sub_edge_index.index_fill(0, torch.tensor([1], device=device), 888)],
    dim=1
)

tensor([[  0,   3,   3,  ..., 300, 347, 445],
        [558, 556, 560,  ..., 888, 888, 888]])

In [103]:
sub_edge_index

tensor([[138, 138, 179, 179, 181, 181, 245, 245, 482, 482],
        [705, 705, 747, 750, 753, 753, 812, 816, 394, 425]])

In [130]:
new_node_list = []
last_node_id = node_feature.shape[0] - 1
for ntype, ntype_list in enumerate(node_types):
    if len(ntype_list) == 0:
        print(f'skip node type {ntype}')
        continue 
    _mask = sum(sub_edge_index[1] == i for i in ntype_list).bool()

    # skip if no node matched
    if _mask.sum() == 0:
        print('skip mask')
        continue

    new_node_id = last_node_id + 1
    new_node_list.append((new_node_id, ntype))
    sub_edge_index[1] = sub_edge_index[1].masked_fill_(_mask, new_node_id)

    last_node_id = new_node_id

skip node type 6
skip node type 7


In [129]:
new_node_list

[]

In [120]:
sub_edge_index[1] == node_types[7]

tensor([False, False, False,  True, False, False, False,  True, False, False])

In [126]:
_mask.sum()

tensor(2)

In [125]:
sub_edge_index[1].masked_fill_(_mask, 1)

tensor([705, 705, 747,   1, 753, 753, 812,   1, 394, 425])