In [1]:
%load_ext autoreload
%autoreload 2


In [256]:
import torch
from torch_geometric.utils import dense_to_sparse, from_scipy_sparse_matrix, to_dense_adj
from data_loader import HetGCNEventGraphDataset

In [3]:
data_root_dir = '../ProcessedData_HetGCN'

In [4]:
dataset = HetGCNEventGraphDataset(
    node_feature_csv=f'{data_root_dir}/node_feature_norm.csv',
    edge_index_csv=f'{data_root_dir}/edge_index.csv',
    node_type_txt=f'{data_root_dir}/node_types.txt',
    ignore_weight=True,
    include_edge_type=True
)

reading node features..
reading edge index..
Ignore Edge Weights.
read node types ..
node types txt: 132485
done


In [182]:
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]
num_edge_types = 1
num_node_types = 8

m = torch.distributions.Bernoulli(torch.tensor([0.01]))
num_nodes = node_feature.size()[0]
for etype in range(num_edge_types):
    for src_type in range(num_node_types):
        for dst_type in range(num_node_types):
            src_node_list = node_types[src_type]
            dst_node_list = node_types[dst_type]

            random_het_adj_mat = m.sample((num_nodes, num_nodes)).view((num_nodes, num_nodes)).to(node_feature.device)
            break
        break

In [171]:
random_het_adj_mat
# dense_to_sparse(random_het_adj_mat)

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [172]:
len(src_node_list), len(dst_node_list)

(123, 123)

In [183]:
src_bool = torch.Tensor([True if i in src_node_list else False for i in range(0, num_nodes)])
dst_bool = torch.Tensor([True if i in dst_node_list else False for i in range(0, num_nodes)])
bool_mat = torch.matmul(src_bool.view(-1,1), dst_bool.view(1,-1)).bool()
~bool_mat

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])

In [174]:
random_het_adj_mat.sum()

tensor(7041.)

In [175]:
s = random_het_adj_mat.masked_fill(~bool_mat, 0).sum(dim=1)

s.argwhere().shape

torch.Size([83, 1])

In [184]:
masked_adj_mat = random_het_adj_mat.masked_fill(~bool_mat, 0)
masked_adj_mat.sum()

tensor(152.)

In [155]:
dense_to_sparse(masked_adj_mat.masked_fill(bool_mat, 2))

(tensor([[  0,   0,   0,  ..., 836, 836, 836],
         [  0,  74,  79,  ..., 826, 831, 836]]),
 tensor([2., 2., 2.,  ..., 2., 2., 2.]))

In [177]:
# masked_edge_index = masked_adj_mat.nonzero().t().contiguous()
# masked_edge_index

In [178]:
# masked_edge_index.shape[1]

In [71]:
len(src_node_list)

123

In [206]:
def get_het_edge_index(edge_index, edge_weight, node_types, ntype, source_types=None,
                           edge_type_list=None, edge_type=None, num_node_types=8):
        """
        get het edge index by given type
        """
        row, col = edge_index

        if source_types is not None:
            try:
                num_src_types = len(source_types)
                src_type_idx = int(ntype / num_node_types)
                dst_type = ntype - num_node_types * src_type_idx
                src_type = source_types[src_type_idx]

                if len(node_types[dst_type]) == 0 or len(node_types[src_type]) == 0:
                    return ntype, None, None
                
                # TODO: handle edge type
                

                src_het_mask = sum(row == i for i in node_types[src_type]).bool()
                dst_het_mask = sum(col == i for i in node_types[dst_type]).bool()

                if edge_type is not None and edge_type_list is not None:
                    edge_mask = edge_type_list == edge_type
                    cmask = src_het_mask & dst_het_mask & edge_mask
                else:
                    cmask = src_het_mask & dst_het_mask
            except Exception as e:
                print(f'{src_type_idx} - {dst_type}')
                print(f'row: {row}')
                print(f'node_types[src_type]: {node_types[src_type]}')
                raise Exception(e)
            return ntype, torch.stack([row[cmask], col[cmask]]), edge_weight[cmask]
        else:

            if len(node_types[ntype]) == 0:
                return ntype, None, None

            het_mask = sum(col == i for i in node_types[ntype]).bool()
            return ntype, torch.stack([row[het_mask], col[het_mask]]), edge_weight[het_mask]

In [248]:
edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
_, e_idx, e_weight = get_het_edge_index(edge_index, edge_weight, node_types, 0)

In [278]:
e_idx.unique().shape, e_idx.shape

(torch.Size([287]), torch.Size([2, 168]))

In [322]:
perturbation_prob = 0.01
m = torch.distributions.Bernoulli(torch.tensor([perturbation_prob]))
# TODO: edge perturbation based on edge_index and edge_type
node_feature, edge_index, (edge_weight, edge_type), node_types = dataset[0]
num_nodes = node_feature.size()[0]
new_adj_mat = torch.zeros(num_nodes, num_nodes)

for etype in range(num_edge_types):
    for src_type in range(num_node_types):
        for dst_type in range(num_node_types):
            src_node_list = node_types[src_type]
            dst_node_list = node_types[dst_type]

            random_het_adj_mat = m.sample((num_nodes, num_nodes)).view((num_nodes, num_nodes)).to(node_feature.device)
            src_bool = torch.Tensor([True if i in src_node_list else False for i in range(0, num_nodes)])
            dst_bool = torch.Tensor([True if i in dst_node_list else False for i in range(0, num_nodes)])
            bool_mat = torch.matmul(src_bool.view(-1, 1), dst_bool.view(1, -1)).bool()
            # TODO: continue to reset non-het node type matrix
            random_het_adj_mat += etype
            masked_adj_mat = random_het_adj_mat.masked_fill(~bool_mat, 0)
            # print(masked_adj_mat.sum())

            new_adj_mat += masked_adj_mat

In [315]:
random_het_adj_mat.shape, random_het_adj_mat.sum(), masked_adj_mat.sum(), dense_to_sparse(masked_adj_mat)[0].unique().shape

(torch.Size([841, 841]), tensor(7008.), tensor(136.), torch.Size([114]))

In [324]:
edge_index.shape, new_adj_mat.sum()

(torch.Size([2, 962]), tensor(148.))

In [323]:
torch.logical_xor(
    to_dense_adj(edge_index).view(num_nodes, num_nodes),
    new_adj_mat
).int().sum()

tensor(944)

In [246]:
dense_to_sparse(new_adj_mat)[1].unique(), dense_to_sparse(new_adj_mat)

(tensor([1.]),
 (tensor([[  0,   0,   0,  ..., 840, 840, 840],
          [ 32,  43, 308,  ..., 101, 212, 384]]),
  tensor([1., 1., 1.,  ..., 1., 1., 1.])))