# Autoregressive Diffusion Models on Graphs

In [2]:
import torch_geometric
from torch_geometric.datasets import ZINC
from tqdm import tqdm
import torch
from torch import nn
from torch_geometric.nn import MessagePassing, GATConv, GAT
from torch_geometric.utils import add_self_loops, degree
from torch.nn import functional as F
import math

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# instanciate the dataset
dataset = ZINC(root='/tmp/ZINC', transform=None, pre_transform=None)

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting /tmp/ZINC/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
Processing train dataset: 100%|██████████| 220011/220011 [00:16<00:00, 13174.82it/s]
Processing val dataset: 100%|██████████| 24445/24445 [00:02<00:00, 8901.46it/s] 
Processing test dataset: 100%|██████████| 5000/5000 [00:00<00:00, 12933.02it/s]
Done!


In [57]:
point = dataset[1]
point

Data(x=[18, 1], edge_index=[2, 36], edge_attr=[36], y=[1])

In [6]:
point.edge_attr

tensor([1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1,
        2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1])

In [7]:
# absence of edges should be an edge type, attribute 0
# fully connect graph

In [8]:
point.y

tensor([-0.2184])

In [9]:
'''
Masking node mechanism
1. Masked node (x = -1)
2. Connected to all other nodes in graph by masked edges (edge_attr = -1)
'''

'\nMasking node mechanism\n1. Masked node (x = -1)\n2. Connected to all other nodes in graph by masked edges (edge_attr = -1)\n'

In [10]:
point.to_dict()

{'x': tensor([[0],
         [4],
         [2],
         [0],
         [1],
         [4],
         [1],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [2],
         [2],
         [0],
         [0],
         [0],
         [0],
         [0],
         [2]]),
 'edge_index': tensor([[ 0,  1,  1,  1,  2,  2,  3,  3,  3,  4,  5,  5,  5,  6,  7,  7,  7,  8,
           8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16, 16,
          16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 21],
         [ 1,  0,  2, 13,  1,  3,  2,  4,  5,  3,  3,  6,  7,  5,  5,  8, 12,  7,
           9,  8, 10,  9, 11, 10, 12,  7, 11,  1, 14, 21, 13, 15, 14, 16, 15, 17,
          21, 16, 18, 17, 19, 18, 20, 19, 21, 13, 16, 20]]),
 'edge_attr': tensor([1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1,
         2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1]),
 'y': tensor([-0.2184])}

In [65]:
'''
Node decay ordering, for forward absorbing pass
* Naive: random
Later on, use diffusion ordering network
'''

def node_decay_ordering(datapoint):
    # create random list of nodes
    return torch.randperm(datapoint.x.shape[0]).tolist()

def is_masked(datapoint, node=None):
    if node is None:
        return datapoint.x == -1
    return datapoint.x[node] == -1



def mask_node(datapoint, selected_node):
    '''
    datapoint.x: node feature matrix
    datapoint.edge_index: edge index matrix
    datapoint.edge_attr: edge attribute matrix
    datapoint.y: target value

    ** Changes datapoint inplace
    '''
    # mask node
    datapoint.x[selected_node] = -1
    # mask edges
    datapoint.edge_attr[datapoint.edge_index[0] == selected_node] = 0
    datapoint.edge_attr[datapoint.edge_index[1] == selected_node] = 0
    return datapoint

node_decay_ordering(point)

[13, 10, 11, 3, 17, 15, 16, 1, 2, 6, 0, 4, 12, 5, 9, 8, 7, 14]

![Alt text](media/don.png)

![Alt text](media/don_2.png)

In [75]:
class DiffusionOrderingNetwork(nn.Module):
    '''
    at each diffusion step t, we sample from this network to select a node 
    v_sigma(t) to be absorbed and obtain the corresponding masked graph Gt
    '''
    def __init__(self,
                 node_feature_dim,
                 num_node_types,
                 num_edge_types,
                 num_layers=3,
                 out_channels=1):
        super(DiffusionOrderingNetwork, self).__init__()

        # add positional encodings into node features
        self.embedding = nn.Embedding(100, node_feature_dim)

        self.gat = GAT(
            in_channels=node_feature_dim,
            out_channels=node_feature_dim,
            hidden_channels=6 * 6,
            num_layers=num_layers,
            dropout=0,
            heads=6,
        )


    def positionalencoding(self, lengths, permutations):
        '''
        From Chen, et al. 2021 (Order Matters: Probabilistic Modeling of Node Sequences for Graph Generation)
        '''
        # length = sum([len(perm) for perm in permutations])
        l_t = len(permutations[0])
        # pes = [torch.zeros(length, self.d_model) for length in lengths]
        pes = torch.split(torch.zeros((sum(lengths), self.d_model), device=self.device), lengths)
        # print(pes[0].device)
        position = torch.arange(0, l_t, device=self.device).unsqueeze(1) + 1
        div_term = torch.exp((torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) *
                              -(math.log(10000.0) / self.d_model)))
        # test = torch.sin(position.float() * div_term)
        for i in range(len(lengths)):
            pes[i][permutations[i], 0::2] = torch.sin(position.float() * div_term)
            pes[i][permutations[i], 1::2] = torch.cos(position.float() * div_term)

        pes = torch.cat(pes)
        return pes

    def forward(self, G, p=None):
        h = self.gat(G.x.float(), G.edge_index)

        # TODO augment node features with positional encodings
        # if p is not None:
        #     # p = self.positionalencoding(G.batch_num_nodes().tolist(), p) original from Chen et al.
        #     p = self.positionalencoding(G.x.shape[0], p)
        #     h = h + p
        
        # softmax over nodes
        h = F.softmax(h, dim=0)
        
        return h # outputs probabilities for a categorical distribution over nodes

