<a href="https://colab.research.google.com/github/bachnguyenTE/temporal-mgn/blob/prototype-mgvae/mgvae_qm9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Tue Apr 19 13:57:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# Arguments (inactive)

# def _parse_args():
#     parser = argparse.ArgumentParser(description = 'Temporal graph learning')
#     parser.add_argument('--dir', '-dir', type = str, default = '.', help = 'Directory')
#     parser.add_argument('--learning_target', '-learning_target', type = str, default = 'U0', help = 'Learning target')
#     parser.add_argument('--name', '-name', type = str, default = 'NAME', help = 'Name')
#     parser.add_argument('--dataset', '-dataset', type = str, default = 'ZINC_12k', help = 'ZINC')
#     parser.add_argument('--num_epoch', '-num_epoch', type = int, default = 2048, help = 'Number of epochs')
#     parser.add_argument('--batch_size', '-batch_size', type = int, default = 20, help = 'Batch size')
#     parser.add_argument('--learning_rate', '-learning_rate', type = float, default = 0.001, help = 'Initial learning rate')
#     parser.add_argument('--seed', '-s', type = int, default = 123456789, help = 'Random seed')
#     parser.add_argument('--n_clusters', '-n_clusters', type = int, default = 2, help = 'Number of clusters')
#     parser.add_argument('--n_levels', '-n_levels', type = int, default = 3, help = 'Number of levels of resolution')
#     parser.add_argument('--n_layers', '-n_layers', type = int, default = 3, help = 'Number of layers of message passing')
#     parser.add_argument('--hidden_dim', '-hidden_dim', type = int, default = 32, help = 'Hidden dimension')
#     parser.add_argument('--z_dim', '-z_dim', type = int, default = 32, help = 'Latent dimension')
#     parser.add_argument('--device', '-device', type = str, default = 'cpu', help = 'cuda/cpu')
#     args = parser.parse_args()
#     return args

# args = _parse_args()
# log_name = args.dir + "/" + args.name + ".log"
# model_name = args.dir + "/" + args.name + ".model"
# LOG = open(log_name, "w")

In [3]:
# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
%%capture
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

!pip install einops
!wget -c https://gist.githubusercontent.com/Luvata/55f7b3e9ae451122b9e3faf0a7387b4f/raw/440fac5c6e7153fd39e4eb9ebec6e51c9520ef1f/visualize.py
!pip install --upgrade graphviz

In [4]:
# Library import (legacy MGVAE)
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, Adagrad
import pickle
from torch import optim
from torch.utils.data import DataLoader
import numpy as np
import os
import time
import argparse

# Library import (pytorch-geometric)
import torch_geometric 
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader, ClusterLoader, ClusterData
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import to_dense_adj, dense_to_sparse

###############################################################
# NOTE: 
# We preferably define our own clustering 
# procedure, rather than using the built-in ClusterLoader
# since there is a chance using ClusterLoader will not
# make the entire net differentiable (separate data process),
# and the net may no longer be isomorphic invariant.
###############################################################

from visualize import display_module

In [5]:
# Fix all random seed
torch_geometric.seed.seed_everything(69420)

# Set device to gpu
device = torch.device('cuda')

In [6]:
# Define glorot initialization
def glorot_init(input_dim, output_dim):
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range
    return nn.Parameter(initial)

