# Autoregressive Diffusion Models on Graphs

In [1]:
import torch_geometric
from torch_geometric.datasets import ZINC
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# instanciate the dataset
dataset = ZINC(root='/tmp/ZINC', transform=None, pre_transform=None)

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting /tmp/ZINC/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
Processing train dataset: 100%|██████████| 220011/220011 [00:12<00:00, 18195.33it/s]
Processing val dataset: 100%|██████████| 24445/24445 [00:02<00:00, 10637.45it/s]
Processing test dataset: 100%|██████████| 5000/5000 [00:00<00:00, 15249.15it/s]
Done!


In [3]:
point = dataset[10]
point

Data(x=[22, 1], edge_index=[2, 48], edge_attr=[48], y=[1])

In [4]:
point.edge_attr

tensor([1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1,
        2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1])

In [5]:
# absence of edges should be an edge type, attribute 0
# fully connect graph

In [6]:
point.y

tensor([-0.2184])

In [7]:
'''
Masking node mechanism
1. Masked node (x = -1)
2. Connected to all other nodes in graph by masked edges (edge_attr = -1)
'''

'\nMasking node mechanism\n1. Masked node (x = -1)\n2. Connected to all other nodes in graph by masked edges (edge_attr = -1)\n'

In [8]:
point.to_dict()

{'x': tensor([[0],
         [4],
         [2],
         [0],
         [1],
         [4],
         [1],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [2],
         [2],
         [0],
         [0],
         [0],
         [0],
         [0],
         [2]]),
 'edge_index': tensor([[ 0,  1,  1,  1,  2,  2,  3,  3,  3,  4,  5,  5,  5,  6,  7,  7,  7,  8,
           8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16, 16,
          16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 21],
         [ 1,  0,  2, 13,  1,  3,  2,  4,  5,  3,  3,  6,  7,  5,  5,  8, 12,  7,
           9,  8, 10,  9, 11, 10, 12,  7, 11,  1, 14, 21, 13, 15, 14, 16, 15, 17,
          21, 16, 18, 17, 19, 18, 20, 19, 21, 13, 16, 20]]),
 'edge_attr': tensor([1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1,
         2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1]),
 'y': tensor([-0.2184])}

In [9]:
'''
Node decay ordering, for forward absorbing pass
* Naive: random
Later on, use diffusion ordering network
'''

def node_decay_ordering(datapoint):
    # create random list of nodes
    return torch.randperm(datapoint.x.shape[0]).tolist()



def mask_node(datapoint, selected_node):
    '''
    datapoint.x: node feature matrix
    datapoint.edge_index: edge index matrix
    datapoint.edge_attr: edge attribute matrix
    datapoint.y: target value

    ** Changes datapoint inplace
    '''
    # mask node
    datapoint.x[selected_node] = 0
    # mask edges
    datapoint.edge_attr[datapoint.edge_index[0] == selected_node] = 0
    datapoint.edge_attr[datapoint.edge_index[1] == selected_node] = 0
    return datapoint

node_decay_ordering(point)

[17, 19, 9, 1, 2, 15, 21, 13, 10, 4, 8, 0, 12, 14, 5, 6, 18, 20, 16, 3, 11, 7]

### Forward Diffusion Process

Node orderint $\sigma$ randomly sampled. Exactly one node decays at a time.

At each step $t$, distribution of $t$-th node is conditioned on original graph $G$ and generated node ordering $\sigma$ up to $t-1$.

In [10]:
for node in node_decay_ordering(point):
    mask_node(point, node)

## Reverse Diffusion (Generative) Process

Denoising network $p_ \theta (G_t | G_{t+1})$ is a graph attention network (GAT). (Vanilla GAT)

For now, simple graph convolutional network (GCN) is used.


** Initially, the graph is fully connected and masked, so in the paper they only keep the masked node to be denoised during the generation step. 

For now, we will just use the whole graph as well.

In [11]:
'''
Message passing
Custom message passing function
'''
from torch import nn
from torch_geometric.nn import MessagePassing, GATConv
from torch_geometric.utils import add_self_loops, degree
from torch.nn import functional as F


Make sure attentive message passing is done correctly.


![Alt text](media/image.png)

In [12]:

