<a href="https://colab.research.google.com/github/bachnguyenTE/temporal-mgn/blob/prototype-mgvae/mgvae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Wed Apr  6 08:53:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# Arguments (inactive)

# def _parse_args():
#     parser = argparse.ArgumentParser(description = 'Temporal graph learning')
#     parser.add_argument('--dir', '-dir', type = str, default = '.', help = 'Directory')
#     parser.add_argument('--learning_target', '-learning_target', type = str, default = 'U0', help = 'Learning target')
#     parser.add_argument('--name', '-name', type = str, default = 'NAME', help = 'Name')
#     parser.add_argument('--dataset', '-dataset', type = str, default = 'ZINC_12k', help = 'ZINC')
#     parser.add_argument('--num_epoch', '-num_epoch', type = int, default = 2048, help = 'Number of epochs')
#     parser.add_argument('--batch_size', '-batch_size', type = int, default = 20, help = 'Batch size')
#     parser.add_argument('--learning_rate', '-learning_rate', type = float, default = 0.001, help = 'Initial learning rate')
#     parser.add_argument('--seed', '-s', type = int, default = 123456789, help = 'Random seed')
#     parser.add_argument('--n_clusters', '-n_clusters', type = int, default = 2, help = 'Number of clusters')
#     parser.add_argument('--n_levels', '-n_levels', type = int, default = 3, help = 'Number of levels of resolution')
#     parser.add_argument('--n_layers', '-n_layers', type = int, default = 3, help = 'Number of layers of message passing')
#     parser.add_argument('--hidden_dim', '-hidden_dim', type = int, default = 32, help = 'Hidden dimension')
#     parser.add_argument('--z_dim', '-z_dim', type = int, default = 32, help = 'Latent dimension')
#     parser.add_argument('--device', '-device', type = str, default = 'cpu', help = 'cuda/cpu')
#     args = parser.parse_args()
#     return args

# args = _parse_args()
# log_name = args.dir + "/" + args.name + ".log"
# model_name = args.dir + "/" + args.name + ".model"
# LOG = open(log_name, "w")

In [3]:
# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
%%capture
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

In [4]:
# Library import (legacy MGVAE)
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, Adagrad
import pickle
from torch import optim
from torch.utils.data import DataLoader
import numpy as np
import os
import time
import argparse

# Library import (pytorch-geometric)
import torch_geometric 
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader, ClusterLoader, ClusterData
from torch_geometric.nn import MessagePassing, GCNConv

###############################################################
# NOTE: 
# We preferably define our own clustering 
# procedure, rather than using the built-in ClusterLoader
# since there is a chance using ClusterLoader will not
# make the entire net differentiable (separate data process),
# and the net may no longer be isomorphic invariant.
###############################################################

In [5]:
# Fix all random seed
torch_geometric.seed.seed_everything(69420)

# Set device to gpu
device = torch.device('cuda')

In [None]:
# Define glorot initialization
def glorot_init(input_dim, output_dim):
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range
    return nn.Parameter(initial)

In [24]:
# Multiresolution Graph Network
class MGN(nn.Module):
    def __init__(self, clusters, num_layers, node_dim, edge_dim, hidden_dim, z_dim):
        super(MGN, self).__init__()
        self.clusters = clusters
        self.num_layers = num_layers
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim

        self.base_encoder = GraphEncoder(self.num_layers, self.node_dim, self.edge_dim, self.hidden_dim, self.z_dim)

        self.cluster_learner = nn.ModuleList()
        self.global_encoder = nn.ModuleList()
        for i in range(len(self.clusters)):
            N = self.clusters[i]
            self.cluster_learner.append(GraphCluster(self.num_layers, self.z_dim, self.hidden_dim, N))
            self.global_encoder.append(GraphEncoder(self.num_layers, self.z_dim, None, self.hidden_dim, self.z_dim))

    def forward(self, adj, node_feat, edge_feat = None):
        outputs = []

        # Base encoder
        base_latent = self.base_encoder(adj, node_feat, edge_feat)

        outputs.append([base_latent, adj])

        l = len(self.clusters) - 1
        while l >= 0:
            if l == len(self.clusters) - 1:
                prev_adj = adj
                prev_latent = base_latent
            else:
                prev_adj = outputs[len(outputs) - 1][1]
                prev_latent = outputs[len(outputs) - 1][0]

            # Assignment score
            assign_score = self.cluster_learner[l](prev_adj, prev_latent)

            # Softmax (soft assignment)
            # assign_matrix = F.softmax(assign_score, dim = 2)

            # Gumbel softmax (hard assignment)
            assign_matrix = F.gumbel_softmax(assign_score, tau = 1, hard = True, dim = 2)

            # Print out the cluster assignment matrix
            # print(torch.sum(assign_matrix, dim = 0))

            # Shrinked latent
            shrinked_latent = torch.matmul(assign_matrix.transpose(1, 2), prev_latent)

            # Latent normalization
            shrinked_latent = F.normalize(shrinked_latent, dim = 1)

            # Shrinked adjacency
            shrinked_adj = torch.matmul(torch.matmul(assign_matrix.transpose(1, 2), prev_adj), assign_matrix)

            # Adjacency normalization
            shrinked_adj = shrinked_adj / torch.sum(shrinked_adj)

            # Global encoder
            next_latent = self.global_encoder[l](shrinked_adj, shrinked_latent)

            outputs.append([next_latent, shrinked_adj])
            l -= 1

        # Scalar prediction
        latent = torch.cat([torch.sum(output[0], dim = 1) for output in outputs], dim = 1)
        hidden = torch.tanh(self.fc1(latent))
        predict = self.fc2(hidden)

        return predict, latent, outputs

