In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
import torch
from torch_geometric.utils import dense_to_sparse, from_scipy_sparse_matrix, to_dense_adj
from data_loader import HetGCNEventGraphDataset

In [4]:
data_root_dir = '../ProcessedData_HetGCN'

In [5]:
dataset = HetGCNEventGraphDataset(
    node_feature_csv=f'{data_root_dir}/node_feature_norm.csv',
    edge_index_csv=f'{data_root_dir}/edge_index.csv',
    node_type_txt=f'{data_root_dir}/node_types.txt',
    ignore_weight=True,
    include_edge_type=True
)

reading node features..
reading edge index..
Ignore Edge Weights.
read node types ..
node types txt: 132485
done


In [182]:
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]
num_edge_types = 1
num_node_types = 8

m = torch.distributions.Bernoulli(torch.tensor([0.01]))
num_nodes = node_feature.size()[0]
for etype in range(num_edge_types):
    for src_type in range(num_node_types):
        for dst_type in range(num_node_types):
            src_node_list = node_types[src_type]
            dst_node_list = node_types[dst_type]

            random_het_adj_mat = m.sample((num_nodes, num_nodes)).view((num_nodes, num_nodes)).to(node_feature.device)
            break
        break

In [171]:
random_het_adj_mat
# dense_to_sparse(random_het_adj_mat)

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [172]:
len(src_node_list), len(dst_node_list)

(123, 123)

In [183]:
src_bool = torch.Tensor([True if i in src_node_list else False for i in range(0, num_nodes)])
dst_bool = torch.Tensor([True if i in dst_node_list else False for i in range(0, num_nodes)])
bool_mat = torch.matmul(src_bool.view(-1,1), dst_bool.view(1,-1)).bool()
~bool_mat

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])

In [174]:
random_het_adj_mat.sum()

tensor(7041.)

In [175]:
s = random_het_adj_mat.masked_fill(~bool_mat, 0).sum(dim=1)

s.argwhere().shape

torch.Size([83, 1])

In [184]:
masked_adj_mat = random_het_adj_mat.masked_fill(~bool_mat, 0)
masked_adj_mat.sum()

tensor(152.)

In [155]:
dense_to_sparse(masked_adj_mat.masked_fill(bool_mat, 2))

(tensor([[  0,   0,   0,  ..., 836, 836, 836],
         [  0,  74,  79,  ..., 826, 831, 836]]),
 tensor([2., 2., 2.,  ..., 2., 2., 2.]))

In [177]:
# masked_edge_index = masked_adj_mat.nonzero().t().contiguous()
# masked_edge_index

In [178]:
# masked_edge_index.shape[1]

In [71]:
len(src_node_list)

123

In [206]:
def get_het_edge_index(edge_index, edge_weight, node_types, ntype, source_types=None,
                           edge_type_list=None, edge_type=None, num_node_types=8):
        """
        get het edge index by given type
        """
        row, col = edge_index

        if source_types is not None:
            try:
                num_src_types = len(source_types)
                src_type_idx = int(ntype / num_node_types)
                dst_type = ntype - num_node_types * src_type_idx
                src_type = source_types[src_type_idx]

                if len(node_types[dst_type]) == 0 or len(node_types[src_type]) == 0:
                    return ntype, None, None
                
                # TODO: handle edge type
                

                src_het_mask = sum(row == i for i in node_types[src_type]).bool()
                dst_het_mask = sum(col == i for i in node_types[dst_type]).bool()

                if edge_type is not None and edge_type_list is not None:
                    edge_mask = edge_type_list == edge_type
                    cmask = src_het_mask & dst_het_mask & edge_mask
                else:
                    cmask = src_het_mask & dst_het_mask
            except Exception as e:
                print(f'{src_type_idx} - {dst_type}')
                print(f'row: {row}')
                print(f'node_types[src_type]: {node_types[src_type]}')
                raise Exception(e)
            return ntype, torch.stack([row[cmask], col[cmask]]), edge_weight[cmask]
        else:

            if len(node_types[ntype]) == 0:
                return ntype, None, None

            het_mask = sum(col == i for i in node_types[ntype]).bool()
            return ntype, torch.stack([row[het_mask], col[het_mask]]), edge_weight[het_mask]

In [461]:
edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
_, e_idx, e_weight = get_het_edge_index(edge_index, edge_weight, node_types, 0)

In [462]:
e_idx

