# Autoregressive Diffusion Models on Graphs

In [12]:
import torch_geometric
from torch_geometric.datasets import ZINC
from tqdm import tqdm
import torch
from torch import nn
from torch_geometric.nn import MessagePassing, GAT
from torch_geometric.utils import add_self_loops, degree
from torch.nn import functional as F
import math

from benchmarks.GraphARM.models import DiffusionOrderingNetwork, DenoisingNetwork
from benchmarks.GraphARM.utils import NodeMasking

In [3]:
# instanciate the dataset
dataset = ZINC(root='../data/ZINC', transform=None, pre_transform=None)

In [4]:
point = dataset[1]
point

Data(x=[18, 1], edge_index=[2, 36], edge_attr=[36], y=[1])

In [5]:
point.edge_attr

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1,
        2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2])

In [6]:
# absence of edges should be an edge type, attribute 0
# TODO fully connect graph with unexisting edges in the beginning

In [7]:
point.y

tensor([0.4907])

In [8]:
point.to_dict()

{'x': tensor([[0],
         [1],
         [0],
         [0],
         [4],
         [0],
         [0],
         [1],
         [2],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [0],
         [0]]),
 'edge_index': tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  4,  5,  6,  6,  6,  7,  8,  8,  8,
           9, 10, 10, 11, 11, 11, 12, 12, 13, 13, 14, 14, 14, 15, 16, 16, 17, 17],
         [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  6,  4,  4,  7,  8,  6,  6,  9, 10,
           8,  8, 11, 10, 12, 17, 11, 13, 12, 14, 13, 15, 16, 14, 14, 17, 11, 16]]),
 'edge_attr': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1,
         2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2]),
 'y': tensor([0.4907])}

Node decay ordering, for forward absorbing pass

![Alt text](media/don.png)

![Alt text](media/don_2.png)

In [9]:
diff_ord_net = DiffusionOrderingNetwork(node_feature_dim=1,
                                        num_node_types=dataset.x.unique().shape[0] + 1,
                                        num_edge_types=3,
                                        num_layers=3,
                                        out_channels=1)
sigma_t_dist = diff_ord_net(point)
print(sigma_t_dist.flatten())
# sample from categorical distribution to get node to mask
# TODO only on the samples that are not masked
# sigma_t = F.softmax(sigma_t_dist.flatten(), dim=0)
torch.distributions.Categorical(probs=sigma_t_dist.flatten()).sample()

tensor([0.0546, 0.0551, 0.0559, 0.0574, 0.0578, 0.0589, 0.0567, 0.0560, 0.0557,
        0.0557, 0.0552, 0.0545, 0.0544, 0.0544, 0.0544, 0.0545, 0.0544, 0.0544],
       grad_fn=<ReshapeAliasBackward0>)


  return src.new_zeros(size).scatter_reduce_(


tensor(11)

In [13]:
masker = NodeMasking(dataset)

### Forward Diffusion Process

Node ordering $\sigma$ using diffusion ordering network. Exactly one node decays at a time.

At each step $t$, distribution of $t$-th node is conditioned on original graph $G$ and generated node ordering $\sigma$ up to $t-1$.

In [14]:
def node_decay_ordering(datapoint):
    p = datapoint.clone()
    node_order = []
    for i in range(p.x.shape[0]):
        # use diffusion ordering network to get probabilities
        sigma_t_dist = diff_ord_net(p, i)
        # sample from categorical distribution to get node to mask
        # TODO only on the samples that are not masked
        unmasked = ~masker.is_masked(p)
        sigma_t = torch.distributions.Categorical(probs=sigma_t_dist[unmasked].flatten()).sample()
        
        # get node index
        sigma_t = torch.where(unmasked.flatten())[0][sigma_t.long()]
        node_order.append(sigma_t)
        # mask node
        p = masker.mask_node(p, sigma_t)
    return node_order

In [15]:
node_decay_ordering(point)

[tensor(12),
 tensor(6),
 tensor(0),
 tensor(4),
 tensor(5),
 tensor(7),
 tensor(11),
 tensor(1),
 tensor(13),
 tensor(3),
 tensor(15),
 tensor(8),
 tensor(10),
 tensor(2),
 tensor(9),
 tensor(16),
 tensor(14),
 tensor(17)]

## Reverse Diffusion (Generative) Process

Denoising network $p_ \theta (G_t | G_{t+1})$ is a graph attention network (GAT). (Vanilla GAT)

For now, simple graph convolutional network (GCN) is used.


** Initially, the graph is fully connected and masked, so in the paper they only keep the masked node to be denoised during the generation step. 

For now, we will just use the whole graph as well.

In [16]:
'''
Message passing
TODO find out if custom message passing function is necessary
'''

'\nMessage passing\nTODO find out if custom message passing function is necessary\n'

Make sure attentive message passing is done correctly.


![Alt text](media/image.png)

In [17]:

from torch.nn import Linear, Parameter
class MessagePassingLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MessagePassingLayer, self).__init__(aggr='mean', flow="target_to_source")  # 'add' aggregation for simplicity

        self.W = nn.Linear(in_channels, out_channels)
        # self.attention = nn.Linear(2 * in_channels, attention_dim)
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.W.reset_parameters()
        self.bias.data.zero_()

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j # TODO change to adapt attention mechanism
    
    
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Attention mechanism
        # alpha = F.softmax(self.attention(torch.cat([x[row], x[col]], dim=-1)), dim=-1)
        # alpha = F.dropout(alpha, p=0.5)
        # messages = alpha.view(-1, 1) * edge_attr
        # messages = scatter_add(messages, col, dim=0, dim_size=x.size(0))

        # h = self.W(torch.cat([x, messages], dim=-1))

        out = self.propagate(edge_index, x=x, norm=norm)
        out += self.bias

        return out
    