In [77]:
diff_ord_net = DiffusionOrderingNetwork(node_feature_dim=1,
                                        num_node_types=dataset.x.unique().shape[0],
                                        num_edge_types=3,
                                        num_layers=3,
                                        out_channels=1)
sigma_t_dist = diff_ord_net(point)
print(sigma_t_dist.flatten())
# sample from categorical distribution to get node to mask
# TODO only on the samples that are not masked
# sigma_t = F.softmax(sigma_t_dist.flatten(), dim=0)
torch.distributions.Categorical(probs=sigma_t_dist.flatten()).sample()

tensor([0.0544, 0.0550, 0.0558, 0.0584, 0.0593, 0.0616, 0.0595, 0.0586, 0.0567,
        0.0576, 0.0551, 0.0526, 0.0522, 0.0523, 0.0529, 0.0535, 0.0523, 0.0522],
       grad_fn=<ReshapeAliasBackward0>)


tensor(16)

### Forward Diffusion Process

Node ordering $\sigma$ using diffusion ordering network. Exactly one node decays at a time.

At each step $t$, distribution of $t$-th node is conditioned on original graph $G$ and generated node ordering $\sigma$ up to $t-1$.

In [86]:
def node_decay_ordering(datapoint):
    p = datapoint.clone()
    node_order = []
    for i in range(p.x.shape[0]):
        # use diffusion ordering network to get probabilities
        sigma_t_dist = diff_ord_net(p, i)
        # sample from categorical distribution to get node to mask
        # TODO only on the samples that are not masked
        unmasked = ~is_masked(p)
        sigma_t = torch.distributions.Categorical(probs=sigma_t_dist[unmasked].flatten()).sample()
        
        # get node index
        sigma_t = torch.where(unmasked.flatten())[0][sigma_t.long()]
        node_order.append(sigma_t)
        # mask node
        p = mask_node(p, sigma_t)
    return node_order

In [87]:
node_decay_ordering(point)

[tensor(5),
 tensor(11),
 tensor(1),
 tensor(12),
 tensor(15),
 tensor(6),
 tensor(3),
 tensor(14),
 tensor(8),
 tensor(16),
 tensor(9),
 tensor(10),
 tensor(13),
 tensor(4),
 tensor(17),
 tensor(0),
 tensor(7),
 tensor(2)]

## Reverse Diffusion (Generative) Process

Denoising network $p_ \theta (G_t | G_{t+1})$ is a graph attention network (GAT). (Vanilla GAT)

For now, simple graph convolutional network (GCN) is used.


** Initially, the graph is fully connected and masked, so in the paper they only keep the masked node to be denoised during the generation step. 

For now, we will just use the whole graph as well.

In [1]:
'''
Message passing
Custom message passing function
'''

'\nMessage passing\nCustom message passing function\n'

Make sure attentive message passing is done correctly.


![Alt text](media/image.png)

In [12]:

from torch.nn import Linear, Parameter
class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr='mean', flow="target_to_source")  # 'add' aggregation for simplicity

        self.W = nn.Linear(in_channels, out_channels)
        # self.attention = nn.Linear(2 * in_channels, attention_dim)
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.W.reset_parameters()
        self.bias.data.zero_()

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j # TODO change to adapt attention mechanism
    
    
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Attention mechanism
        # alpha = F.softmax(self.attention(torch.cat([x[row], x[col]], dim=-1)), dim=-1)
        # alpha = F.dropout(alpha, p=0.5)
        # messages = alpha.view(-1, 1) * edge_attr
        # messages = scatter_add(messages, col, dim=0, dim_size=x.size(0))

        # h = self.W(torch.cat([x, messages], dim=-1))

        out = self.propagate(edge_index, x=x, norm=norm)
        out += self.bias

        return out
    


