# Requirements
* pytorch-geometric
* gpytorch



In [None]:
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
# !pip install torch-geometric
# !pip install gpytorch

# General Imports

In [None]:
import os
import json
import logging
import warnings
import torch
import gpytorch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Gaussian Process Model
In order to build the main model, one must write some GPyTorch modules:
* Distribution
* Variational strategy
* Variational ELBO

### Input Dependent Variational Distribution
The distribution must be implemented in such a way that its parameters are updated based on the estimation given by an outside non-linear function (i.e., a Graph Neural Network). Thus, the new class, called **InputDependentDistribution**, instantiates the GPyTorch class **_VariationalDistribution** and add a new function ***update_params($q_{\mu}\in\mathbb{R}^{N\times C\times M}$, $q_{L}\in\mathbb{R}^{N\times M\times M}$)***, where $N$ is the number of samples, $D$ is the number of features, and $M$ is the number of inducing points.

In [None]:
from gpytorch.variational import _VariationalDistribution
from gpytorch.lazy import TriangularLazyTensor, CholLazyTensor
from gpytorch.distributions import MultivariateNormal

class InputDependentVariationalDistribution(_VariationalDistribution):
    def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3):
        super(InputDependentVariationalDistribution, self).__init__(
            num_inducing_points=num_inducing_points,
            batch_shape=batch_shape,
            mean_init_std=mean_init_std
        )

        self.variational_mean = torch.zeros(num_inducing_points).repeat(*batch_shape, 1).unsqueeze(0)
        self.chol_variational_covar = torch.eye(num_inducing_points).repeat(*batch_shape, 1, 1).unsqueeze(0)

    def forward(self):
        chol_variational_covar = TriangularLazyTensor(self.chol_variational_covar)
        variational_covar = CholLazyTensor(chol_variational_covar)
        return MultivariateNormal(self.variational_mean, variational_covar)

    def initialize_variational_distribution(self, prior_dist):
        pass

    def shape(self):
        return torch.Size(self.variational_mean.shape)

    def update_params(self, variational_mean, chol_variational_covar_vec):
        self._update_variational_mean(variational_mean)
        self._update_chol_variational_covar(chol_variational_covar_vec)

    def _update_variational_mean(self, variational_mean):
        self.variational_mean = variational_mean

    def _update_chol_variational_covar(self, chol_variational_covar_vec):
        # getting the indices of a triangular inferior matrix
        tril_i = torch.tril_indices(self.num_inducing_points, self.num_inducing_points)

        # getting the indices of the diagonal
        diag_i = torch.arange(self.num_inducing_points)

        # transforming the vectorized cholesky into a matrix
        num_inputs, num_classes, _ = chol_variational_covar_vec.shape
        chol_variational_covar_shape = [num_inputs, num_classes, self.num_inducing_points, self.num_inducing_points]
        self.chol_variational_covar = torch.zeros(chol_variational_covar_shape, device=chol_variational_covar_vec.device)
        self.chol_variational_covar[..., tril_i[0], tril_i[1]] = chol_variational_covar_vec

        # the diagonal of the matrices pass thourgh a Softplus function
        self.chol_variational_covar[..., diag_i, diag_i] = F.softplus(self.chol_variational_covar[..., diag_i, diag_i])

### Input Dependent Variational Strategy
The variational strategy must be rewritten in order to support 3D tensor operations. Thus, a new class **InputDependentVariationalStrategy** is created which instantiates **VariationalStrategy**. At last, the functions ***prior_distribution*** and ***forward*** are rewritten.

In [None]:
from gpytorch.variational import VariationalStrategy
from gpytorch.lazy import DiagLazyTensor, SumLazyTensor, MatmulLazyTensor
from gpytorch.utils import cached
from gpytorch.settings import trace_mode, _linalg_dtype_cholesky

