In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np

In [3]:
from torch_geometric.datasets import ZINC

In [4]:
train_set = ZINC('./datasets/ZINC', split='train', subset=True)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
from torch_geometric.data import Batch

In [149]:
b = Batch.from_data_list(train_set[:4]).to(device)

In [150]:
from models.my_encoders import get_atom_encoder, get_bond_encoder

In [151]:
atom_encoder = get_atom_encoder('zinc', 128).to(device)
bond_encoder = get_bond_encoder('zinc', 128).to(device)

In [152]:
b.x = atom_encoder(b)

In [153]:
torch.manual_seed(21)
scores = torch.nn.Parameter(torch.rand(b.num_nodes, 3, 2).to(device))

In [154]:
from samplers.simple_scheme import SIMPLESampler

sampler = SIMPLESampler(1, device)

In [155]:
mask, marginal = sampler(scores)
VE, nnodes, n_centroids, E = mask.shape
repeats = VE * E
node_mask = mask.permute(0, 3, 2, 1).reshape(repeats, n_centroids, nnodes)[..., None]

In [156]:
edge_mask = node_mask[:, :, b.edge_index[0], :] * node_mask[:, :, b.edge_index[1], :]

In [157]:
node_mask.shape

torch.Size([2, 3, 98, 1])

In [129]:
from models.base2centroid import GNNMultiEdgeset

In [130]:
model = GNNMultiEdgeset(
    'gine',
    bond_encoder,
    128,
                 3,
                 3,
                 32,
                 'graph_norm',
                 0.).to(device)

In [131]:
x = b.x.repeat(repeats, n_centroids, 1, 1)  # repeats, n_centroids, nnodes, features
x = x.reshape(-1, x.shape[-1])
num_graphs = b.num_graphs
batch = b.batch.repeat(repeats * n_centroids) + \
        torch.arange(repeats * n_centroids, device=x.device).repeat_interleave(nnodes) * num_graphs

In [200]:
embd = model(x, batch, b.edge_index, b.edge_attr, node_mask, edge_mask)

In [201]:
embd.shape

torch.Size([2, 3, 4, 32])

# construct a hetero graph

In [134]:
cumsum_nnodes = b._slice_dict['x'].to(device)

In [135]:
nnodes_list = cumsum_nnodes[1:] - cumsum_nnodes[:-1]

In [136]:
nnodes_list

tensor([29, 26, 16, 27], device='cuda:0')

In [172]:
src = torch.arange(b.num_nodes * repeats, device=device).repeat_interleave(n_centroids)

In [173]:
dst = torch.arange(repeats * b.num_graphs, device=device).repeat_interleave(nnodes_list.repeat(repeats)) * n_centroids
dst = dst[None] + torch.arange(n_centroids, device=device, dtype=torch.long)[:, None]
dst = dst.t().reshape(-1)

In [174]:
dst.shape

torch.Size([588])

In [175]:
edge_weights = node_mask.squeeze(-1).permute(0, 2, 1).reshape(-1)

In [176]:
base2higher_edge_index = torch.vstack([src, dst])
higher2base_edge_index = torch.vstack([dst, src])

In [177]:
idx = np.hstack([np.vstack(np.triu_indices(n_centroids, k=1)), 
                 np.vstack(np.diag_indices(n_centroids))])

In [178]:
from torch_scatter import scatter_sum

In [179]:
cumsum_nedges = b._slice_dict['edge_index'].to(device)
nedges_list = cumsum_nedges[1:] - cumsum_nedges[:-1]
num_edges = scatter_sum((node_mask[:, idx[0], :, 0][:, :, b.edge_index[0]] * \
                         node_mask[:, idx[1], :, 0][:, :, b.edge_index[1]]), 
                        torch.arange(b.num_graphs, device=device).repeat_interleave(nedges_list), dim=2)

In [183]:
num_edges[:, -n_centroids:, :] = num_edges[:, -n_centroids:, :] / 2

In [185]:
from torch_geometric.utils import to_undirected

In [187]:
idx, edge_weights = to_undirected(torch.from_numpy(idx).to(device), num_edges.permute(1, 0, 2))

In [189]:
edge_weights = edge_weights.permute(1, 0, 2)

In [202]:
edge_weights.shape

torch.Size([2, 9, 4])

In [190]:
idx = idx.repeat(1, b.num_graphs * repeats) + \
        (torch.arange(b.num_graphs * repeats, device=device) * n_centroids).repeat_interleave(idx.shape[1])

In [None]:
edge_weights = edge_weights.permute(0, 2, 1).reshape(-1)

In [None]:
edge_weights