### Autoregressive Diffusion Models on Graphs

In [1]:
import torch_geometric
from torch_geometric.datasets import ZINC
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# instanciate the dataset
dataset = ZINC(root='/tmp/ZINC', transform=None, pre_transform=None)

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting /tmp/ZINC/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
Processing train dataset: 100%|██████████| 220011/220011 [00:14<00:00, 15604.59it/s]
Processing val dataset: 100%|██████████| 24445/24445 [00:02<00:00, 10267.08it/s]
Processing test dataset: 100%|██████████| 5000/5000 [00:00<00:00, 12979.67it/s]
Done!


In [3]:
point = dataset[2]
point

Data(x=[26, 1], edge_index=[2, 58], edge_attr=[58], y=[1])

In [4]:
point.edge_attr

tensor([2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2,
        1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2,
        1, 2, 1, 1, 1, 1, 1, 1, 1, 1])

In [6]:
# absence of edges should be an edge type, attribute 0
# fully connect graph

In [7]:
point.y

tensor([1.0851])

In [8]:
'''
Masking node mechanism
1. Masked node (x = -1)
2. Connected to all other nodes in graph by masked edges (edge_attr = -1)
'''

'\nMasking node mechanism\n1. Masked node (x = 0)\n2. Connected to all other nodes in graph by masked edges (edge_attr = -1)\n'

In [9]:
point.to_dict()