![Alt text](media/architecture.png)


In [18]:
# num_features
denoising_net = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.x.unique().shape[0] + 1,
    num_edge_types=3,
    num_layers=7,
    out_channels=1
)

1 29 3 7 1


In [19]:
p = point.clone()
node_type_probs, edge_type_probs = denoising_net(p)
edge_type_probs # connections of new node to all previous nodes
node_type_probs[0].shape

torch.Size([29])

## Training

For now, use MSE / Reconstruction error instead of loss function from the paper.


Optimize by sampling multiple ($M$) diffusion trajectories, thereby enabling training both the diffusion ordering network $q_ϕ(σ|G_0)$ and the denoising network
$p_θ(Gt|G_t+1)$ using gradient descent.

Create M trajectories (sequence of graphs) for each training graph. Where node decay order is sampled from $q_ϕ(σ|G_0)$ (or random, initially).

Train denoising network to minimize the negative VLB using SGD.

Diffusion ordering network can be updated with common RL optimization methods, e.g., the REINFORCE algorithm. Creating M trajectories and computing the negative VLB to obtain rewards, and then updating the parameters of the diffusion ordering network using the REINFORCE algorithm.

In [20]:
# use wandb to log the model
import wandb

denoising_net = DenoisingNetwork(
    node_feature_dim=dataset.num_features,
    num_node_types=dataset.x.unique().shape[0] + 1,
    num_edge_types=3,
    num_layers=7,
    out_channels=1
)

wandb.init(
        project="ARGD",
        group=f"v0.5",
        name=f"denoising_and_ordering",
        # track hyperparameters and run metadata
        config={
            "policy": "train",
            "n_epochs": 10000,
            "batch_size": 1,
        }
    )


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


1 29 3 7 1


