In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

In [3]:
from torch_geometric.datasets import ZINC

In [4]:
train_set = ZINC('./datasets/ZINC', split='train', subset=True)
# train_set.data.edge_attr = train_set.data.edge_attr[:, None]

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
from torch_geometric.data import Batch

In [8]:
data = Batch.from_data_list(train_set[:4]).to(device)

In [9]:
from models.my_encoders import get_atom_encoder, get_bond_encoder

In [10]:
atom_encoder = get_atom_encoder('zinc', 128).to(device)
bond_encoder = get_bond_encoder('zinc', 128).to(device)

In [11]:
x = atom_encoder(data)

In [12]:
x.shape

torch.Size([98, 128])

In [13]:
from samplers.simple_scheme import SIMPLESampler

sampler = SIMPLESampler(1, device)

In [14]:
from models.base2centroid import GNNMultiEdgeset

In [16]:
model = GNNMultiEdgeset(
    'gine',
    bond_encoder,
    128,
     3,
     3,
     128,
     'graph_norm',
    'gelu',
     0.).to(device)

In [17]:
scores = torch.nn.Parameter(torch.rand(data.num_nodes, 3, 2).to(device))

In [18]:
# get scores and samples
node_mask, marginal = sampler(scores)
n_samples, nnodes, n_centroids, n_ensemble = node_mask.shape
repeats = n_samples * n_ensemble

4C1 done


In [19]:
# add a dimension for multiply broadcasting
node_mask = node_mask.permute(0, 3, 2, 1).reshape(repeats, n_centroids, nnodes)[..., None]
edge_mask = node_mask[:, :, data.edge_index[0], :] * node_mask[:, :, data.edge_index[1], :]

In [20]:
node_mask.shape

torch.Size([2, 3, 98, 1])

In [21]:
edge_mask.shape

torch.Size([2, 3, 214, 1])

In [22]:
data.edge_index.shape

torch.Size([2, 214])

In [30]:
nnz = edge_mask.squeeze(-1).nonzero()

In [31]:
nnz = (nnz[:, 2] + nnz[:, 1] * 214) + nnz[:, 0] * 214 * 3

In [32]:
nnz

tensor([   9,   10,   44,   52,   53,   54,  100,  101,  122,  123,  124,  125,
         126,  141,  142,  144,  160,  162,  175,  177,  244,  245,  307,  308,
         309,  310,  312,  321,  331,  333,  368,  369,  371,  383,  384,  396,
         397,  426,  430,  431,  450,  452,  462,  464,  465,  468,  495,  497,
         498,  499,  500,  501,  502,  503,  504,  516,  517,  519,  537,  541,
         574,  575,  614,  615,  616,  630,  632,  638,  642,  643,  676,  677,
         678,  680,  681,  683,  684,  690,  692,  697,  759,  761,  780,  782,
         802,  804,  806,  820,  826,  827,  865,  866,  867,  868,  869,  879,
         884,  908,  910,  913,  915,  916,  933,  934,  936,  938,  988,  989,
        1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1047, 1048, 1049, 1050,
        1051, 1052, 1055, 1059, 1100, 1101, 1143, 1144, 1160, 1161, 1162, 1179,
        1180, 1182, 1192, 1193, 1194, 1211, 1243, 1244, 1278, 1281],
       device='cuda:0')

In [45]:
new_edge_index = data.edge_index.repeat(1, repeats * n_centroids) + \
torch.arange(repeats * n_centroids, device=device).repeat_interleave(data.num_edges) * data.num_nodes

In [50]:
new_edge_index[:, nnz_weights]