{'x': tensor([[0],
         [0],
         [0],
         [2],
         [0],
         [5],
         [4],
         [0],
         [0],
         [2],
         [0],
         [0],
         [0],
         [1],
         [2],
         [0],
         [0],
         [0],
         [0],
         [5],
         [2],
         [2],
         [0],
         [0],
         [0],
         [0]]),
 'edge_index': tensor([[ 0,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  6,  6,  6,  7,  8,
           8,  8,  9,  9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 14, 14, 15, 15,
          15, 16, 17, 17, 17, 18, 19, 19, 20, 20, 21, 21, 22, 22, 22, 23, 23, 23,
          24, 24, 25, 25],
         [ 1,  0,  2,  1,  3,  2,  4, 22,  3,  5, 20,  4,  6,  5,  7,  8,  6,  6,
           9, 14,  8, 10,  9, 11, 19, 10, 12, 15, 11, 13, 14, 12,  8, 12, 11, 16,
          17, 15, 15, 18, 19, 17, 10, 17,  4, 21, 20, 22,  3, 21, 23, 22, 24, 25,
          23, 25, 23, 24]]),
 'edge_attr': tensor([2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 

In [10]:
'''
Node decay ordering, for forward absorbing pass
* Naive: random
Later on, use diffusion ordering network
'''

def node_decay_ordering(datapoint):
    # create random list of nodes
    return torch.randperm(datapoint.x.shape[0]).tolist()



def mask_node(datapoint, selected_node):
    '''
    datapoint.x: node feature matrix
    datapoint.edge_index: edge index matrix
    datapoint.edge_attr: edge attribute matrix
    datapoint.y: target value

    ** Changes datapoint inplace
    '''
    # mask node
    datapoint.x[selected_node] = -1
    # mask edges
    datapoint.edge_attr[datapoint.edge_index[0] == selected_node] = -1
    datapoint.edge_attr[datapoint.edge_index[1] == selected_node] = -1
    return datapoint

node_decay_ordering(point)

[17,
 7,
 12,
 1,
 18,
 20,
 14,
 13,
 25,
 16,
 22,
 23,
 6,
 9,
 19,
 4,
 10,
 24,
 2,
 11,
 8,
 15,
 3,
 5,
 21,
 0]

### Forward Diffusion Process

Node orderint $\sigma$ randomly sampled. Exactly one node decays at a time.

At each step $t$, distribution of $t$-th node is conditioned on original graph $G$ and generated node ordering $\sigma$ up to $t-1$.

In [11]:
for node in node_decay_ordering(point):
    mask_node(point, node)

## Reverse Diffusion (Generative) Process

Denoising network $p_ \theta (G_t | G_{t+1})$ is a graph attention network (GAT). (Vanilla GAT)

For now, simple graph convolutional network (GCN) is used.


** Initially, the graph is fully connected and masked, so in the paper they only keep the masked node to be denoised during the generation step. 

For now, we will just use the whole graph as well.

In [5]:
'''
Message passing
Custom message passing function
'''
from torch import nn
from torch_geometric.nn import MessagePassing, GATConv
from torch_geometric.utils import add_self_loops, degree
from torch.nn import functional as F


In [13]:

from torch.nn import Linear, Parameter
class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr='mean', flow="target_to_source")  # 'add' aggregation for simplicity

        self.W = nn.Linear(in_channels, out_channels)
        # self.attention = nn.Linear(2 * in_channels, attention_dim) # No attention for now
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.W.reset_parameters()
        self.bias.data.zero_()

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j
    
    
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Attention mechanism
        # alpha = F.softmax(self.attention(torch.cat([x[row], x[col]], dim=-1)), dim=-1)
        # alpha = F.dropout(alpha, p=0.5)
        # messages = alpha.view(-1, 1) * edge_attr
        # messages = scatter_add(messages, col, dim=0, dim_size=x.size(0))

        # h = self.W(torch.cat([x, messages], dim=-1))

        out = self.propagate(edge_index, x=x, norm=norm) # , edge_attr=edge_attr
        out += self.bias

        return out
class DenoisingNetwork(nn.Module):
    def __init__(self, node_feature_dim, num_node_types, num_edge_types, num_layers, out_channels):
        super(DenoisingNetwork, self).__init__()

        # self.embedding_layer = nn.Embedding(num_embeddings=node_feature_dim, embedding_dim=node_feature_dim)
        
        
        self.message_passing_layers = nn.ModuleList([
            MessagePassingLayer(node_feature_dim, node_feature_dim) for _ in range(num_layers)
        ])

        # Node type prediction
        self.node_type_prediction = nn.Linear(node_feature_dim, num_node_types)
        
        # Edge type prediction
        self.edge_type_prediction = nn.Linear(node_feature_dim, num_edge_types)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        # Make sure x is float
        x = x.long()
        edge_index = edge_index.long()
        edge_attr = edge_attr.float()


        # h = self.embedding_layer(x) # bypass embedding
        h = x

        for layer in self.message_passing_layers:
            h = layer(h, edge_index)

        # Node type prediction
        node_type_logits = self.node_type_prediction(h)
        node_type_probs = F.softmax(node_type_logits, dim=-1)  # Applying softmax for the multinomial distribution


        # Edge type prediction
        edge_type_logits = self.edge_type_prediction(h)
        edge_type_probs = F.softmax(edge_type_logits, dim=-1)  # Applying softmax for the multinomial distribution

        return node_type_probs, edge_type_probs


In [14]:
# Graph Convolutional Network, with message passing
class OldDenoisingNetwork(torch.nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 num_layers,
                 dropout):
        super(DenoisingNetwork, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.gcn = torch_geometric.nn.GCNConv(input_dim, hidden_dim)
        self.batch_norm = torch.nn.BatchNorm1d(hidden_dim)
        self.dropout = torch.nn.Dropout(dropout) 
        # embedding layer
        self.linear = torch.nn.Linear(hidden_dim, output_dim)
        self.softmax = torch.nn.Softmax(dim=0)

    def forward(self, x, edge_index, edge_attr, batch):
        # x: node feature matrix
        # edge_index: edge index matrix
        # edge_attr: edge attribute matrix
        # batch: batch vector

        # make sure node features are floats
        x = x.float()
        # make sure edge attributes are floats
        edge_attr = edge_attr.float()
        # make sure edge index is ints
        edge_index = edge_index.long()

        # pass through GCN
        x = self.gcn(x, edge_index, edge_attr)

        # batch normalization
        x = self.batch_norm(x)

        # dropout
        x = self.dropout(x)

        # pass through linear layer
        x = self.linear(x)

        # softmax
        x = self.softmax(x)

        return x

In [15]:
# odn = OldDenoisingNetwork(input_dim=dataset.num_features,
#                         hidden_dim=64,
#                         output_dim=dataset.num_features,
#                         num_layers=3,
#                         dropout=0.5)
dn = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.num_features,
    num_edge_types=dataset.num_features,
    num_layers=3,
    out_channels=1
)

In [16]:
p = point.clone()
pred_x, pred_edges = dn(p)
pred_edges.shape # connections of new node to all previous nodes

torch.Size([26, 1])

## Training

For now, use MSE / Reconstruction error instead of loss function from the paper.


Optimize by sampling multiple ($M$) diffusion trajectories, thereby enabling training both the diffusion ordering network $q_ϕ(σ|G_0)$ and the denoising network
$p_θ(Gt|G_t+1)$ using gradient descent.

Create M trajectories (sequence of graphs) for each training graph. Where node decay order is sampled from $q_ϕ(σ|G_0)$ (or random, initially).

Train denoising network to minimize the negative VLB using SGD.

Diffusion ordering network can be updated with common RL optimization methods, e.g., the REINFORCE algorithm. Creating M trajectories and computing the negative VLB to obtain rewards, and then updating the parameters of the diffusion ordering network using the REINFORCE algorithm.

In [17]:
# import Variable class
from torch.autograd import Variable

In [19]:
# train denoising diffusion model
# use node decay ordering

optimizer = torch.optim.Adam(dn.parameters(), lr=0.0001)
# require grad for model
dn.requires_grad_(True)


M = 5 # number of diffusion trajectories to be created for each graph


with tqdm(range(100)) as pbar:
    for batch in pbar:
        graph = dataset[batch]
        # print(f"Generating trajectories for graph")
        # node decay ordering, accoding to node_decay_ordering
        original_data = graph.clone()
        diffusion_trajectories = []
        for m in range(M):
            node_order = node_decay_ordering(graph)
            # create diffusion trajectory
            diffusion_trajectory = [original_data]
            masked_data = graph.clone()
            for node in node_order:
                masked_data = masked_data.clone()
                
                masked_data = mask_node(masked_data, node)
                diffusion_trajectory.append(masked_data)

            diffusion_trajectories.append(diffusion_trajectory)
        
        # predictions & loss
        for diffusion_trajectory in diffusion_trajectories:
            G_0 = diffusion_trajectory[0]
            for t in range(1, len(diffusion_trajectory)-1):
                G_t = diffusion_trajectory[t]
                # transform to float
                G_t.x = G_t.x.float()
                G_t.edge_index = G_t.edge_index.float()

                G_tplus1 = diffusion_trajectory[t+1].clone()
                # transform to float
                G_tplus1.x = G_tplus1.x.float()
                G_tplus1.edge_index = G_tplus1.edge_index.float()
                G_tplus1.x.requires_grad = True
                G_tplus1.edge_index.requires_grad = True

                G_pred = G_tplus1.clone()

                # predict node type
                node_type_probs, node_connections_probs = dn(G_tplus1)
                node_type_probs.requires_grad_()
                node_connections_probs.requires_grad_()
                # sample node type
                pred_node_type = torch.distributions.multinomial.Multinomial(1, node_type_probs.squeeze())
                # add node type to node
                G_pred.x[node] = pred_node_type.sample()[node]
                # sample edge type
                new_connections = torch.distributions.multinomial.Multinomial(1, node_connections_probs.squeeze()).sample()
                # add new connections to edge_attr
                for i in  range(len(new_connections)):
                    new_connection = new_connections[i]
                    if new_connection != -1:
                        G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection
                # calculate loss
                node_loss = torch.nn.functional.mse_loss(G_t.x.float(), G_pred.x.float())
                # connections_loss = torch.nn.functional.mse_loss(G_t.edge_index[0].float(), G_pred.edge_index[0].float())
                loss = node_loss #+ connections_loss
                # backprop
                loss.backward()
                # update parameters
                optimizer.step()
                # log loss
                pbar.set_description(f"Epoch: {batch}, Loss: {loss.item()%10:.4f}")
            
        
'''
TODO: edge_attributes have to be used. The masked nodes can be identified through them. There'll never be error in edge_index
'''

Epoch: 99, Loss: 0.0455: 100%|██████████| 100/100 [00:35<00:00,  2.81it/s]


"\nTODO: edge_attributes have to be used. The masked nodes can be identified through them. There'll never be error in edge_index\n"

In [20]:
# test denoising diffusion model on diffusion trajectory
G_pred = diffusion_trajectory[-1].clone()
node_type_probs, node_connections_probs = dn(diffusion_trajectory[-1])
node_type_probs.requires_grad_()
node_connections_probs.requires_grad_()
# sample node type
pred_node_type = torch.distributions.multinomial.Multinomial(1, node_type_probs.squeeze())
# add node type to node
G_pred.x[node] = pred_node_type.sample()[node]
# sample edge type
new_connections = torch.distributions.multinomial.Multinomial(1, node_connections_probs.squeeze()).sample()
# add new connections to edge_attr
for i in  range(len(new_connections)):
    new_connection = new_connections[i]
    if new_connection != -1:
        G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [22]:
 == 

SyntaxError: invalid syntax (2539200343.py, line 1)