from torch.nn import Linear, Parameter
class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr='mean', flow="target_to_source")  # 'add' aggregation for simplicity

        self.W = nn.Linear(in_channels, out_channels)
        # self.attention = nn.Linear(2 * in_channels, attention_dim) # TODO add attention mechanism
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.W.reset_parameters()
        self.bias.data.zero_()

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j # TODO change to adapt attention mechanism
    
    
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Attention mechanism
        # alpha = F.softmax(self.attention(torch.cat([x[row], x[col]], dim=-1)), dim=-1)
        # alpha = F.dropout(alpha, p=0.5)
        # messages = alpha.view(-1, 1) * edge_attr
        # messages = scatter_add(messages, col, dim=0, dim_size=x.size(0))

        # h = self.W(torch.cat([x, messages], dim=-1))

        out = self.propagate(edge_index, x=x, norm=norm) # , edge_attr=edge_attr
        out += self.bias

        return out
    

class DenoisingNetwork(nn.Module):
    def __init__(self,
                 node_feature_dim,
                 num_node_types,
                 num_edge_types,
                 num_layers,
                 out_channels):
        super(DenoisingNetwork, self).__init__()

        self.embedding_layer = nn.Embedding(num_embeddings=num_node_types, embedding_dim=node_feature_dim)

        print(node_feature_dim, num_node_types, num_edge_types, num_layers, out_channels)
        self.message_passing_layers = nn.ModuleList([
            GATConv(node_feature_dim, node_feature_dim, heads=1, dropout=0.6) for _ in range(num_layers)
        ])
        # TODO use models.GAT network instead of GATConv layers separately.
        
        # Node type prediction
        self.node_type_prediction = nn.Linear(node_feature_dim, out_channels)
        
        # Edge type prediction
        self.edge_type_prediction = nn.Linear(node_feature_dim, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x = x.squeeze()
        # Make sure x is float
        x = x.long()
        edge_index = edge_index.long()


        h = self.embedding_layer(x)

        for layer in self.message_passing_layers:

            h = layer(h, edge_index, edge_attr)

        # Node type prediction
        node_type_logits = self.node_type_prediction(h)
        # node_type_probs = F.softmax(node_type_logits, dim=-1)  # TODO Applying softmax for the multinomial distribution
        # TODO Apply multinomial distribution


        # Edge type prediction
        edge_type_logits = self.edge_type_prediction(h)
        # edge_type_probs = F.softmax(edge_type_logits, dim=-1)  # TODO Applying softmax for the multinomial distribution
        # TODO Apply multinomial distribution

        return node_type_logits, edge_type_logits


In [77]:
num_layers = 10
# num_features
dn = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.num_classes+1,
    num_edge_types=3,
    num_layers=num_layers,
    out_channels=1
)

1 218363 3 10 1


In [78]:
p = point.clone()
pred_x, pred_edges = dn(p)
pred_edges.shape # connections of new node to all previous nodes

torch.Size([22, 1])

## Training

For now, use MSE / Reconstruction error instead of loss function from the paper.


Optimize by sampling multiple ($M$) diffusion trajectories, thereby enabling training both the diffusion ordering network $q_ϕ(σ|G_0)$ and the denoising network
$p_θ(Gt|G_t+1)$ using gradient descent.

Create M trajectories (sequence of graphs) for each training graph. Where node decay order is sampled from $q_ϕ(σ|G_0)$ (or random, initially).

Train denoising network to minimize the negative VLB using SGD.

Diffusion ordering network can be updated with common RL optimization methods, e.g., the REINFORCE algorithm. Creating M trajectories and computing the negative VLB to obtain rewards, and then updating the parameters of the diffusion ordering network using the REINFORCE algorithm.

In [89]:
# train denoising diffusion model
# use node decay ordering

# use wandb to log the model
import wandb

