In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np

In [None]:
from torch_geometric.datasets import ZINC

In [None]:
train_set = ZINC('./datasets/ZINC', split='train', subset=True)
# train_set.data.edge_attr = train_set.data.edge_attr[:, None]

In [None]:
from torch_geometric.utils import is_undirected

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch_geometric.data import Batch

In [None]:
data = Batch.from_data_list(train_set[:4]).to(device)

In [None]:
from models.my_encoders import get_atom_encoder, get_bond_encoder

In [None]:
atom_encoder = get_atom_encoder('zinc', 128).to(device)
bond_encoder = get_bond_encoder('zinc', 128).to(device)

In [None]:
x = atom_encoder(data)

In [None]:
x.shape

In [None]:
from samplers.simple_scheme import SIMPLESampler

sampler = SIMPLESampler(1, device)

In [None]:
from models.base2centroid import GNNMultiEdgeset

In [None]:
model = GNNMultiEdgeset(
    'gine',
    bond_encoder,
    128,
                 3,
                 3,
                 128,
                 'graph_norm',
                 0.).to(device)

In [None]:
scores = torch.nn.Parameter(torch.rand(data.num_nodes, 3, 2).to(device))

In [None]:
# get scores and samples
node_mask, marginal = sampler(scores)
n_samples, nnodes, n_centroids, n_ensemble = node_mask.shape
repeats = n_samples * n_ensemble

In [None]:
# add a dimension for multiply broadcasting
node_mask = node_mask.permute(0, 3, 2, 1).reshape(repeats, n_centroids, nnodes)[..., None]
edge_mask = node_mask[:, :, data.edge_index[0], :] * node_mask[:, :, data.edge_index[1], :]

In [None]:
batch = data.batch.repeat(repeats * n_centroids) + \
        torch.arange(repeats * n_centroids, device=device).repeat_interleave(nnodes) * data.num_graphs

# repeats, n_centroids, n_graphs, features
centroid_x = model(
    x.repeat(repeats, n_centroids, 1, 1).reshape(-1, x.shape[-1]), 
    batch, data.edge_index, data.edge_attr, node_mask, edge_mask)

In [None]:
cumsum_nnodes = data._slice_dict['x'].to(device)
nnodes_list = cumsum_nnodes[1:] - cumsum_nnodes[:-1]

In [None]:
# low to high hierarchy
src = torch.arange(data.num_nodes * repeats, device=device).repeat_interleave(n_centroids)
dst = torch.arange(repeats * data.num_graphs, device=device).repeat_interleave(
    nnodes_list.repeat(repeats)) * n_centroids
dst = dst[None] + torch.arange(n_centroids, device=device, dtype=torch.long)[:, None]
dst = dst.t().reshape(-1)

In [None]:
from torch_scatter import scatter_sum
from torch_geometric.utils import to_undirected

idx = np.hstack([np.vstack(np.triu_indices(n_centroids, k=1)),
                 np.vstack(np.diag_indices(n_centroids))])

cumsum_nedges = data._slice_dict['edge_index'].to(device)
nedges_list = cumsum_nedges[1:] - cumsum_nedges[:-1]
intra_num_edges = scatter_sum((node_mask[:, idx[0], :, 0][:, :, data.edge_index[0]] *
                         node_mask[:, idx[1], :, 0][:, :, data.edge_index[1]]),
                        torch.arange(data.num_graphs, device=device).repeat_interleave(nedges_list), dim=2)
intra_num_edges[:, -n_centroids:, :] = intra_num_edges[:, -n_centroids:, :] / 2

idx, intra_edge_weights = to_undirected(torch.from_numpy(idx).to(device),
                                        intra_num_edges.permute(1, 0, 2),
                                        reduce='mean')
intra_edge_weights = intra_edge_weights.permute(1, 0, 2) / intra_edge_weights.detach().max()

In [None]:
from torch_geometric.data import HeteroData

new_data = HeteroData(
    base={'x': x.repeat(repeats, 1)},
    centroid={'x': centroid_x},

    base__to__base={'edge_index': data.edge_index.repeat(1, repeats) + \
                                  torch.arange(repeats, device=device).repeat_interleave(data.num_edges) * data.num_nodes,
                    'edge_attr': data.edge_attr.repeat(repeats) if data.edge_attr.dim() == 1 else \
                    data.edge_attr.repeat(repeats, 1),
                    'edge_weight': None},
    base__to__centroid={'edge_index': torch.vstack([src, dst]),
                        'edge_attr': None,
                        'edge_weight': node_mask.squeeze(-1).permute(0, 2, 1).reshape(-1)},
    centroid__to__base={'edge_index': torch.vstack([dst, src]),
                        'edge_attr': None,
                        'edge_weight': node_mask.squeeze(-1).permute(0, 2, 1).reshape(-1)},
    centroid__to__centroid={'edge_index': idx.repeat(1, data.num_graphs * repeats) + \
                                          (torch.arange(data.num_graphs * repeats,
                                                        device=device) * n_centroids).repeat_interleave(idx.shape[1]),
                            'edge_attr': None,
                            'edge_weight': intra_edge_weights.permute(0, 2, 1).reshape(-1)}
)

In [None]:
from models.hetero_gnn import HeteroGNN

In [None]:
gnn = HeteroGNN('gine', True,
                 bond_encoder,
                 128,
                 3,
                 3,
                 3,
                 5,
                 0.,
                 'batchnorm',
                 True).to(device)

In [None]:
gnn(new_data)[0].sum().backward()