tensor([[  3,   7,  11,  15,  19,  23,  28,  32,  37,  40,  43,  47,  51,  55,
          60,  65,  70,  73,  77,  81,  85, 107, 114, 136, 139, 147, 155, 158,
         176, 179, 187, 190, 195, 200, 206, 212, 218, 224, 230, 235, 240, 245,
         249, 253, 257, 261, 266, 267, 268, 269, 270, 271, 244, 274, 275, 273,
         278, 279, 277, 282, 283, 281, 286, 287, 285, 290, 291, 289, 294, 295,
         296, 293, 299, 300, 298, 303, 304, 302, 307, 308, 309, 310, 306, 313,
         314, 312, 317, 318, 316, 321, 322, 323, 320, 326, 327, 328, 325, 331,
         332, 333, 330, 336, 337, 338, 335, 341, 342, 343, 340, 346, 347, 348,
         345, 351, 352, 353, 354, 350, 357, 358, 359, 360, 356, 363, 364, 365,
         366, 362, 369, 370, 368, 373, 374, 372, 377, 378, 376, 381, 382, 383,
         380, 386, 387, 388, 385, 391, 392, 393, 394, 390, 397, 398, 399, 396,
         402, 403, 404, 405, 401, 408, 409, 410, 411, 407, 414, 415, 413, 418],
        [560, 565, 570, 575, 580, 585, 591, 596, 60

In [350]:
perturbation_prob = 0.0002
m = torch.distributions.Bernoulli(torch.tensor([perturbation_prob]))
# TODO: edge perturbation based on edge_index and edge_type
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]
num_nodes = node_feature.size()[0]
new_adj_mat = torch.zeros(num_nodes, num_nodes)

for etype in range(num_edge_types):
    for src_type in range(num_node_types):
        for dst_type in range(num_node_types):
            src_node_list = node_types[src_type]
            dst_node_list = node_types[dst_type]

            random_het_adj_mat = m.sample((num_nodes, num_nodes)).view((num_nodes, num_nodes)).to(node_feature.device)
            src_bool = torch.Tensor([True if i in src_node_list else False for i in range(0, num_nodes)])
            dst_bool = torch.Tensor([True if i in dst_node_list else False for i in range(0, num_nodes)])
            bool_mat = torch.matmul(src_bool.view(-1, 1), dst_bool.view(1, -1)).bool()
            # TODO: continue to reset non-het node type matrix
            random_het_adj_mat += etype
            masked_adj_mat = random_het_adj_mat.masked_fill(~bool_mat, 0)
            # print(masked_adj_mat.sum())

            new_adj_mat += masked_adj_mat

In [351]:
random_het_adj_mat.shape, random_het_adj_mat.sum(), masked_adj_mat.sum(), dense_to_sparse(masked_adj_mat)[0].unique().shape

(torch.Size([841, 841]), tensor(142.), tensor(0.), torch.Size([0]))

In [352]:
edge_index.shape, new_adj_mat.sum()

(torch.Size([2, 962]), tensor(164.))

In [353]:
torch.logical_xor(
    to_dense_adj(edge_index).view(num_nodes, num_nodes),
    new_adj_mat
).int().sum()

tensor(960)

In [355]:
dense_to_sparse(
    torch.logical_xor(
    to_dense_adj(edge_index).view(num_nodes, num_nodes),
    new_adj_mat
).int()
)

(tensor([[  0,   0,   0,  ..., 823, 831, 834],
         [196, 558, 599,  ..., 597, 497, 509]]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1

In [424]:
edge_type

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 

In [463]:
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]
dense_edge_attr = to_dense_adj(edge_index, edge_attr=edge_type+1).view(num_nodes, -1)
dense_edge_attr

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [478]:
to_dense_adj(edge_index).sum(), dense_edge_attr.sum()

(tensor(962.), tensor(1328.))

In [487]:
src_bool = torch.Tensor(
    [True if i in node_types[0] else False for i in range(0, num_nodes)]
)
dst_bool = torch.Tensor(
    [True if i in node_types[1] else False for i in range(0, num_nodes)]
)
bool_mat = torch.matmul(
    src_bool.view(-1, 1), dst_bool.view(1, -1)
).bool()


dense_edge_attr.masked_fill(~bool_mat, 0).sum(), dense_edge_attr.sum(), dense_edge_attr.masked_fill(~bool_mat, 0).unique()

(tensor(34.), tensor(1328.), tensor([0., 1.]))

In [488]:
[i for i in dense_edge_attr.masked_fill(~bool_mat, 0).unique() if i !=0]

[tensor(1.)]

In [15]:
from graph_augmentation import create_het_edge_perturbation
a = create_het_edge_perturbation([dataset[i] for i in range(1)])

torch.Size([2, 2264])


In [16]:
d_idx = 0
element = 1
dataset[6][element].shape, a[d_idx][element].shape

(torch.Size([2, 962]), torch.Size([2, 2264]))