![Alt text](media/architecture.png)


In [169]:

class DenoisingNetwork(nn.Module):
    def __init__(self,
                 node_feature_dim,
                 num_node_types,
                 num_edge_types,
                 num_layers,
                 out_channels):
        super(DenoisingNetwork, self).__init__()

        self.embedding_layer = nn.Embedding(num_embeddings=num_node_types, embedding_dim=node_feature_dim)

        print(node_feature_dim, num_node_types, num_edge_types, num_layers, out_channels)

        self.GAT = GAT(
            in_channels=node_feature_dim,
            out_channels=node_feature_dim,
            hidden_channels=128,
            num_layers=num_layers,
            dropout=0,
        )

        # Custom message passing layers
        # self.message_passing_layers = nn.ModuleList([
        #     GATConv(node_feature_dim, node_feature_dim, heads=1, dropout=0.6) for _ in range(num_layers)
        # ])
        
        # Node type prediction
        self.node_type_prediction = nn.Linear(node_feature_dim, num_node_types) # Use only element of the new node
        
        # Edge type prediction
        self.edge_type_prediction = nn.Linear(node_feature_dim, num_edge_types) # Use all elements (connections to other nodes)

    def forward(self, data):

        '''
        Outputs: 
        new_node_type: type of new node to be unmasked
        new_edge_type: types of new edges from previous nodes to the one to be unmasked
        '''
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x = x.squeeze()
        # Make sure x is float
        x = x.long()
        edge_index = edge_index.long()


        h = self.embedding_layer(x)

        # for layer in self.message_passing_layers:
        #     h = layer(h, edge_index, edge_attr)


        h = self.GAT(h, edge_index)


        # Node type prediction
        node_type_logits = self.node_type_prediction(h)
        # Applying softmax for the multinomial distribution
        node_type_probs = F.softmax(node_type_logits, dim=-1)  
        



        # Edge type prediction
        edge_type_logits = self.edge_type_prediction(h)
        # Applying softmax for the multinomial distribution
        edge_type_probs = F.softmax(edge_type_logits, dim=-1)
        
        
        return node_type_probs, edge_type_probs

In [170]:
# num_features
dn = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.x.unique().shape[0],
    num_edge_types=3,
    num_layers=7,
    out_channels=1
)

1 28 3 7 1


In [171]:
p = point.clone()
node_type_probs, edge_type_probs = dn(p)
edge_type_probs # connections of new node to all previous nodes
node_type_probs[0].shape

torch.Size([28])

## Training

For now, use MSE / Reconstruction error instead of loss function from the paper.


Optimize by sampling multiple ($M$) diffusion trajectories, thereby enabling training both the diffusion ordering network $q_ϕ(σ|G_0)$ and the denoising network
$p_θ(Gt|G_t+1)$ using gradient descent.

Create M trajectories (sequence of graphs) for each training graph. Where node decay order is sampled from $q_ϕ(σ|G_0)$ (or random, initially).

Train denoising network to minimize the negative VLB using SGD.

Diffusion ordering network can be updated with common RL optimization methods, e.g., the REINFORCE algorithm. Creating M trajectories and computing the negative VLB to obtain rewards, and then updating the parameters of the diffusion ordering network using the REINFORCE algorithm.

In [201]:
# use wandb to log the model
import wandb

dn = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.x.unique().shape[0],
    num_edge_types=3,
    num_layers=7,
    out_channels=1
)

wandb.init(
        project="ARGD",
        group=f"v0.21",
        name=f"overfit_nozerograd",
        # track hyperparameters and run metadata
        config={
            "policy": "train",
            "n_epochs": 10000,
            "batch_size": 1,
        }
    )


1 28 3 7 1


0,1
loss,███▇▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,-1.0


![Alt text](media/optimizer.png)

In [204]:
# train denoising diffusion model
# use node decay ordering

optimizer = torch.optim.Adam(dn.parameters(), lr=1e-4, betas=(0.9, 0.999))
loss_fcn = torch.nn.NLLLoss() # TODO use appropriate loss function (possibly NLLLoss)
# require grad for model
dn.train()

M = 4 # number of diffusion trajectories to be created for each graph