In [7]:
# Multiresolution Graph Network
class MGN(nn.Module):
    def __init__(self, clusters, num_layers, node_dim, edge_dim, hidden_dim, z_dim, num_classes):
        super(MGN, self).__init__()
        self.clusters = clusters
        self.num_layers = num_layers
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim
        self.num_classes = num_classes

        self.base_encoder = GraphEncoder(self.num_layers, self.node_dim, self.edge_dim, self.hidden_dim, self.z_dim)

        self.cluster_learner = nn.ModuleList()
        self.global_encoder = nn.ModuleList()
        for i in range(len(self.clusters)):
            N = self.clusters[i]
            self.cluster_learner.append(GraphCluster(self.num_layers, self.z_dim, self.hidden_dim, N))
            self.global_encoder.append(GraphEncoder(self.num_layers, self.z_dim, None, self.hidden_dim, self.z_dim))

        D = self.z_dim * (len(self.clusters) + 1)
        self.fc1 = nn.Linear(D, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, adj, node_feat, edge_feat = None):
        outputs = []

        # Base encoder
        base_latent = self.base_encoder(adj, node_feat, edge_feat)

        outputs.append([base_latent, adj])

        l = 0
        while l < len(self.clusters):
            if l == 0:
                prev_adj = adj
                prev_latent = base_latent
            else:
                prev_adj = outputs[len(outputs) - 1][1]
                prev_latent = outputs[len(outputs) - 1][0]

            # Assignment score
            assign_score = self.cluster_learner[l](prev_adj, prev_latent)

            # Softmax (soft assignment)
            # assign_matrix = F.softmax(assign_score, dim = 2)

            # Gumbel softmax (hard assignment)
            assign_matrix = F.gumbel_softmax(assign_score, tau = 1, hard = True, dim = 1)

            # Print out the cluster assignment matrix
            # print(torch.sum(assign_matrix, dim = 0))

            # Shrinked latent
            shrinked_latent = torch.matmul(assign_matrix.transpose(0, 1), prev_latent)

            # Latent normalization
            shrinked_latent = F.normalize(shrinked_latent, dim = 0)

            # Shrinked adjacency
            # print(f'Iteration: {l}')
            # print(to_dense_adj(prev_adj))
            # print(f'to_dense_adj size: {to_dense_adj(prev_adj, max_num_nodes=self.clusters[l]).size()}')
            # print(f'assign_matrix size: {assign_matrix.size()}')
            # print(f'node_feat size: {node_feat.size()}')
            # print(f'adj size: {prev_adj.size()}')
            # print(f'prev_adj size: {prev_adj.size()}')
            if l == 0:
                shrinked_adj = torch.matmul(torch.matmul(assign_matrix.transpose(0, 1), to_dense_adj(prev_adj, max_num_nodes=node_feat.size()[0])[0]), assign_matrix)
            else:
                shrinked_adj = torch.matmul(torch.matmul(assign_matrix.transpose(0, 1), to_dense_adj(prev_adj, max_num_nodes=self.clusters[l - 1])[0]), assign_matrix)

            # Adjacency normalization
            shrinked_adj = shrinked_adj / torch.sum(shrinked_adj)

            # Reformatting adjacency matrix as edge index
            shrinked_adj = shrinked_adj[None, :]
            # print(f'shrinked_adj matrix: {shrinked_adj}')
            shrinked_adj, _ = dense_to_sparse(shrinked_adj)

            # Global encoder
            next_latent = self.global_encoder[l](shrinked_adj, shrinked_latent)

            outputs.append([next_latent, shrinked_adj])
            l += 1

        # Scalar prediction
        # print(f'size of output elem: {outputs[0][0].size()}')
        latent = torch.cat([torch.sum(output[0], dim = 0) for output in outputs], dim = 0)
        # print(f'final latent size: {latent.size()}')
        hidden = torch.tanh(self.fc1(latent))
        predict = self.fc2(hidden)

        return predict, latent, outputs

In [8]:
# Graph encoder block
class GraphEncoder(nn.Module):
    def __init__(self, num_layers, node_dim, edge_dim, hidden_dim, z_dim, use_concat_layer=True, **kwargs):
        super(GraphEncoder, self).__init__(**kwargs)
        self.num_layers = num_layers
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim
        self.use_concat_layer = use_concat_layer

        self.node_fc1 = nn.Linear(self.node_dim, 128)
        self.node_fc2 = nn.Linear(128, self.hidden_dim)

        if self.edge_dim is not None:
            self.edge_fc1 = nn.Linear(self.edge_dim, 128)
            self.edge_fc2 = nn.Linear(128, self.hidden_dim)

        self.base_net = nn.ModuleList()
        self.combine_net = nn.ModuleList()
        for layer in range(self.num_layers):
            self.base_net.append(GCNConv(self.hidden_dim, self.hidden_dim))
            if self.edge_dim is not None:
                self.combine_net.append(nn.Linear(2 * self.hidden_dim, self.hidden_dim))

        if self.use_concat_layer == True:
            self.latent_fc1 = nn.Linear((self.num_layers + 1) * self.hidden_dim, 256)
            self.latent_fc2 = nn.Linear(256, self.z_dim)
        else:
            self.latent_fc1 = nn.Linear(self.hidden_dim, 256)
            self.latent_fc2 = nn.Linear(256, self.z_dim)

    def forward(self, adj, node_feat, edge_feat=None):
        node_hidden = torch.tanh(self.node_fc1(node_feat))
        node_hidden = torch.tanh(self.node_fc2(node_hidden))

        if edge_feat is not None and self.edge_dim is not None:
            edge_hidden = torch.tanh(self.edge_fc1(edge_feat))
            edge_hidden = torch.tanh(self.edge_fc2(edge_hidden))

        all_hidden = [node_hidden]
        for layer in range(len(self.base_net)):
            if layer == 0:
                hidden = self.base_net[layer](node_hidden, adj)
            else:
                hidden = self.base_net[layer](node_hidden, adj)
            
            if edge_feat is not None and self.edge_dim is not None:
                hidden = torch.cat((hidden, torch.tanh(torch.einsum('bijc,bjk->bik', edge_hidden, hidden))), dim = 2)
                hidden = torch.tanh(self.combine_net[layer](hidden))
        
            all_hidden.append(hidden)

        if self.use_concat_layer == True:
            hidden = torch.cat(all_hidden, dim=1)

        latent = torch.tanh(self.latent_fc1(hidden))
        latent = torch.tanh(self.latent_fc2(latent))
        return latent