class InputDependentVariationalStrategy(VariationalStrategy):
    def __init__(self, model, inducing_points, variational_distribution):
        super(InputDependentVariationalStrategy, self).__init__(
            model, inducing_points, variational_distribution, learn_inducing_locations=False
        )

        delattr(self, "inducing_points")
        self.inducing_points = inducing_points

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self):
        zeros = torch.zeros(self.variational_distribution.mean.shape, device=self.inducing_points.device)
        ones = torch.ones_like(zeros, device=zeros.device)
        return MultivariateNormal(zeros, DiagLazyTensor(ones))

    def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None, **kwargs):

        # compute p(f), p(u), and Kuf
        dist_data = self.model.forward(x, covar_part="Kff", **kwargs)
        dist_induc = self.model.forward(inducing_points, covar_part="Kuu", **kwargs)
        induc_data_covar = self.model.covar_module(inducing_points, x, covar_part="Kuf", **kwargs)

        test_mean = dist_data.mean
        induc_induc_covar = dist_induc.lazy_covariance_matrix.add_jitter(1e-4)
        data_data_covar = dist_data.lazy_covariance_matrix
        induc_data_covar = induc_data_covar.evaluate()

        # compute interpolation terms
        # Kuu^{-1/2} Kuf
        L = self._cholesky_factor(induc_induc_covar)
        interp_term = L.inv_matmul(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(inducing_points.dtype)

        # compute the mean of q(f)
        # Kfu Kuu^{-1/2} (m - Kuu^{-1/2}u_mean) + f_mean
        predictive_mean = (interp_term.unsqueeze(-3).transpose(-1, -2) @ inducing_values.unsqueeze(-1)).squeeze(-1) + test_mean.unsqueeze(-1)

        # compute the covariance of q(f)
        # Kff + Kfu Kuu^{-1/2} (S - I) Kuu^{-1/2} Kuf
        middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
        if variational_inducing_covar is not None:
            middle_term = SumLazyTensor(variational_inducing_covar, middle_term)

        if trace_mode.on():
            predictive_covar = (
                data_data_covar.add_jitter(1e-4)
                + interp_term.transpose(-1, -2) @ middle_term.evaluate() @ interp_term
        )
        else:
            predictive_covar = SumLazyTensor(
                data_data_covar.add_jitter(1e-4).unsqueeze(-3),
                MatmulLazyTensor(interp_term.unsqueeze(-3).transpose(-1, -2), middle_term @ interp_term.unsqueeze(-3))
            )

        # return the distribution
        return MultivariateNormal(predictive_mean, predictive_covar)
    
    def __call__(self, x, prior=False, **kwargs):
        if not self.training:
            self._clear_cache()
        return super().__call__(x, prior=prior, **kwargs)

### Input Dependent Variational ELBO
The variational ELBO must be adapted to 3D tensor operations. In this way, a new class **InputDependentVariationalELBO** is created which instantiates **VariationalELBO** and the function ***_log_likelihood_term*** is rewritten.

In [None]:
from gpytorch.mlls import VariationalELBO

class InputDependentVariationalELBO(VariationalELBO):
    def _log_likelihood_term(self, variational_dist_f, target, **kwargs):
        return torch.diag(self.likelihood.expected_log_prob(target, variational_dist_f, **kwargs))

### Input Dependent Stochastic Variational Gaussian Process
The GP model is build using the **InputDependentVariationalDistribution** and **InputDependentVariationalStrategy** classes. In this sense, the class **InputDependentSVGP** is created which instantiates the **ApproximateGP** from *gpytorch.models*. The **InputDependentSVGP** class also implements the function ***update_variational_params*** which is used to update the variational parameters by calling the function ***update_params*** from the **InputDependentVariationalDistribution** class.

In [None]:
from gpytorch.models import ApproximateGP

class InputDependentSVGP(ApproximateGP):
    def __init__(self, mean, kernel, inducing_points, num_tasks):
        variational_distribution = InputDependentVariationalDistribution(inducing_points.shape[-2], torch.Size([num_tasks]))
        variational_strategy = InputDependentVariationalStrategy(self, inducing_points, variational_distribution)

        super().__init__(variational_strategy)

        self.mean_module = mean
        self.covar_module = kernel

    def update_variational_params(self, inducing_points, variational_mean, chol_variational_covar):
        self.variational_strategy.inducing_points = inducing_points
        self.variational_strategy._variational_distribution.update_params(variational_mean, chol_variational_covar)

    def forward(self, x, **kwargs) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x, **kwargs)
        return MultivariateNormal(mean_x, covar_x)