In [21]:
# Graph encoder block
class GraphEncoder(nn.Module):
    def __init__(self, num_layers, node_dim, edge_dim, hidden_dim, z_dim, use_concat_layer=True, **kwargs):
        super(GraphEncoder, self).__init__(**kwargs)
        self.clusters = clusters
        self.num_layers = num_layers
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim
        self.use_concat_layer = use_concat_layer

        self.node_fc1 = nn.Linear(self.node_dim, 128)
        self.node_fc2 = nn.Linear(128, self.hidden_dim)

        if self.edge_dim is not None:
            self.edge_fc1 = nn.Linear(self.edge_dim, 128)
            self.edge_fc2 = nn.Linear(128, self.hidden_dim)

        self.base_net = nn.ModuleList()
        self.combine_net = nn.ModuleList()
        for layer in range(self.num_layers):
            self.base_net.append(GCNConv(self.hidden_dim, self.hidden_dim))
            if self.edge_dim is not None:
                self.combine_net.append(nn.Linear(2 * self.hidden_dim, self.hidden_dim))

        if self.use_concat_layer == True:
            self.latent_fc1 = nn.Linear((self.num_layers + 1) * self.hidden_dim, 256)
            self.latent_fc2 = nn.Linear(256, self.z_dim)
        else:
            self.latent_fc1 = nn.Linear(self.hidden_dim, 256)
            self.latent_fc2 = nn.Linear(256, self.z_dim)

    def forward(self, adj, node_feat, edge_feat=None):
        node_hidden = torch.tanh(self.node_fc1(node_feat))
        node_hidden = torch.tanh(self.node_fc2(node_hidden))

        if edge_feat is not None and self.edge_dim is not None:
            edge_hidden = torch.tanh(self.edge_fc1(edge_feat))
            edge_hidden = torch.tanh(self.edge_fc2(edge_hidden))

        all_hidden = [node_hidden]
        for layer in range(len(self.base_net)):
            if layer == 0:
                hidden = self.base_net[layer](adj, node_hidden)
            else:
                hidden = self.base_net[layer](adj, hidden)
            
            if edge_feat is not None and self.edge_dim is not None:
                hidden = torch.cat((hidden, torch.tanh(torch.einsum('bijc,bjk->bik', edge_hidden, hidden))), dim = 2)
                hidden = torch.tanh(self.combine_net[layer](hidden))
        
            all_hidden.append(hidden)

        if self.use_concat_layer == True:
            hidden = torch.cat(all_hidden, dim = 2)

        latent = torch.tanh(self.latent_fc1(hidden))
        latent = torch.tanh(self.latent_fc2(latent))
        return latent

In [23]:
# Graph clustering block
class GraphCluster(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, z_dim, **kwargs):
        super(GraphCluster, self).__init__(**kwargs)
        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim

        self.fc1 = nn.Linear(self.input_dim, 128)
        self.fc2 = nn.Linear(128, self.hidden_dim)

        # Option 1: Learnable clustering
        # self.base_net = nn.ModuleList()
        
        # Option 2: Fixed clustering
        self.base_net = []

        for layer in range(self.num_layers):
            self.base_net.append(GCNConv(self.hidden_dim, self.hidden_dim))

        self.assign_net = GCNConv(self.hidden_dim, self.z_dim)

    def forward(self, adj, X):
        hidden = torch.sigmoid(self.fc1(X))
        hidden = torch.sigmoid(self.fc2(hidden))
        for net in self.base_net:
            hidden = net(adj, hidden)
        assign = self.assign_net(adj, hidden)
        return assign