with tqdm(range(50)) as pbar:
    for batch in pbar:
        graph = dataset[0]
        # print(f"Generating trajectories for graph")
        # node decay ordering, accoding to node_decay_ordering
        original_data = graph.clone()
        diffusion_trajectories = []
        node_order = torch.range(0, graph.x.shape[0]-1).long()
        for m in range(M):
            # node_order = node_decay_ordering(graph)
            
            # create diffusion trajectory
            diffusion_trajectory = [original_data]
            masked_data = graph.clone()
            for node in node_order:
                masked_data = masked_data.clone()
                
                masked_data = mask_node(masked_data, node)
                diffusion_trajectory.append(masked_data)

            diffusion_trajectories.append(diffusion_trajectory)
        
        # predictions & loss
        for diffusion_trajectory in diffusion_trajectories:
            G_0 = diffusion_trajectory[0]
            # optimizer.zero_grad()
            # node_order = node_decay_ordering(G_0) 
            for t in range(1, len(diffusion_trajectory)-1):
                node = node_order[len(diffusion_trajectory)-t-1]
                G_t = diffusion_trajectory[t]
                # transform to float
                G_t.x = G_t.x.float()
                G_t.edge_index = G_t.edge_index.float()

                G_tplus1 = diffusion_trajectory[t+1].clone()
                # transform to float
                G_tplus1.x = G_tplus1.x.float()
                G_tplus1.edge_index = G_tplus1.edge_index.float()

                G_pred = G_tplus1.clone()
                

                # predict node type
                node_type_probs, edge_type_probs = dn(G_tplus1)

                node_type_probs = node_type_probs[node]
                
                #  Apply multinomial distribution
                # node_dist = torch.distributions.Multinomial(probs=node_type_probs.squeeze(), total_count=1)
                # node_type = node_dist.sample()
                # edge_dist = torch.distributions.Multinomial(probs=edge_type_probs.squeeze(), total_count=G_tplus1.x.shape[0])
                # node_connections = edge_dist.sample()

                # add node type to node
                node_dist = torch.distributions.Categorical(probs=node_type_probs.squeeze())
                node_type = node_dist.sample()
                # print(node_type)
                # G_pred.x[node] = node_type
                # sample edge type
                # new_connections = torch.multinomial(node_connections_probs.squeeze(), num_samples=1, replacement=True)
                '''
                TODO make sure you 
                "predict the connections of the new node to all previous nodes at once 
                                        using a mixture of multinomial distribution"
                '''
                # add new connections to edge_attr
                # for i in  range(len(node_connections)):
                #     new_connection = node_connections[i]
                #     if new_connection != 0:
                #         G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection
                
                # calculate loss
                # loss = loss_fcn(G_t.x.float(), G_pred.x.float())

                loss = loss_fcn(node_type_probs.reshape(1, -1), G_t.x[node].long())

                # backprop
                loss.backward()

                # DEBUG if gradients are not None
                # for name, param in dn.named_parameters():
                #     if param.grad != None and torch.sum(param.grad) != 0:
                #         print(name, param.grad.shape, torch.sum(param.grad))
                
                # update parameters
                optimizer.step()
                # log loss
                pbar.set_description(f"Epoch: {batch}, Loss: {loss.item()%10:.4f}")
                wandb.log({"loss": loss.item()})
        # save model
        torch.save(dn.state_dict(), "ardm_model.pt")
            
            
    '''
    TODO: edge_attributes have to be used. The masked nodes can be identified through them. There'll never be error in edge_index
    '''

  node_order = torch.range(0, graph.x.shape[0]-1).long()
Epoch: 26, Loss: 0.0000:  52%|█████▏    | 26/50 [00:42<00:39,  1.65s/it]


KeyboardInterrupt: 

In [175]:
diffusion_trajectory[0].x

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [2],
        [2],
        [0],
        [2],
        [0],
        [0],
        [0],
        [0],
        [0],
        [2],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [2],
        [0],
        [5]])

In [176]:
G_pred = diffusion_trajectory[-1].clone()
G_pred.x = G_pred.x.float()

In [189]:
node_connections.shape

torch.Size([33, 3])

In [192]:
forward_pass = [G_pred]

for i in node_order:
    # 1 diffusion step for each node
    with torch.no_grad():
        node_type_probs, edge_type_probs = dn(G_pred)
        # sample node type
        node_dist = torch.distributions.Categorical(probs=node_type_probs.squeeze())
        node_type = node_dist.sample()[i]
        print(node_type)
        edge_dist = torch.distributions.Multinomial(probs=edge_type_probs.squeeze(), total_count=G_tplus1.x.shape[0])
        node_connections = edge_dist.sample()
        G_pred.x[i] = node_type
        # add new connections to edge_attr
        # for i in  range(len(node_connections)):
        #     new_connection = node_connections[i]
        #     if new_connection != 0:
        #         G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection
        
        forward_pass.append(G_pred.clone())
    # print(G_pred.x)
    # print(G_pred.edge_attr)
    # print(G_pred.edge_index)

tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


In [191]:
forward_pass[-1].x

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])