tensor([[  4,   5,  19,  23,  24,  24,  45,  46,  56,  56,  57,  57,  58,  65,
          65,  66,  73,  75,  80,  81, 111, 112, 140, 140, 141, 142, 142, 147,
         151, 152, 169, 170, 170, 175, 176, 181, 182, 195, 197, 198, 206, 207,
         211, 212, 212, 213, 226, 228, 228, 228, 229, 229, 230, 230, 231, 236,
         237, 237, 246, 248, 263, 264, 281, 281, 282, 288, 289, 292, 294, 295,
         309, 309, 310, 310, 310, 312, 313, 315, 316, 319, 347, 348, 357, 358,
         367, 369, 369, 376, 378, 379, 396, 397, 397, 397, 398, 402, 404, 415,
         416, 418, 419, 419, 427, 428, 428, 430, 452, 453, 460, 460, 461, 461,
         462, 462, 463, 464, 479, 479, 480, 480, 480, 481, 483, 485, 503, 504,
         523, 524, 531, 531, 532, 540, 540, 541, 546, 546, 547, 555, 569, 570,
         584, 586],
        [  5,   4,  24,  24,  19,  23,  46,  45,  57,  65,  56,  58,  57,  56,
          66,  65,  75,  73,  81,  80, 112, 111, 141, 142, 140, 140, 147, 142,
         152, 151, 170, 169, 195

In [49]:
nnz_weights = edge_mask.reshape(-1).nonzero().squeeze()

In [None]:
edge_mask

In [None]:
batch = data.batch.repeat(repeats * n_centroids) + \
        torch.arange(repeats * n_centroids, device=device).repeat_interleave(nnodes) * data.num_graphs

# repeats, n_centroids, n_graphs, features
centroid_x = model(
    x.repeat(repeats, n_centroids, 1, 1).reshape(-1, x.shape[-1]), 
    batch, data.edge_index, data.edge_attr, node_mask, edge_mask)

In [None]:
cumsum_nnodes = data._slice_dict['x'].to(device)
nnodes_list = cumsum_nnodes[1:] - cumsum_nnodes[:-1]

In [None]:
# low to high hierarchy
src = torch.arange(data.num_nodes * repeats, device=device).repeat_interleave(n_centroids)
dst = torch.arange(repeats * data.num_graphs, device=device).repeat_interleave(
    nnodes_list.repeat(repeats)) * n_centroids
dst = dst[None] + torch.arange(n_centroids, device=device, dtype=torch.long)[:, None]
dst = dst.t().reshape(-1)

In [None]:
from torch_scatter import scatter_sum
from torch_geometric.utils import to_undirected

idx = np.hstack([np.vstack(np.triu_indices(n_centroids, k=1)),
                 np.vstack(np.diag_indices(n_centroids))])

cumsum_nedges = data._slice_dict['edge_index'].to(device)
nedges_list = cumsum_nedges[1:] - cumsum_nedges[:-1]
intra_num_edges = scatter_sum((node_mask[:, idx[0], :, 0][:, :, data.edge_index[0]] *
                         node_mask[:, idx[1], :, 0][:, :, data.edge_index[1]]),
                        torch.arange(data.num_graphs, device=device).repeat_interleave(nedges_list), dim=2)
intra_num_edges[:, -n_centroids:, :] = intra_num_edges[:, -n_centroids:, :] / 2

idx, intra_edge_weights = to_undirected(torch.from_numpy(idx).to(device),
                                        intra_num_edges.permute(1, 0, 2),
                                        reduce='mean')
intra_edge_weights = intra_edge_weights.permute(1, 0, 2) / intra_edge_weights.detach().max()

In [None]:
from torch_geometric.data import HeteroData

new_data = HeteroData(
    base={'x': x.repeat(repeats, 1)},
    centroid={'x': centroid_x},

    base__to__base={'edge_index': data.edge_index.repeat(1, repeats) + \
                                  torch.arange(repeats, device=device).repeat_interleave(data.num_edges) * data.num_nodes,
                    'edge_attr': data.edge_attr.repeat(repeats) if data.edge_attr.dim() == 1 else \
                    data.edge_attr.repeat(repeats, 1),
                    'edge_weight': None},
    base__to__centroid={'edge_index': torch.vstack([src, dst]),
                        'edge_attr': None,
                        'edge_weight': node_mask.squeeze(-1).permute(0, 2, 1).reshape(-1)},
    centroid__to__base={'edge_index': torch.vstack([dst, src]),
                        'edge_attr': None,
                        'edge_weight': node_mask.squeeze(-1).permute(0, 2, 1).reshape(-1)},
    centroid__to__centroid={'edge_index': idx.repeat(1, data.num_graphs * repeats) + \
                                          (torch.arange(data.num_graphs * repeats,
                                                        device=device) * n_centroids).repeat_interleave(idx.shape[1]),
                            'edge_attr': None,
                            'edge_weight': intra_edge_weights.permute(0, 2, 1).reshape(-1)}
)

In [None]:
from models.hetero_gnn import HeteroGNN

In [None]:
gnn = HeteroGNN('gine', True,
                 bond_encoder,
                 128,
                 3,
                 3,
                 3,
                 5,
                 0.,
                 'batchnorm',
                 True).to(device)

In [None]:
gnn(new_data)[0].sum().backward()