### Paper Information

* Title: [Attentive Neural Processes](https://arxiv.org/abs/1901.05761) (ICLR 2019)
* Authors: Hyunjik Kim, Andriy Mnih, Jonathan Schwarz, Marta Garnelo, Ali Eslami, Dan Rosenbaum, Oriol Vinyals, Yee Whye Teh (2019)

### Main Ideas

* Neural Processes ([Garnelo et al., 2018](https://arxiv.org/abs/1807.01622)) often underfits the context set, possibly because the mean-aggregation step that summarizes the context set into a finite-dimensional vector acts as a bottleneck. Ideally, the closer a context point is to a target point, the more it should contribute towards the target point's prediction.

    <img src="images/attentive-neural-processes/figure-1.png" alt="Drawing" style="width: 70%;"/>

* While a Gaussian Process uses a parametric kernel to quantify the notion of similarity between training points and test points, an Attentive Neural Process (ANP) employs a differentiable attention mechanism, allowing the model to learn to attend to the context points relevant to the target points. Attention has been widely used in various contexts, such as neural machine translation ([Bahdanau et al., 2014](https://arxiv.org/abs/1409.0473), [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)) and image modeling ([Parmar et al., 2018](https://arxiv.org/abs/1802.05751)).

### Methodology

* Instead of replacing the mean-aggregation step in Neural Processes ([Garnelo et al., 2018](https://arxiv.org/abs/1807.01622)) with an attention layer, the authors include both of them in the model via two separate paths: a latent path using a global latent variable and a deterministic path using cross-attention. More specifically, the conditional distribution of the target points $(\mathbf{X}_T, \mathbf{Y}_T)$ given the context points $(\mathbf{X}_C, \mathbf{Y}_C)$ is specified in terms of a global latent variable $\mathbf{z}$ and a finite-dimensional representation $\mathbf{r}_C$ of the context sets:

    $$p(\mathbf{Y}_T \lvert \mathbf{X}_T, \mathbf{X}_C, \mathbf{Y}_C) = \int p(\mathbf{Y}_T \lvert \mathbf{X}_T, \mathbf{r}_C, \mathbf{z}) \, p(\mathbf{z} \lvert  \mathbf{X}_C, \mathbf{Y}_C) \, d\mathbf{z}.$$

    <img src="images/attentive-neural-processes/figure-2.png" alt="Drawing" style="width: 95%;"/>

* As in NPs, the global latent variable $\mathbf{z}$ is sampled from a variational distribution that depends only on the average hidden representation of the context points. The attention-based representation $r_C$ takes the context points into account. For example, using dot-product attention:

    $$\mathbf{r}_T = \text{Softmax} \left(\frac{h(\mathbf{X}_T) h(\mathbf{X}_C)^{\mathsf{T}}}{\sqrt{d_k}}  \right) \phi(\mathbf{X}_C, \mathbf{Y}_C),$$

    where $h$ and $\phi$ are some mappings parametrized as neural networks and $d_k$ is some fixed normalizing constant. The mappings $\phi$ can also include self-attention to model the interactions between the context points. The authors mention not including attention in the latent path in order to preserve the global structure of the stochastic process.

* Training proceeds via variational inference, where the global latent variable $\mathbf{z}$ is assumed an isotropic Gaussian prior $q(\mathbf{z} \lvert \mathbf{X}_C, \mathbf{Y}_C)$ and and isotropic Gaussian posterior $q(\mathbf{z} \lvert \mathbf{X}_T, \mathbf{Y}_T, \mathbf{X}_C, \mathbf{Y}_C)$.

* Similar to NPs, ANPs do not necessarily satisfy the consistency condition that requires the collection of joint distributions to be invariant under marginalization. However, the invariance under permutations is met as both the mean-aggregation step and the attention step don't depend on the orderings of context points.

### Related Works

* ANPs is related to GPs as other neural process models ([Garnelo et al., 2018](https://arxiv.org/abs/1807.01613), [Garnelo et al., 2018](https://arxiv.org/abs/1807.01622)). Another related approach inspired by Gaussian Processes is Variational Implicit Processes ([Ma et al., 2018](https://arxiv.org/abs/1806.02390)), where both the process and its posterior are approximated by a GP.

### Experiments

* The authors run experiments on 1D function regression to demonstrate that equipped with attention mechanism, ANPs has larger capacity to fit the training data, i.e. lower reconstruction error and lower negative log likelihood, especially those with dot-product attention and multihead attention. Experiments on 2D function regression also confirm these findings.

    <img src="images/attentive-neural-processes/figure-3.png" alt="Drawing" style="width: 70%;"/>

### Thoughts

* ANPs is more expressive than similar models in the same family, hence its outperformance is not surprising, especially when attention has been successfully incorporated into previous models. In that sense, the novelty of the paper is somewhat limited.

### Implementation

The following implementation is based on the official implementation [deepmind/neural-processes](https://github.com/deepmind/neural-processes). We first attempt to replicate the results for 1D function regression and benchmark against results produced by a Gaussian Process.

In [None]:
import torch
import torch.nn as nn
import torch.distributions as D
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm.notebook import tqdm


def gaussian_kernel(X1, X2, l1=0.6, sigma=1.0):
    return sigma ** 2 * ((-0.5 / l1 ** 2) * (torch.cdist(X1, X2, p=2) ** 2)).exp()


def sample_gaussian_process(inputs, l1=0.6, sigma=1.0, eps=2e-2):
    # Compute covariance matrix of size B x M x M induced by a Gaussian kernel
    batch_size, num_points, x_dim = inputs.size()
    covariance = gaussian_kernel(inputs, inputs, l1=l1, sigma=sigma)

    # Add diagonal noise to avoid singular covariance matrix
    covariance += (eps ** 2) * torch.eye(num_points)
    cholesky = torch.cholesky(covariance.double()).float()
    outputs = torch.matmul(cholesky, torch.randn(batch_size, num_points, 1))
    return outputs


def generate_train_data(batch_size=16, max_num_context=50):
    num_context = torch.randint(low=3, high=max_num_context, size=())
    num_target = torch.randint(low=2, high=max_num_context, size=())
    num_points = num_target + num_context

    x_values = 4 * torch.rand(batch_size, num_points, 1) - 2
    y_values = sample_gaussian_process(x_values)
    x_target, y_target = x_values, y_values
    x_context, y_context = x_values[:, :num_context], y_values[:, :num_context]
    return x_context, y_context, x_target, y_target


def generate_test_data(batch_size=1, max_num_context=50, num_target=400):
    num_context = torch.randint(low=3, high=max_num_context, size=())
    x_values = torch.linspace(-2, 2, 400).view(1, 400, 1).repeat((batch_size, 1, 1))
    y_values = sample_gaussian_process(x_values)
    x_target, y_target = x_values, y_values
    idx = torch.randperm(num_target)[:num_context]
    x_context, y_context = x_values[:, idx, :], y_values[:, idx, :]
    return x_context, y_context, x_target, y_target


def posterior_predictive(x_target, x_context, y_context, l1=0.4, sigma=1.0, eps=2e-3):
    K = gaussian_kernel(x_context, x_context, l1=l1, sigma=sigma) + (eps ** 2) * torch.eye(len(x_context))
    K_s = gaussian_kernel(x_context, x_target, l1=l1, sigma=sigma)
    K_ss = gaussian_kernel(x_target, x_target, l1=l1, sigma=sigma)
    K_inv = torch.inverse(K)

    mu = K_s.transpose(1, 2).bmm(K_inv).bmm(y_context)
    cov = K_ss - K_s.transpose(1, 2).bmm(K_inv).bmm(K_s)
    std = (cov.diagonal(dim1=1, dim2=2) + 1e-6).sqrt()
    return mu, std

In [None]:
class UniformAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        return value.mean(dim=1, keepdims=True).repeat(1, query.size(1), 1)


class LaplaceAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, scale=1., normalize=True):
        super().__init__()
        self.scale = scale
        self.normalize = normalize
        self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))

    def forward(self, query, key, value):
        query, key = self.net(query), self.net(key)
        weights = -((query[:, :, None, :] - key[:, None, :, :]).abs() / self.scale).sum(dim=-1)
        weights = F.softmax(weights, dim=-1) if self.normalize else 1 + F.tanh(weights)
        return torch.bmm(weights, value)


class DotProductAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, normalize=True):
        super().__init__()
        self.normalize = normalize
        self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))

    def forward(self, query, key, value):
        query, key = self.net(query), self.net(key)
        weights = torch.bmm(query, key.transpose(1, 2)) / (query.size(-1) ** 0.5)
        weights = F.softmax(weights, dim=-1) if self.normalize else F.sigmoid(weights)
        return torch.bmm(weights, value)


class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = output_dim // num_heads

        self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
        self.query_proj = nn.Linear(input_dim, output_dim, bias=False)
        self.key_proj = nn.Linear(input_dim, output_dim, bias=False)
        self.value_proj = nn.Linear(output_dim, output_dim, bias=False)
        self.output_proj = nn.Linear(output_dim, output_dim, bias=False)

        for layer in [self.query_proj, self.key_proj, self.value_proj, self.output_proj]:
            layer.weight.data.normal_(std=layer.in_features ** -0.5)

    def forward(self, query, key, value):
        batch_size, query_len, _ = query.size()
        query, key = self.net(query), self.net(key)

        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)
        weights = torch.bmm(query, key.transpose(1, 2)) / (query.size(-1) ** 0.5)
        weights = F.softmax(weights, dim=-1)

        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)
        output = torch.bmm(weights, value).view(batch_size, self.num_heads, query_len, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, self.num_heads * self.head_dim)
        return self.output_proj(output)


class LinearNormal(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, 2 * out_features)

    def forward(self, inputs):
        mean, std = self.linear(inputs).chunk(2, dim=-1)
        std = 0.1 + 0.9 * F.softplus(std)
        return D.Normal(loc=mean, scale=std)


class AttentiveNeuralProcess(nn.Module):
    def __init__(self, hidden_dim=128, x_dim=1, y_dim=1, z_dim=128):
        super().__init__()
        self.attention_encoder = nn.Sequential(nn.Linear(x_dim + y_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim),)
        self.attention_mapping = DotProductAttention(x_dim, hidden_dim)
        self.latent_encoder = nn.Sequential(nn.Linear(x_dim + y_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim),)
        self.latent_mapping = LinearNormal(hidden_dim, z_dim)
        self.decoder = nn.Sequential(nn.Linear(2 * hidden_dim + x_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), LinearNormal(hidden_dim, y_dim),)

    def forward(self, x_context, y_context, x_target, y_target=None):
        # x_context: [batch_size, num_context, x_dim]
        # y_context: [batch_size, num_context, y_dim]
        # x_target: [batch_size, num_points, x_dim]
        # y_target: [batch_size, num_points, y_dim]

        # Compute attentive representations for target outputs
        h_context = self.attention_encoder(torch.cat([x_context, y_context], dim=-1))
        h_target = self.attention_mapping(x_target, x_context, h_context)

        # Infer a global latent variable based on the context input-output pairs and
        # map the global latent variable to a factorized Gaussian distribution
        r_context = self.latent_encoder(torch.cat([x_context, y_context], dim=-1)).mean(dim=1)
        z_prior = self.latent_mapping(r_context)

        # After training, target input-output pairs are not available. We concatenate the deterministic
        # attention representation and a sample from z_prior to each target input to make predictions.
        if y_target is None:
            z_sample = z_prior.rsample()[:, None, :].repeat(1, x_target.size(1), 1)
            y_pred = self.decoder(torch.cat([h_target, z_sample, x_target], dim=-1))
            return y_pred

        # During training, infer another factorized Gaussian distribution using the
        # target input-output pairs, later to be matched with the previous distribution
        r_target = self.latent_encoder(torch.cat([x_target, y_target], dim=-1)).mean(dim=1)
        z_posterior = self.latent_mapping(r_target)

        # Concatenate a sample from z_posterior to each target input to make predictions
        z_sample = z_posterior.rsample()[:, None, :].repeat(1, x_target.size(1), 1)
        y_pred = self.decoder(torch.cat([h_target, z_sample, x_target], dim=-1))
        return z_prior, z_posterior, y_pred


model = AttentiveNeuralProcess().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
progress_bar = tqdm(range(int(1e5 + 1)), desc="Iteration", leave=False)

for it in progress_bar:
    x_context, y_context, x_target, y_target = generate_train_data(max_num_context=50)
    z_prior, z_posterior, y_pred = model(x_context.cuda(), y_context.cuda(), x_target.cuda(), y_target.cuda())

    log_prob = y_pred.log_prob(y_target.cuda()).mean(dim=0).sum()
    kl_loss = D.kl_divergence(z_posterior, z_prior).mean(dim=0).sum()
    loss = -log_prob + kl_loss

    model.zero_grad()
    loss.backward()
    optimizer.step()
    progress_bar.set_postfix({"loss": "{:.3f}".format(loss.item())})

    if it % 20000 == 0:
        x_context, y_context, x_target, y_target = generate_test_data(max_num_context=10)
        y_pred = model(x_context.cuda(), y_context.cuda(), x_target.cuda())
        mu, sigma = y_pred.loc.detach().cpu().numpy(), y_pred.scale.detach().cpu().numpy()

        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        ax.plot(x_target[0], mu[0], c="C0", label="NP")
        ax.scatter(x_context[0], y_context[0], c="r", s=15, label="Context")
        ax.plot(x_target[0], y_target[0], c="C2", label="Truth")
        ax.set_title("Iteration {}: loss {:.3f}".format(it, loss.item()))
        ax.fill_between(x_target[0, :, 0], mu[0, :, 0] - sigma[0, :, 0], mu[0, :, 0] + sigma[0, :, 0], alpha=0.2, facecolor='C0', interpolate=True)

        mu, std = posterior_predictive(x_target, x_context, y_context)
        ax.plot(x_target[0], mu[0], c='C1', label="GP")
        ax.fill_between(x_target[0].squeeze(), mu[0].squeeze() - std[0], mu[0].squeeze() + std[0], facecolor='C1', alpha=0.2)
        ax.legend()
        plt.show()

In [2]:
from IPython.core.display import HTML
HTML(open('../css/custom.css', 'r').read())