[34m[1mwandb[0m: Currently logged in as: [33mcaiofreitas[0m. Use [1m`wandb login --relogin`[0m to force relogin


![Alt text](media/optimizer.png)

In [21]:
# train denoising diffusion model
# use node decay ordering

EPSILON = 1e-8

optimizer = torch.optim.Adam(denoising_net.parameters(), lr=1e-4, betas=(0.9, 0.999))
ordering_optimizer = torch.optim.Adam(diff_ord_net.parameters(), lr=5e-4, betas=(0.9, 0.999))

M = 4 # number of diffusion trajectories to be created for each graph

for superbatch in range(10, 20):
    loss = 0
    with tqdm(range(4*superbatch, 2+4*superbatch)) as pbar:
        denoising_net.train()
        diff_ord_net.eval()
        for batch in pbar:
            graph = dataset[batch]
            # print(f"Generating trajectories for graph")
            # node decay ordering, accoding to node_decay_ordering
            original_data = graph.clone()
            diffusion_trajectories = []
            for m in range(M):
                node_order = node_decay_ordering(graph)
                
                # create diffusion trajectory
                diffusion_trajectory = [original_data]
                masked_data = graph.clone()
                for node in node_order:
                    masked_data = masked_data.clone()
                    
                    masked_data = masker.mask_node(masked_data, node)
                    diffusion_trajectory.append(masked_data)

                diffusion_trajectories.append(diffusion_trajectory)
            
            # predictions & loss
            for diffusion_trajectory in diffusion_trajectories:
                G_0 = diffusion_trajectory[0]
                # optimizer.zero_grad()
                node_order = node_decay_ordering(G_0)
                for t in range(1, len(diffusion_trajectory)-1):
                    node = node_order[len(diffusion_trajectory)-t-1]
                    G_t = diffusion_trajectory[t]
                    # transform to float
                    G_t.x = G_t.x.float()
                    G_t.edge_index = G_t.edge_index.float()

                    G_tplus1 = diffusion_trajectory[t+1].clone()
                    # transform to float
                    G_tplus1.x = G_tplus1.x.float()
                    G_tplus1.edge_index = G_tplus1.edge_index.float()

                    G_pred = G_tplus1.clone()
                    

                    # predict node type
                    node_type_probs, edge_type_probs = denoising_net(G_pred)
                    
                    #  Apply multinomial distribution
                    # node_dist = torch.distributions.Multinomial(probs=node_type_probs.squeeze(), total_count=1)
                    # node_type = node_dist.sample()
                    # edge_dist = torch.distributions.Multinomial(probs=edge_type_probs.squeeze(), total_count=G_tplus1.x.shape[0])
                    # node_connections = edge_dist.sample()

                    # add node type to node
                    # node_dist = torch.distributions.Categorical(probs=node_type_probs.squeeze())
                    # node_type = node_dist.sample()
                    # print(node_type)
                    # G_pred.x[node] = node_type
                    # sample edge type
                    # new_connections = torch.multinomial(node_connections_probs.squeeze(), num_samples=1, replacement=True)
                    '''
                    TODO make sure you 
                    "predict the connections of the new node to all previous nodes at once 
                                            using a mixture of multinomial distribution"
                    '''
                    # add new connections to edge_attr
                    # for i in  range(len(node_connections)):
                    #     new_connection = node_connections[i]
                    #     if new_connection != 0:
                    #         G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection
                    
                    # calculate loss
                    p_O_v =  node_type_probs[node].mean() + EPSILON # TODO add edges (joint probability)
                    w_k = diff_ord_net(G_tplus1)[node]
                    n_i = G_t.x.shape[0]
                    loss -= (n_i/(len(diffusion_trajectory)-1))*torch.log(p_O_v)*w_k
    
    loss /= M
    # backprop
    loss.backward()
    # update parameters
    optimizer.step()
    
    # log loss
    pbar.set_description(f"Epoch: {batch}, Loss: {loss.item()%10:.4f}")
    wandb.log({"loss": loss.item()})
    # validation batch (for diffusion ordering network)

    reward = 0
    with tqdm(range(2+4*superb
        self.diffusion_ordering_network.train()atch, 4+4*superbatch)) as pbar:
        for i in pbar:
            graph = dataset[i]
            n_i = graph.x.shape[0]
            original_data = graph.clone()
            diffusion_trajectories = []
            
            denoising_net.eval()
            diff_ord_net.train()

            # Generate M diffusion trajectories
            for m in range(M):
                node_order = node_decay_ordering(original_data)
                
                # create diffusion trajectory
                diffusion_trajectory = [original_data]
                masked_data = graph.clone()
                for node in node_order:
                    masked_data = masked_data.clone()
                    
                    masked_data = masker.mask_node(masked_data, node)
                    diffusion_trajectory.append(masked_data)

                diffusion_trajectories.append(diffusion_trajectory)

            for diffusion_trajectory in diffusion_trajectories:
                G_0 = diffusion_trajectory[0]
                node_order = node_decay_ordering(G_0)
                for t in range(len(diffusion_trajectory)-1):
                    node = node_order[G_0.x.shape[0] - t - 1]
                    G_tplus1 = diffusion_trajectory[t+1]
                    # predict node type
                    node_type_probs, edge_type_probs = denoising_net(G_tplus1)
                    # node_type_probs.register_hook(lambda grad: print(grad.mean()))

                    # node_type_probs = node_type_probs[node]
                    #  Apply multinomial distribution
                    # node_dist = torch.distributions.Multinomial(probs=node_type_probs.squeeze(), total_count=1)
                    # node_type = node_dist.sample()

                    # calculate reward (VLB)
                    
                    p_O_v =  node_type_probs[node].mean() + EPSILON # TODO add edges (joint probability)

                    # reward -= torch.log(O_v)*n_i/(len(diffusion_trajectory)-1)
                    r = (n_i/(len(diffusion_trajectory)-1))*torch.log(p_O_v)
                    w_k = diff_ord_net(G_tplus1)[node]

                    reward -= w_k*r
    reward /= M
    wandb.log({"reward": reward.item()})
    # update parameters (REINFORCE algorithm)
    ordering_optimizer.zero_grad()
    reward.backward()
    ordering_optimizer.step()
    pbar.set_description(f"Epoch: {batch}, Loss: {reward.item()%10:.4f}")
    # save model
    torch.save(denoising_net.state_dict(), "ardm_model_overfit.pt")
    torch.save(diff_ord_net.state_dict(), "ordering_model_overfit.pt")
                
    '''
    TODO: edge_attributes have to be used. The masked nodes can be identified through them. There'll never be error in edge_index
    '''

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:01<00:00,  1.20it/s]
100%|██████████| 2/2 [00:01<00:00,  1.96it/s]
100%|██████████| 2/2 [00:01<00:00,  1.56it/s]
100%|██████████| 2/2 [00:01<00:00,  1.68it/s]
100%|██████████| 2/2 [00:01<00:00,  1.38it/s]
100%|██████████| 2/2 [00:01<00:00,  1.40it/s]
100%|██████████| 2/2 [00:01<00:00,  1.64it/s]
100%|██████████| 2/2 [00:01<00:00,  1.82it/s]
100%|██████████| 2/2 [00:01<00:00,  1.46it/s]
100%|██████████| 2/2 [00:01<00:00,  1.51it/s]
100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
100%|██████████| 2/2 [00:01<00:00,  1.68it/s]
100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
100%|██████████| 2/2 [00:01<00:00,  1.55it/s]
100%|██████████| 2/2 [00:01<00:00,  1.30it/s]
100%|██████████| 2/2 [00:01<00:00,  1.31it/s]
100%|██████████| 2/2 [00:01<00:00,  1.57it/s]
100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
100%|██████████| 2/2 [00:01<00:00,  1.59it/s]


In [127]:
G_pred = diffusion_trajectory[-1].clone()
G_pred.x

tensor([[28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28]])

In [128]:
# compute node_order
node_order = node_decay_ordering(G_0)

forward_pass = []

for i in node_order:
    forward_pass.append(G_pred.clone())
    # 1 diffusion step for each node
    with torch.no_grad():
        node_type_probs, edge_type_probs = denoising_net(G_pred)
        # sample node type
        node_dist = torch.distributions.Categorical(probs=node_type_probs.squeeze())
        node_type = node_dist.sample()[i]
        print(node_type)
        edge_dist = torch.distributions.Multinomial(probs=edge_type_probs.squeeze(), total_count=G_tplus1.x.shape[0])
        node_connections = edge_dist.sample()
        G_pred.x[i] = node_type
        # add new connections to edge_attr
        # for i in  range(len(node_connections)):
        #     new_connection = node_connections[i]
        #     if new_connection != 0:
        #         G_pred.edge_attr[G_pred.edge_index[0] == i] = new_connection

tensor(13)
tensor(27)
tensor(19)
tensor(26)
tensor(24)
tensor(25)
tensor(5)
tensor(11)
tensor(1)
tensor(26)
tensor(20)
tensor(4)
tensor(26)
tensor(6)
tensor(25)
tensor(19)
tensor(2)
tensor(24)
tensor(24)
tensor(2)
tensor(19)
tensor(19)
tensor(15)
tensor(28)
tensor(16)
tensor(11)
tensor(24)
tensor(13)
tensor(15)
tensor(24)
tensor(25)
tensor(22)
tensor(2)


In [130]:
forward_pass[0].x

tensor([[28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28],
        [28]])

In [None]:
for i in range(50):
    denoising_net(dataset[i])