# Graph Neural Networks
The following layers are implemented:
* Variational Layer

The following GNNs are implemented:
* Graph Convolutional Networks (GCN)
* Graph Attention Networks (GAT)
* Approximate Personalized Propagation of Neural Predictions (APPNP)
* Graph Convolutional Network via Initial residual and Identity mapping (GCNII)

### Variational Layer
The variational is defined by:
$$
Z, q_{\mu}, q_{L} = Variational(x),
$$
where Z, $q_{\mu}$ and $q_{L}$ are the variational parameters of a GP.

In [None]:
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential, Dropout, Linear, ReLU
from torch.nn.init import xavier_uniform_, zeros_
from torch_geometric.nn import GCNConv, GATConv, APPNP, GCN2Conv

class Variational(Module):
    def __init__(self, input_dim, num_features, num_classes, num_inducing_points):
        
        super(Variational, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.num_inducing_points = num_inducing_points
        self.chol_dim = num_inducing_points * (num_inducing_points + 1) // 2

        self.Z = Linear(input_dim, num_inducing_points * num_features)
        self.q_mu = Linear(input_dim, num_classes * num_inducing_points)
        self.q_sqrt = Linear(input_dim, num_classes * self.chol_dim)

    def reset_parameters(self):
        xavier_uniform_(self.Z.weight)
        xavier_uniform_(self.q_mu.weight)
        xavier_uniform_(self.q_sqrt.weight)

        zeros_(self.Z.bias)
        zeros_(self.q_mu.bias)
        zeros_(self.q_sqrt.bias)

    def forward(self, x):
        num_nodes = x.size(0)
        z = self.Z(x).reshape((num_nodes, self.num_inducing_points, self.num_features))
        q_mu = self.q_mu(x).reshape((num_nodes, self.num_classes, self.num_inducing_points))
        q_sqrt = self.q_sqrt(x).reshape((num_nodes, self.num_classes, self.chol_dim))
        return z, q_mu, q_sqrt

### Graph Convolutional Networks
The GCN is defined by
\begin{equation}
    H^{(k + 1)} = \sigma(\hat{A}H^{(k)}W^{(k)}),
\end{equation}
where $H^{(0)} = X$, and $\sigma(\cdot)$ is a activation function.

In [None]:
class GCN(Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_inducing_points, dropout=0.5):
        super(GCN, self).__init__()

        self.dropout = dropout
        self.hidden = ModuleList([
            GCNConv(input_dim, hidden_dim, cached=True),
            GCNConv(hidden_dim, hidden_dim, cached=True)
        ])
        self.q_dist = Variational(hidden_dim, input_dim, num_classes, num_inducing_points)

    def reset_parameters(self):
        for layer in self.hidden:
            layer.reset_parameters()
        self.q_dist.reset_parameters()

    def forward(self, x, edge_index):
        for layer in self.hidden[:-1]:
            x = layer(x, edge_index).relu()
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.hidden[-1](x, edge_index)
        x = F.dropout(x, p=self.dropout, training=self.training)
        z, q_mu, q_sqrt = self.q_dist(x)
        return z, q_mu, q_sqrt

### Graph Attention Networks
The GAT model operator is defined by:
\begin{equation}
    h_i' = \alpha_{ii}Wh_i + \sum_{j\in\mathcal{N}(i)}\alpha_{ij}Wh_j,
\end{equation}
where the normalized attention coefficients are computed as follows:
\begin{equation}
    \alpha_{ij} = \frac{\exp(LeakyReLU(a^T[Wh_i||Wh_j]))}{\sum_{k\in\mathcal{N}(i)}\exp(LeakyReLU(a^T[Wh_i||Wh_k]))},
\end{equation}
where $||$ is the concatenation operator.

In [None]:
class GAT(Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_inducing_points,
                 dropout=0.0, in_heads=8, out_heads=1, att_dropout=0.0):
        
        super(GAT, self).__init__()
        self.dropout = dropout
        self.hidden = ModuleList([
            GATConv(input_dim, hidden_dim, heads=in_heads, dropout=att_dropout),
            GATConv(hidden_dim * in_heads, hidden_dim * in_heads, heads=out_heads,
                    concat=False, dropout=att_dropout)
        ])
        self.q_dist = Variational(hidden_dim * in_heads, input_dim, num_classes,
                                  num_inducing_points)

    def reset_parameters(self):
        for layer in self.hidden:
            layer.reset_parameters()
        self.q_dist.reset_parameters()

    def forward(self, x, edge_index):
        for layer in self.hidden[:-1]:
            x = F.elu(layer(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.hidden[-1](x, edge_index)
        x = F.dropout(x, p=self.dropout, training=self.training)
        z, q_mu, q_sqrt = self.q_dist(x)
        return z, q_mu, q_sqrt

### Approximate Personalized Propagation of Neural Predictions
The APPNP has linear computational complexity and approximates topic-sensitive PageRank via $K$ aggregations:

\begin{equation}
    H^{(k + 1)} = (1 - \alpha)\hat{A}H^{(k)} + \alpha H^{(0)},
\end{equation}
where $H^{(0)} = f(X)$.

In [None]:
class GNNPP(Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_inducing_points,
                 dropout=0.5, K=10, alpha=0.1):
        
        super(GNNPP, self).__init__()
        self.dropout = dropout
        self.hidden = ModuleList([
            Linear(input_dim, hidden_dim),
            Linear(hidden_dim, hidden_dim)
        ])
        self.propagate = APPNP(K, alpha, cached=True)
        self.q_dist = Variational(hidden_dim, input_dim, num_classes, num_inducing_points)

    def reset_parameters(self):
        for layer in self.hidden:
            xavier_uniform_(layer.weight)
            zeros_(layer.bias)
        self.q_dist.reset_parameters()

    def forward(self, x, edge_index):
        for layer in self.hidden[:-1]:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = layer(x).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.hidden[-1](x)
        x = self.propagate(x, edge_index)
        x = F.dropout(x, p=self.dropout, training=self.training)
        z, q_mu, q_sqrt = self.q_dist(x)
        return z, q_mu, q_sqrt

### Graph Convolutional Networks via Initial residual and Idendity mapping
The GCNII operator is given by:

\begin{equation}
    H^{(k)} = (1 - \alpha)\hat{A}H^{(k)} + \alpha H^{(0)},\\
    H^{(k + 1)} = (1 - \beta)H^{(k)} + \beta H^{(k)}W^{(k)}
\end{equation}
where $H^{(0)} = f(X)$.

In [None]:
from torch.nn.init import kaiming_normal_, ones_

class GCNII(Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_inducing_points, dropout=0.0,
                 depth=64, alpha=0.1, theta=0.5):
        
        super(GCNII, self).__init__()
        self.dropout = dropout
        self.q_dist = Variational(hidden_dim, input_dim, num_classes, num_inducing_points)
        self.hidden = ModuleList([Linear(input_dim, hidden_dim)])
        for l in range(1, depth + 1):
            self.hidden.append(GCN2Conv(hidden_dim, alpha, theta, l, cached=True))

    def reset_parameters(self):
        kaiming_normal_(self.hidden[0].weight)
        zeros_(self.hidden[0].bias)
        for layer in self.hidden[1:]:
            layer.reset_parameters()

        self.q_dist.reset_parameters()

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.hidden[0](x).relu()
        h0 = x.clone()
        for layer in self.hidden[1:]:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = layer(x, h0, edge_index).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        z, q_mu, q_sqrt = self.q_dist(x)
        return z, q_mu, q_sqrt

# Utilities
The utilities functions are:
* Prepare Logging
* Load Dataset
* Metrics: Accuracy, ECE, Reliability Diagram
* Step and Evaluation functions
* Get Amortizer, GP and Likelihood, and Train
* Grid Search

### Logging and Load Dataset

In [None]:
from sklearn.feature_extraction.text import TfidfTransformer
from torch_geometric.datasets import Planetoid
from typing import NamedTuple

class Data(NamedTuple):
    num_features: int
    num_classes: int
    x: torch.Tensor
    y: torch.Tensor
    edge_index: torch.Tensor
    train_mask: torch.Tensor
    val_mask: torch.Tensor
    test_mask: torch.Tensor

def prepare_logging(name):
    os.makedirs('./reports/', exist_ok=True)
    fmt = '%(asctime)s, %(name)s: %(message)s'
    datefmt = '%Y.%m.%d - %H:%M:%S'
    log_file = f'./reports/{name.lower()}.log'

    if os.path.isfile(log_file):
        os.remove(log_file)

    handler = logging.FileHandler(log_file)
    formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
    handler.setFormatter(formatter)

    log = logging.getLogger(f'{name.lower()}')
    log.setLevel(logging.INFO)
    log.addHandler(handler)

    return log

def load_dataset(name, device):
    loader = Planetoid(root=f'datasets/{name.lower()}', name=name)
    x = loader.data.x.to(device)
    if name.lower() != 'pubmed':
        x = TfidfTransformer(smooth_idf=True).fit_transform(x.cpu().numpy()).todense()
        x = torch.Tensor(x).to(device)
    y = loader.data.y.to(device)
    edge_index = loader.data.edge_index.to(device)
    train_mask = torch.where(loader.data.train_mask)[0].to(device)
    val_mask = torch.where(loader.data.val_mask)[0].to(device)
    test_mask = torch.where(loader.data.test_mask)[0].to(device)

    data = Data(loader.num_features, loader.num_classes, x, y, edge_index,
                train_mask, val_mask, test_mask)

    return data

### Metrics
* Accuracy
* Expected Calibration Error
* Reliability Diagram

In [None]:
def accuracy_score(y_true, y_pred):
    return torch.mean(torch.eq(y_true, y_pred).type(torch.float64)).item()

def ECE(y_prob, y, nbins=10):
    edges = torch.linspace(0, 1, nbins + 1)
    accuracy = torch.zeros(nbins)
    confidence = torch.zeros(nbins)
    bin_sizes = torch.zeros(nbins)

    prob, pred = torch.max(y_prob, dim=1)
    for bin in range(nbins):
        if bin == (nbins - 1):
            in_bin = (edges[bin] <= prob) & (prob <= edges[bin + 1])
        else:
            in_bin = (edges[bin] <= prob) & (prob < edges[bin + 1])

        bin_sizes[bin] = in_bin.sum()
        if bin_sizes[bin] > 0:
            accuracy[bin] = torch.mean((y[in_bin] == pred[in_bin]).type(torch.double))
        confidence[bin] = (edges[bin + 1] + edges[bin]) / 2

    ece = torch.sum(torch.abs(accuracy - confidence) * bin_sizes) / len(y)

    calibration = {
        'ece': ece.item(),
        'accuracy': accuracy.tolist(),
        'confidence': confidence.tolist(),
        'bin_sizes': bin_sizes.tolist()
    }

    return calibration

def reliability_diagram(conf, acc, nbins=10):
    plt.bar(conf, height=acc, width=0.1 / (nbins / 10), edgecolor='k', align='center', label='Output')
    plt.bar(conf, height=conf - acc, bottom=acc, color='r', width=0.1, edgecolor='k', align='center', alpha=0.5, label='Gap')
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect Calibration')
    plt.legend()
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.show()

### Step and Evaluation functions

In [None]:
def step(x, y, edge_index, mask, amortizer, gp_model, likelihood, optimizer, criterion,
         num_samples=8):
    
    amortizer.train()
    gp_model.train()
    likelihood.train()

    # clear gradients
    optimizer.zero_grad()

    # estimate variational parameters and update the GP model
    z, m, L = amortizer(x, edge_index)
    gp_model.update_variational_params(z[mask], m[mask], L[mask])

    # compute ELBO
    with gpytorch.settings.num_likelihood_samples(num_samples):
        ll, kl, _ = criterion(gp_model(x[mask].unsqueeze(1)), y[mask])
    elbo = (ll.sum() - kl.sum().div(len(mask))).div(len(mask))

    # compute loss, gradients and update parameters
    loss = -elbo
    loss.backward()
    optimizer.step()

    return elbo.item()

In [None]:
def evaluate(x, y, edge_index, mask, amortizer, gp, likelihood, criterion, num_samples=16):
    amortizer.eval()
    gp.eval()
    likelihood.eval()

    with torch.no_grad(), gpytorch.settings.num_likelihood_samples(num_samples):
        # compute variational parameters and update the GP model
        z, m, L = amortizer(x, edge_index)
        gp.update_variational_params(z[mask], m[mask], L[mask])

        # compute the GP output
        variational_dist = gp(x[mask].unsqueeze(1))

        # compute the ELBO
        ll, kl, _ = criterion(variational_dist, y[mask])
        elbo = (ll.sum() - kl.sum().div(len(mask))).div(len(mask)).item()

        # compute the probabilities
        y_prob = likelihood(variational_dist).probs.mean(0).squeeze(1)
        y_pred = y_prob.argmax(-1)

        # get the ground truth
        y_true = y[mask]

        # compute the Negative Mean Negative Log-Likelihood, Accuracy and Calibration
        mnll = -F.nll_loss(y_prob.log(), y_true).item()
        accuracy = accuracy_score(y_true, y_pred)
        calibration = ECE(y_prob.cpu(), y_true.cpu())

    return {'elbo': elbo, 'mean_ll': mnll, 'accuracy': accuracy, 'calibration': calibration}

### Get Amortizer, GP and Optimizer, and Train

In [None]:
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, PolynomialKernel
from gpytorch.likelihoods import SoftmaxLikelihood

def get_amortizer(gnn, input_dim, num_classes, params):
    hidden_dim = int(params.hidden_dim)
    num_inducing_points = int(params.num_induc)
    dropout = params.dropout

    if gnn == 'gcn':
        return GCN(input_dim, hidden_dim, num_classes, num_inducing_points, dropout=dropout)

    if gnn == 'sgc':
        return None
    
    if gnn == 'gat':
        return GAT(input_dim, hidden_dim, num_classes, num_inducing_points, dropout=dropout,
                   in_heads=int(params.in_heads), out_heads=int(params.out_heads),
                   att_dropout=params.att_dropout)
    
    if gnn == 'appnp':
        return GNNPP(input_dim, hidden_dim, num_classes, num_inducing_points, dropout=dropout)
    
    if gnn == 'gcnii':
        return GCNII(input_dim, hidden_dim, num_classes, num_inducing_points, dropout=dropout,
                     depth=int(params.depth))

def get_gp(input_dim, num_classes, params):
    inducing_points = torch.zeros(int(params.num_induc), input_dim).unsqueeze(0)
    mean = ConstantMean()
    kernel = ScaleKernel(PolynomialKernel(power=params.power))
    gp = InputDependentSVGP(mean, kernel, inducing_points, num_classes)
    likelihood = SoftmaxLikelihood(mixing_weights=False, num_classes=num_classes)
    return gp, likelihood

def get_optimizer(amortizer, gp, likelihood, lr, num_epochs, params,
                  steps=[0.25, 0.5], gamma=0.1):
    gnn_lr = lr #* 0.1
    criterion = InputDependentVariationalELBO(likelihood, gp, num_data=1, combine_terms=False)
    optimizer = torch.optim.Adam([
        {'params': amortizer.hidden.parameters(), 'weight_decay': params.gnn_wdecay, 'lr': gnn_lr},
        {'params': amortizer.q_dist.parameters(), 'weight_decay': params.var_wdecay, 'lr': gnn_lr},
        {'params': gp.covar_module.parameters()},
        {'params': gp.mean_module.parameters()},
        {'params': likelihood.parameters()},
    ], lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, [steps[0] * num_epochs, steps[1] * num_epochs], gamma)
    return optimizer, scheduler, criterion

def train(amortizer, gp, likelihood, optimizer, scheduler, criterion, x, y, edge_index,
          train_mask, val_mask, metric='elbo'):
    
    elbo_curve = []
    val_scores = []
    no_improve = 0
    best_val_scores = {'elbo': -torch.inf, 'mean_ll': -torch.inf, 'accuracy': -torch.inf}

    for epoch in range(num_epochs):
        elbo = step(x, y, edge_index, train_mask, amortizer, gp, likelihood,
                    optimizer, criterion, num_samples=10)
        scores = evaluate(x, y, edge_index, val_mask, amortizer, gp, likelihood,
                          criterion, num_samples=20)
        
        elbo_curve.append(elbo)
        val_scores.append(scores)
        scheduler.step()

        if scores[metric] > best_val_scores[metric]:
            no_improve = 0
            best_val_scores = scores
            torch.save(amortizer.state_dict(), './models/amortizer.pt')
            torch.save(gp.state_dict(), './models/gp.pt')
            torch.save(likelihood.state_dict(), './models/likelihood.pt')
        else:
            no_improve += 1

        if no_improve == n_iter_no_improve:
            break
            
    best_model = {
        'amortizer': torch.load('./models/amortizer.pt'),
        'gp': torch.load('./models/gp.pt'),
        'likelihood': torch.load('./models/likelihood.pt')
    }
    return elbo_curve, best_val_scores, best_model

### Grid Search

In [None]:
def grid_search(data, filtered_grid, gnn, metric, val_with_train, device):
    grid_size = filtered_grid.shape[0]
    best_scores = {'elbo': -torch.inf, 'mean_ll': -torch.inf, 'accuracy': -torch.inf}
    best_params = {}

    if grid_size == 1:
        return filtered_grid.iloc[0], None

    # grid search loop
    for i, params in filtered_grid.iterrows():
        # initialize models
        amortizer = get_amortizer(gnn, data.num_features, data.num_classes, params)
        gp, likelihood = get_gp(data.num_features, data.num_classes, params)
        optimizer, scheduler, criterion = get_optimizer(amortizer, gp, likelihood, lr,
                                                        num_epochs, params)

        # setting device
        amortizer.to(device).reset_parameters()
        gp.to(device)
        likelihood.to(device)

        # train
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')

            mask = data.train_mask if val_with_train else data.val_mask
            _, scores, _ = train(
                amortizer, gp, likelihood, optimizer, scheduler, criterion,
                data.x, data.y, data.edge_index, data.train_mask, mask, metric
            )

        if scores[metric] > best_scores[metric]:
            best_scores = scores
            best_params = params

        print(f'Grid Search: {i}/{grid_size}, '\
              f'{scores[metric]:.4f} ({best_scores[metric]:.4f})')
    
    return best_params, best_scores

# Settings

In [None]:
from sklearn.model_selection import ParameterGrid

# seed and device
global_seed = 30
device = 'cuda'

# settings
dataset = 'PubMed'
lr = 0.01
num_epochs = 2000
n_iter_no_improve = 200
num_runs = 10
metric = 'elbo'
val_with_train = True
gnn = 'appnp'

grid = {
    'gnn_wdecay': [5e-3],
    'var_wdecay': [5e-4],
    'dropout': [0.6],
    'hidden_dim': [64],
    'power': [3],
    'num_induc': [20],
}
filtered_grid = pd.DataFrame(list(ParameterGrid(grid)))
filtered_grid = filtered_grid.query('`var_wdecay` <= `gnn_wdecay`')
filtered_grid.reset_index(inplace=True)

# Setting Seed, Settings and Reports

In [None]:
# setting the seed
import random

random.seed(global_seed)
np.random.seed(global_seed)
torch.manual_seed(global_seed)
torch.cuda.manual_seed_all(global_seed)

# settings informations
settings = {
    'gnn': gnn,
    'grid': grid,
    'metric': metric,
    'val_with_train': val_with_train,
    'learning_rate': lr,
    'num_epochs': num_epochs,
    'n_iter_no_improve': n_iter_no_improve,
    'num_runs': num_runs,
}

# reports
reports = {}

# pre-load datasets
data = load_dataset(dataset, device)

# Experiments

In [None]:
import warnings

os.makedirs('./models/', exist_ok=True)
log = prepare_logging(dataset)
log.info(f'Dataset: {dataset}')
log.info(f'Model: IDSVGP with {gnn.upper()}\n')

data = load_dataset(dataset, device)
hyperparams, _ = grid_search(data, filtered_grid, gnn, metric, val_with_train, device)

reports[dataset] = {}
reports[dataset]['hyperparams'] = hyperparams
for run_id in range(num_runs):
    reports[dataset][run_id] = {}
        
    # train
    amortizer = get_amortizer(gnn, data.num_features, data.num_classes, hyperparams)
    gp, likelihood = get_gp(data.num_features, data.num_classes, hyperparams)
    optimizer, scheduler, criterion = get_optimizer(amortizer, gp, likelihood, lr,
                                                    num_epochs, hyperparams)

    amortizer.to(device).reset_parameters()
    gp.to(device)
    likelihood.to(device)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')

        mask = data.train_mask if val_with_train else data.val_mask
        train_curve, val_scores, model = train(
            amortizer, gp, likelihood, optimizer, scheduler, criterion,
            data.x, data.y, data.edge_index, data.train_mask, mask, metric
        )

    # evalute
    amortizer.load_state_dict(model['amortizer'])
    gp.load_state_dict(model['gp'])
    likelihood.load_state_dict(model['likelihood'])

    test_scores = evaluate(data.x, data.y, data.edge_index, data.test_mask,
                            amortizer, gp, likelihood, criterion, 300)

    # save reports
    reports[dataset][run_id]['train_curve'] = train_curve
    reports[dataset][run_id]['val_scores'] = val_scores
    reports[dataset][run_id]['test_scores'] = test_scores

    log.info(f'Run: {run_id}, # Iterations: {len(train_curve)}, '\
                f'Accuracy: {test_scores["accuracy"] * 100:.1f}, '\
                f'Validation: {val_scores[metric]:.4f}')
    print(f'Run: {run_id}, # Iterations: {len(train_curve)}, '\
            f'Accuracy: {test_scores["accuracy"] * 100:.1f}, '\
            f'Validation: {val_scores[metric]:.4f}')

# Reports

In [None]:
reports[dataset]['hyperparams'] = reports[dataset]['hyperparams'].to_dict()

with open('./reports/settings.json', 'w') as f:
    json.dump(settings, f)

with open('./reports/reports.json', 'w') as f:
    json.dump(reports, f)

print(f'Mean and standard deviation of {num_runs} runs:')
accuracy = []
for i in range(num_runs):
    accuracy.append(reports[dataset][i]['test_scores']['accuracy'])
print(f'{dataset}: {np.mean(accuracy) * 100:.1f} ({np.std(accuracy) * 100:.1f})')

In [None]:
print(f'Best Params: {reports[dataset]["hyperparams"]}')

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=[4, 4])
plt.plot(train_curve, 'k')
plt.grid()
plt.xlabel(r'Num. of Epochs')
plt.ylabel(rf'ELBO')
plt.savefig(f'./reports/idsvgp_{gnn}.png', bbox_inches='tight')