for num_layers in range(0, 200, 10):

    dn = DenoisingNetwork(
        node_feature_dim=dataset.num_features,
        num_node_types=dataset.num_classes+1,
        num_edge_types=3,
        num_layers=num_layers,
        out_channels=1
    )

    wandb.init(
            project="ARGD",
            group=f"v0.10_{num_layers}_layers",
            name=f"overfit_sample_right_mse_zerograd",
            # track hyperparameters and run metadata
            config={
                "policy": "train",
                "n_epochs": 10000,
                "batch_size": 1,
            }
        )

    optimizer = torch.optim.Adam(dn.parameters(), lr=0.0001)
    loss_fcn = torch.nn.MSELoss() # TODO use appropriate loss function (possibly NLLLoss)
    # require grad for model
    dn.train()

    M = 10 # number of diffusion trajectories to be created for each graph


    with tqdm(range(50)) as pbar:
        for batch in pbar:
            graph = dataset[batch]
            # print(f"Generating trajectories for graph")
            # node decay ordering, accoding to node_decay_ordering
            original_data = graph.clone()
            diffusion_trajectories = []
            for m in range(M):
                node_order = node_decay_ordering(graph)
                node_order = list(range(graph.x.shape[0]))
                # create diffusion trajectory
                diffusion_trajectory = [original_data]
                masked_data = graph.clone()
                for node in node_order:
                    masked_data = masked_data.clone()
                    
                    masked_data = mask_node(masked_data, node)
                    diffusion_trajectory.append(masked_data)

                diffusion_trajectories.append(diffusion_trajectory)
            
            # predictions & loss
            for diffusion_trajectory in diffusion_trajectories:
                G_0 = diffusion_trajectory[0]
                optimizer.zero_grad()
                for t in range(1, len(diffusion_trajectory)-1):
                    G_t = diffusion_trajectory[t]
                    # transform to float
                    G_t.x = G_t.x.float()
                    G_t.edge_index = G_t.edge_index.float()

                    G_tplus1 = diffusion_trajectory[t+1].clone()
                    # transform to float
                    G_tplus1.x = G_tplus1.x.float()
                    G_tplus1.edge_index = G_tplus1.edge_index.float()

                    G_pred = G_tplus1.clone()

                    # predict node type
                    node_type_probs, node_connections_probs = dn(G_tplus1)
                    # sample node type
                    # pred_node_type = torch.multinomial( node_type_probs.squeeze(), num_samples=1, replacement=True)
                    pred_node_type = node_type_probs[node]
                    # add node type to node
                    G_pred.x[node] = pred_node_type
                    # sample edge type
                    # new_connections = torch.multinomial(node_connections_probs.squeeze(), num_samples=1, replacement=True)
                    new_connections = node_connections_probs[node]
                    '''
                    TODO make sure you 
                    "predict the connections of the new node to all previous nodes at once 
                                            using a mixture of multinomial distribution"
                    '''
                    # add new connections to edge_attr
                    for i in  range(len(new_connections)):
                        new_connection = new_connections[i]
                        if new_connection != -1:
                            G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection
                    # calculate loss
                    loss = loss_fcn(G_t.x.float(), G_pred.x.float())
                    # node_loss = torch.nn.functional.mse_loss(G_t.x.float(), G_pred.x.float())
                    # connections_loss = torch.nn.functional.mse_loss(G_t.edge_index[0].float(), G_pred.edge_index[0].float())
                    # loss = node_loss + connections_loss
                    # zero gradients

                    # backprop
                    loss.backward()

                    # DEBUG if gradients are not None
                    # for name, param in dn.named_parameters():
                    #     if param.grad != None and torch.sum(param.grad) != 0:
                    #         print(name, param.grad.shape, torch.sum(param.grad))
                    
                    # update parameters
                    optimizer.step()
                    # log loss
                    pbar.set_description(f"Epoch: {batch}, Loss: {loss.item()%10:.4f}")
                    wandb.log({"loss": loss.item()})
            # save model
            torch.save(dn.state_dict(), "ardm_model.pt")
                
            
    '''
    TODO: edge_attributes have to be used. The masked nodes can be identified through them. There'll never be error in edge_index
    '''

1 218363 3 0 1


0,1
loss,▂▁▁▁▁▁▁▄▃▁▁▁▁▁▁▁▂▂▁▁▂▁▁▁▄▁▂▃█▁▂▃▁▃▁▁▁▁▁▁

0,1
loss,0.03957


Epoch: 49, Loss: 0.0002: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


1 218363 3 10 1


0,1
loss,▄▁▁▁▁▂▁▇▆▃▂▁▁▁▁▁▂▃▁▁▃▁▁▁▄▁▂▅█▁▁▃▁▅▁▁▁▁▂▁

0,1
loss,0.00018


Epoch: 49, Loss: 0.0229: 100%|██████████| 50/50 [02:43<00:00,  3.27s/it]


1 218363 3 20 1


0,1
loss,▃▁▁▁▁▁▁▅▄▂▁▁▁▁▁▁▂▂▁▁▂▁▁▁▄▁▂▃█▁▂▃▁▄▁▁▁▁▁▁

0,1
loss,0.02294


Epoch: 49, Loss: 0.0213: 100%|██████████| 50/50 [04:20<00:00,  5.20s/it]


1 218363 3 30 1


0,1
loss,▄▁▁▁▁▂▁▅▅▂▁▁▁▁▁▁▂▂▁▁▂▁▁▁▄▁▂▄█▁▂▃▁▄▁▁▁▁▁▁