In [9]:
# Graph clustering block
class GraphCluster(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, z_dim, **kwargs):
        super(GraphCluster, self).__init__(**kwargs)
        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim

        self.fc1 = nn.Linear(self.input_dim, 128)
        self.fc2 = nn.Linear(128, self.hidden_dim)

        # Option 1: Learnable clustering
        self.base_net = nn.ModuleList()
        
        # Option 2: Fixed clustering
        # self.base_net = []

        for layer in range(self.num_layers):
            self.base_net.append(GCNConv(self.hidden_dim, self.hidden_dim))

        self.assign_net = GCNConv(self.hidden_dim, self.z_dim)

    def forward(self, adj, X):
        hidden = torch.sigmoid(self.fc1(X))
        hidden = torch.sigmoid(self.fc2(hidden))
        for net in self.base_net:
            hidden = net(hidden, adj)
        assign = self.assign_net(hidden, adj)
        return assign

In [55]:
# Dataset testing
from torch_geometric.datasets import ZINC, QM9, TUDataset, KarateClub, GNNBenchmarkDataset

# dataset = ZINC('data/ZINC')
dataset = QM9(root='data/QM7b')
# dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
# dataset = KarateClub()
# dataset = GNNBenchmarkDataset(root='data/GNNBenchmarkDataset', name='MNIST')

print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0] # Get the first graph object

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Dataset: QM9(130831):
Number of graphs: 130831
Number of features: 11
Number of classes: 19

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], idx=[1], name='gdb_1', z=[5])
Number of nodes: 5
Number of edges: 8
Average node degree: 1.60
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [56]:
dataset = dataset.shuffle()

train_dataset = dataset[:1000]
test_dataset = dataset[1000:1100]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 1000
Number of test graphs: 100


In [57]:
# Minibatching the dataset 
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, shuffle=True)
test_loader = DataLoader(test_dataset, shuffle=False)

In [58]:
# adj is the adjacency matrix
# PyG: need to convert graph data format to the adjacency matrix format, or rewrite code

In [65]:
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = MGN(
    clusters=[4, 2],
    num_layers=4,
    node_dim=dataset.num_features,
    edge_dim=None,
    hidden_dim=64,
    z_dim=16,
    num_classes=dataset.num_classes
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data.to(device)
        predict, latent, outputs = model(data.edge_index, data.x)  # Perform a single forward pass.
        loss = criterion(predict[None, :], data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data.to(device)
        predict, latent, outputs = model(data.edge_index, data.x)  
        modelLoss = criterion(predict[None, :], data.y)
    return modelLoss  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Loss: {train_acc:.4f}, Test Loss: {test_acc:.4f}')

<IPython.core.display.Javascript object>

Epoch: 001, Train Loss: 10523363.0000, Test Loss: 31387712.0000
Epoch: 002, Train Loss: 7632530.5000, Test Loss: 21199454.0000
Epoch: 003, Train Loss: 5837754.0000, Test Loss: 13780696.0000
Epoch: 004, Train Loss: 902384.4375, Test Loss: 8641111.0000
Epoch: 005, Train Loss: 774752.1875, Test Loss: 5361662.5000
Epoch: 006, Train Loss: 714150.8750, Test Loss: 3533513.2500
Epoch: 007, Train Loss: 6508.6162, Test Loss: 2755543.0000
Epoch: 008, Train Loss: 371119.2188, Test Loss: 2518667.0000
Epoch: 009, Train Loss: 93806.3203, Test Loss: 2512439.5000
Epoch: 010, Train Loss: 11583.3301, Test Loss: 2527837.7500
Epoch: 011, Train Loss: 342976.7812, Test Loss: 2462667.2500
Epoch: 012, Train Loss: 232957.3750, Test Loss: 2499491.5000
Epoch: 013, Train Loss: 344506.0938, Test Loss: 2515739.5000
Epoch: 014, Train Loss: 92312.4453, Test Loss: 2512872.0000
Epoch: 015, Train Loss: 34557.6289, Test Loss: 2485619.0000
Epoch: 016, Train Loss: 109966.9766, Test Loss: 2464449.7500
Epoch: 017, Train Loss:

In [66]:
dataset[0]

Data(x=[19, 11], edge_index=[2, 44], edge_attr=[44, 4], y=[1, 19], pos=[19, 3], idx=[1], name='gdb_78976', z=[19])

In [71]:
data = dataset[0]
data.to(device)
predict, latent, outputs = model(data.edge_index, data.x)

In [72]:
predict

tensor([ 3.7165e+00,  7.9656e+01, -6.9316e+00,  3.3466e-01,  6.7760e+00,
         1.1531e+03,  4.0558e+00, -1.1217e+04, -1.1217e+04, -1.1217e+04,
        -1.1218e+04,  3.3927e+01, -8.2029e+01, -8.2510e+01, -8.2974e+01,
        -7.6420e+01,  1.5397e+00,  1.3349e+00,  1.5131e+00], device='cuda:0',
       grad_fn=<AddBackward0>)

In [73]:
data.y

tensor([[ 1.7379e+00,  7.5730e+01, -6.5770e+00,  2.2749e+00,  8.8519e+00,
          8.6733e+02,  4.4423e+00, -1.0500e+04, -1.0500e+04, -1.0500e+04,
         -1.0501e+04,  2.7433e+01, -8.2382e+01, -8.2946e+01, -8.3409e+01,
         -7.6727e+01,  2.9823e+00,  1.9835e+00,  1.7279e+00]], device='cuda:0')