### Paper Information

* Title: [Conditional Neural Processes](https://arxiv.org/abs/1807.01613)
* Authors: Marta Garnelo, Dan Rosenbaum, Chris J. Maddison, Tiago Ramalho, David Saxton, Murray Shanahan, Yee Whye Teh, Danilo J. Rezende, S. M. Ali Eslami (2018)

### Main Ideas

* Traditionally, deep neural networks learns a deterministic function and lacks the capacity to adapt to changes in the input distribution, which is important in certain applications such as meta-learning and continual learning.

* A conditional neural process (CNP) models a stochastic process and more accurately a family of conditional distributions given a context set. Its architecture is simple; an average representation of the context points is concatenated to each target point to output a probabilistic prediction, whose negative log likelihood provides an objective function to minimize.

* CNPs are scalable and capable of uncertainty estimation, although its variance estimates are not as accurate as Gaussian Processes.

<div style="display:flex">
     <div style="flex:1;padding-right:5px;">
          <img src="images/conditional-neural-processes/figure-1.png" alt="Drawing" style="width: 80%;"/>
     </div>
     <div style="flex:1;padding-left:5px;">
          <img src="images/conditional-neural-processes/figure-4.png" alt="Drawing" style="width: 90%;"/>
     </div>
</div>

### Methodology

* Inspired by Gaussian Processes, CNPs define a conditional distribution $p(Y \rvert X, C)$ for the target points $T = (X, Y) = \{x_i, y_i\}_{i \in \mathcal{I}(T)}$ given a set of context points $C = (X_C, Y_C) = \{x_i, y_i\}_{i \in \mathcal{I}(C)}$. The dependency on $C$ is specified via a deterministic representation $r_C = r(X_C, Y_C)$, which is practically formed by averaging hidden representations of the context points in $C$:

    $$p(Y \rvert X, C) = \prod_{(x, y) \in T} p(y; g_\phi(x, r_C)), \quad r_C = r_1 \oplus r_2 \cdots \oplus r_{|\mathcal{I}(C)|}, \quad r_i = h_\theta(x_i, y_i),$$

    where $h_\theta: X \times Y \rightarrow \mathbb{R}^d$ and $g_\phi: X \times R^d \rightarrow \mathbb{R}^e$ are neural networks and $\oplus$ denotes a commutative aggregation operation such as elementwise summation or multiplication For regression tasks, for example, $g_\phi(x, r_C) = (\mu_x, \sigma_x^2)$ provides the mean and variance of a Gaussian distribution.

* Despite its simplicity, the aggregation operator $\oplus$ ensures $p(Y \rvert X, C)$ is invariant under permutations of $C$ while its factorization form guarantees invariance under permutations of $T$. These two properties are built into the model to resemble those of a stochastic process. Still, a CNP is not necessarily a stochastic process since it generally fails to satisfy the consistency condition under marginalization.

* Training a CNP amounts to minimizing the negative conditional log probability of the target points $\mathcal{L}(\theta, \phi) = -\log p(Y \rvert X, C)$. Each training iteration is given a small number of context points similar to what happens at test time.

* Compared to Gaussian Processes, CNPs are scalable, which has a running time complexity of $\mathcal{O}(n + m)$ where $n$ and $m$ are the sizes of the context set and the target set.

### Experiments

* The first experiment involves function regression, where the target function is drawn from a fixed Gaussian process or a set of Gaussian processes of different kernel parameters. The second experiment involves image completion and binary pixel classification on MNIST and CIFAR10. Generally, the variance estimates around context points are small, capturing some notion of uncertainty. More context points result in better prediction and changes in variance estimates, though perfect reconstruction is not attainable possible due to the bottleneck at the representation level.

* The third experiment involves one-shot classification on Omniglot, where each class has only 20 examples. CNPs do not outperform matching networks, although their results are somewhat comparable.

### Related Works

* CNPs are closely related to Generative Query Networks ([Eslami et al., 2018](https://deepmind.com/blog/article/neural-scene-representation-and-rendering)), a meta-learning approach that models conditional distributions to predict new viewpoints in 3D scenes. Deep generative models for one-shot generalization ([J. Rezende et al., 2016](https://arxiv.org/abs/1603.05106)) also conditions on context during inference.

* Comparing the context points and the target points using some attention-based distance metric in feature space has also been done in Matching Networks ([Vinyals et al., 2016](https://arxiv.org/abs/1606.04080)) and Generative Matching Networks ([Bartunov and Vetro, 2016](https://arxiv.org/abs/1612.02192)). In contrast, CNPs do not explicitly define a distance kernel.

### Critical Thoughts

* Although the authors mention including latent variables in the experiment section, they are not an essential piece of the model. Latent variables allow for sampling multiple target distributions given the same context set as demonstrated in follow-up works, including Neural Processes ([Garnelo et al., 2018](https://arxiv.org/abs/1807.01622)) and Attentive Neural Processes ([Kim et al., 2019](https://arxiv.org/abs/1901.05761)).

* CNPs seem to overestimate uncertainty around context points (i.e. underfitting) and underestimate uncertainty everywhere else (i.e. overconfidence). Also, the variance estimates are not very smooth, which is probably related to the choice of ReLU activations (see [this blog post](https://kasparmartens.rbind.io/post/np/)). It's also not clear how these variance estimates compare to those generate by traditional probabilistic models such as Bayesian networks or variational dropout.

* Modeling a stochastic process with deep neural networks is an interesting idea. The authors consider learning high-level abstractions as a grand challenge in machine learning. Indeed, this paper sparks certain interests in functional approaches using deep nets (see Functional Neural Process ([Louizos et al. (2019)](https://arxiv.org/abs/1906.08324)), Sequential Neural Process ([Singh et al. (2019)](https://arxiv.org/abs/1906.10264)), for example).

### Implementation

The following implementation is based on the official implementation [deepmind/neural-processes](https://github.com/deepmind/neural-processes).

In [None]:
import torch
import torch.nn as nn
import torch.distributions as D
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import random

from tqdm.notebook import tqdm


def gaussian_kernel(X1, X2, l1=0.4, sigma=1.0):
    return sigma ** 2 * ((-0.5 / l1 ** 2) * (torch.cdist(X1, X2, p=2) ** 2)).exp()


def sample_gaussian_process(inputs, l1=0.4, sigma=1.0, eps=2e-3):
    # Compute covariance matrix of size B x M x M induced by a Gaussian kernel
    batch_size, num_points, x_dim = inputs.size()
    covariance = gaussian_kernel(inputs, inputs, l1=l1, sigma=sigma)

    # Add diagonal noise to avoid singular covariance matrix
    covariance += (eps ** 2) * torch.eye(num_points)
    cholesky = torch.cholesky(covariance.double()).float()
    outputs = torch.matmul(cholesky, torch.randn(batch_size, num_points, 1))
    return outputs


def generate_train_data(batch_size=16, max_num_context=50):
    num_context = torch.randint(low=3, high=max_num_context, size=())
    num_target = torch.randint(low=2, high=max_num_context, size=())
    num_points = num_target + num_context

    x_values = 4 * torch.rand(batch_size, num_points, 1) - 2
    y_values = sample_gaussian_process(x_values)
    x_target, y_target = x_values, y_values
    x_context, y_context = x_values[:, :num_context], y_values[:, :num_context]
    return x_context, y_context, x_target, y_target


def generate_test_data(batch_size=1, max_num_context=50, num_target=400):
    num_context = torch.randint(low=3, high=max_num_context, size=())
    x_values = torch.linspace(-2, 2, 400).view(1, 400, 1).repeat((batch_size, 1, 1))
    y_values = sample_gaussian_process(x_values)
    x_target, y_target = x_values, y_values
    idx = torch.randperm(num_target)[:num_context]
    x_context, y_context = x_values[:, idx, :], y_values[:, idx, :]
    return x_context, y_context, x_target, y_target


def posterior_predictive(x_target, x_context, y_context, l1=0.4, sigma=1.0, eps=2e-3):
    K = gaussian_kernel(x_context, x_context, l1=l1, sigma=sigma) + (eps ** 2) * torch.eye(len(x_context))
    K_s = gaussian_kernel(x_context, x_target, l1=l1, sigma=sigma)
    K_ss = gaussian_kernel(x_target, x_target, l1=l1, sigma=sigma)
    K_inv = torch.inverse(K)

    mu = K_s.transpose(1, 2).bmm(K_inv).bmm(y_context)
    cov = K_ss - K_s.transpose(1, 2).bmm(K_inv).bmm(K_s)
    std = (cov.diagonal(dim1=1, dim2=2) + 1e-6).sqrt()
    return mu, std

In [None]:
class LinearNormal(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, 2 * out_features)

    def forward(self, inputs):
        mean, std = self.linear(inputs).chunk(2, dim=-1)
        std = 0.1 + 0.9 * F.softplus(std)
        return D.Normal(loc=mean, scale=std)


class ConditionalNeuralProcess(nn.Module):
    def __init__(self, hidden_size=128, x_dim=1, y_dim=1):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(x_dim + y_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size),)
        self.decoder = nn.Sequential(nn.Linear(hidden_size + x_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), LinearNormal(hidden_size, y_dim),)

    def forward(self, x_context, y_context, x_target):
        # x_context: [batch_size, num_context, x_dim]
        # y_context: [batch_size, num_context, y_dim]
        # x_target: [batch_size, num_points, x_dim]
        # y_target: [batch_size, num_points, y_dim]

        # Infer a global latent variable based on context input-output pairs
        latent = self.encoder(torch.cat([x_context, y_context], dim=-1)).mean(dim=1)

        # Concatenate the global latent variable to each target input
        latent = latent[:, None, :].repeat(1, x_target.size(1), 1)
        inputs = torch.cat([latent, x_target], dim=-1)

        # Output a factorized distribution for target outputs
        return self.decoder(inputs)


model = ConditionalNeuralProcess().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
progress_bar = tqdm(range(int(1e5 + 1)), desc="Iteration", leave=False)

for it in progress_bar:
    x_context, y_context, x_target, y_target = generate_train_data()
    y_pred = model(x_context.cuda(), y_context.cuda(), x_target.cuda())
    loss = -y_pred.log_prob(y_target.cuda()).mean()
    model.zero_grad()
    loss.backward()
    optimizer.step()
    progress_bar.set_postfix({"loss": "{:.3f}".format(loss.item())})

    if it % 20000 == 0:
        x_context, y_context, x_target, y_target = generate_test_data(max_num_context=10)
        y_pred = model(x_context.cuda(), y_context.cuda(), x_target.cuda())
        mu, sigma = y_pred.loc.detach().cpu().numpy(), y_pred.scale.detach().cpu().numpy()

        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        ax.plot(x_target[0], mu[0], c="C0", label="CNP")
        ax.scatter(x_context[0], y_context[0], c="r", s=15, label="Context")
        ax.plot(x_target[0], y_target[0], c="C2", label="Truth")
        ax.set_title("Iteration {}: loss {:.3f}".format(it, loss.item()))
        ax.fill_between(x_target[0, :, 0], mu[0, :, 0] - sigma[0, :, 0], mu[0, :, 0] + sigma[0, :, 0], alpha=0.2, facecolor='C0', interpolate=True)

        mu, std = posterior_predictive(x_target, x_context, y_context)
        ax.plot(x_target[0], mu[0], c='C1', label="GP")
        ax.fill_between(x_target[0].squeeze(), mu[0].squeeze() - std[0], mu[0].squeeze() + std[0], facecolor='C1', alpha=0.2)
        ax.legend()

        fig.savefig('images/conditional-neural-processes/iteration-{}.png'.format(it), dpi=300, bbox_inches='tight')
        plt.show()

<img src="images/conditional-neural-processes/cnp-toy.gif" alt="Drawing" style="width: 60%;"/>

In [2]:
from IPython.core.display import HTML
HTML(open('../css/custom.css', 'r').read())