0,1
loss,0.02132


Epoch: 49, Loss: 0.0085: 100%|██████████| 50/50 [05:41<00:00,  6.84s/it]


1 218363 3 40 1


0,1
loss,▄▁▁▁▁▂▁▆▅▃▂▁▁▁▁▁▂▃▁▁▂▁▁▁▄▁▂▄█▁▂▃▁▄▁▁▁▁▁▁

0,1
loss,0.00845


Epoch: 49, Loss: 0.0171: 100%|██████████| 50/50 [06:54<00:00,  8.29s/it]


1 218363 3 50 1


0,1
loss,▃▁▁▁▁▂▁▅▅▂▁▁▁▁▁▁▂▃▁▁▂▁▁▁▄▁▂▃█▁▂▃▁▄▁▁▁▁▁▁

0,1
loss,0.01715


Epoch: 49, Loss: 0.0025: 100%|██████████| 50/50 [08:25<00:00, 10.11s/it]


1 218363 3 60 1


0,1
loss,▄▁▁▁▁▂▁▆▆▂▂▁▁▁▁▁▂▃▁▁▂▁▁▁▄▁▂▄█▁▁▃▁▃▁▁▁▁▁▁

0,1
loss,0.00255


Epoch: 4, Loss: 0.0104:   8%|▊         | 4/50 [00:51<09:50, 12.83s/it]


KeyboardInterrupt: 

wandb: Network error (ConnectionError), entering retry loop.


In [32]:
# load model
dn.load_state_dict(torch.load("ardm_model.pt"))

<All keys matched successfully>

In [83]:
diffusion_trajectory[0].x

tensor([[0],
        [2],
        [0],
        [1],
        [4],
        [0],
        [0],
        [0],
        [2],
        [0],
        [2],
        [0],
        [2],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [3],
        [0],
        [0],
        [2],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]])

In [84]:
G_pred = diffusion_trajectory[-1].clone()
G_pred.x = G_pred.x.float()

In [85]:
forward_pass = [G_pred]

for i in range(G_pred.x.shape[0]):
    # 1 diffusion step for each node
    with torch.no_grad():
        node_type_probs, node_connections_probs = dn(G_pred)    
        # sample node type
        # pred_node_type = torch.multinomial( node_type_probs.squeeze(), num_samples=1, replacement=True)
        pred_node_type = node_type_probs[i]
        # add node type to node
        G_pred.x[i] = pred_node_type
        print(pred_node_type)
        # sample edge type
        # new_connections = torch.multinomial(node_connections_probs.squeeze(), num_samples=1, replacement=True)
        new_connections = node_connections_probs[i]
        # add new connections to edge_attr
        for i in  range(len(new_connections)):
            new_connection = new_connections[i]
            if new_connection != -1:
                G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection

        forward_pass.append(G_pred.clone())
    # print(G_pred.x)
    # print(G_pred.edge_attr)
    # print(G_pred.edge_index)

tensor([1.2445])
tensor([1.2445])
tensor([0.7421])
tensor([1.2445])
tensor([1.0534])
tensor([1.0167])
tensor([0.9974])
tensor([1.0024])
tensor([1.1158])
tensor([1.1269])
tensor([0.9464])
tensor([0.9622])
tensor([1.2445])
tensor([0.7681])
tensor([0.9999])
tensor([1.0590])
tensor([1.0460])
tensor([1.2445])
tensor([0.9554])
tensor([1.0220])
tensor([0.6927])
tensor([0.9803])
tensor([0.8272])
tensor([0.8478])
tensor([0.7732])
tensor([0.9623])
tensor([0.8723])
tensor([0.7071])
tensor([1.1206])
tensor([0.8973])
tensor([0.8551])
tensor([0.7293])


In [86]:
G_pred.x.shape[0]

32

In [87]:
forward_pass[-1].x

tensor([[1.2445],
        [1.2445],
        [0.7421],
        [1.2445],
        [1.0534],
        [1.0167],
        [0.9974],
        [1.0024],
        [1.1158],
        [1.1269],
        [0.9464],
        [0.9622],
        [1.2445],
        [0.7681],
        [0.9999],
        [1.0590],
        [1.0460],
        [1.2445],
        [0.9554],
        [1.0220],
        [0.6927],
        [0.9803],
        [0.8272],
        [0.8478],
        [0.7732],
        [0.9623],
        [0.8723],
        [0.7071],
        [1.1206],
        [0.8973],
        [0.8551],
        [0.7293]])