# Replication of Deepmind's official NP/ANPs implementation in Pytorch

## Differences between Deemind's CNP and ANP implementations
- Data Generator
  - lengthscale is 0.6 here
  - random_kernel_parameters = True: kernel parameters are sampled: samples unformly between 0.1 and 0.6 for lengthscale. Vary across batches even.

- batch_mlp function:
- 

## Resources

**Conditional Neural Processes**:   
Garnelo M, Rosenbaum D, Maddison CJ, Ramalho T, Saxton D, Shanahan M, Teh YW, Rezende DJ, Eslami SM. Conditional Neural Processes. In International Conference on Machine Learning 2018.

**Neural Processes**:  
Garnelo, M., Schwarz, J., Rosenbaum, D., Viola, F., Rezende, D.J., Eslami, S.M. and Teh, Y.W. Neural processes. ICML Workshop on Theoretical Foundations and Applications of Deep Generative Models 2018.

**Attentive Neural Processes**:   
Kim, H., Mnih, A., Schwarz, J., Garnelo, M., Eslami, A., Rosenbaum, D., Vinyals, O. and Teh, Y.W. Attentive Neural Processes. In International Conference on Learning Representations 2019.

[Link to Deepmind's Google Colab (A)NP ](https://colab.research.google.com/github/deepmind/neural-processes/blob/master/attentive_neural_process.ipynb)  
[Link to Deemind's Neural Processes GitHub repository](https://github.com/deepmind/neural-processes)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import collections
import torch

## Data Generator

In [None]:
# The CNP takes as input a `CNPRegressionDescription` namedtuple with fields:
#   `query`: a tuple containing ((context_x, context_y), target_x)
#   `target_y`: a tesor containing the ground truth for the targets to be
#     predicted
#   `num_total_points`: A vector containing a scalar that describes the total
#     number of datapoints used (context + target)
#   `num_context_points`: A vector containing a scalar that describes the number
#     of datapoints used as context
# The GPCurvesReader returns the newly sampled data in this format at each
# iteration

CNPRegressionDescription = collections.namedtuple(
    "CNPRegressionDescription",
    ("query", "target_y", "num_total_points", "num_context_points"))

class GPCurvesReader(object):
  """Generates curves using a Gaussian Process (GP).

  Supports vector inputs (x) and vector outputs (y). Kernel is
  mean-squared exponential, using the x-value l2 coordinate distance scaled by
  some factor chosen randomly in a range. Outputs are independent gaussian
  processes.
  """

  def __init__(self,
               batch_size,
               max_num_context,
               x_size = 1,
               y_size = 1,
               l1_scale = 0.6,
               sigma_scale = 1.0,
               random_kernel_parameters=True
               testing = False):
    """Creates a regression dataset of functions sampled from a GP.

    Args:
      batch_size: An integer.
      max_num_context: The max number of observations in the context.
      x_size: Integer >= 1 for length of "x values" vector.
      y_size: Integer >= 1 for length of "y values" vector.
      l1_scale: Float; typical scale for kernel distance function.
      sigma_scale: Float; typical scale for variance.
      random_kernel_parameters: If `True`, the kernel parameters (l1 and sigma) 
          will be sampled uniformly within [0.1, l1_scale] and [0.1, sigma_scale].
      testing: Boolean that indicates whether we are testing. If so there are
          more targets for visualization.
    """
    self._batch_size = batch_size
    self._max_num_context = max_num_context
    self._x_size = x_size
    self._y_size = y_size
    self._l1_scale = l1_scale
    self._sigma_scale = sigma_scale
    self._random_kernel_parameters = random_kernel_parameters
    self._testing = testing

  def _gaussian_kernel(self, xdata, l1, sigma_f, sigma_noise = 2e-2):
    """Applies the Gaussian kernel to generate curve data.

    Args:
      xdata: Tensor with shape `[batch_size, num_total_points, x_size]` with
          the values of the x-axis data.
      l1: Tensor with shape `[batch_size, y_size, x_size]`, the scale
          parameter of the Gaussian kernel.
      sigma_f: Float tensor with shape `[batch_size, y_size]`; the magnitude
          of the std.
      sigma_noise: Float, std of the noise that we add for stability.

    Returns:
      The kernel, a float tensor with shape
      `[batch_size, y_size, num_total_points, num_total_points]`.
    """
    # Extract number of data points from tensor so this works for both testing and training
    num_total_points = xdata.shape[1]
    # num_total_points = tf.shape(xdata)[1]

    # Expand and take the difference
    xdata1 = xdata.unsqueeze(dim = 1) # [B, 1, num_total_points, x_size]
    xdata2 = xdata.unsqueeze(dim = 2) # [B, num_total_points, 1, x_size]
    # xdata1 = tf.expand_dims(xdata, axis=1)  # [B, 1, num_total_points, x_size]
    # xdata2 = tf.expand_dims(xdata, axis=2)  # [B, num_total_points, 1, x_size]

    diff = xdata1 - xdata2  # [B, num_total_points, num_total_points, x_size]

    # Insert dimension for y_size: [B, y_size, num_total_points, num_total_points, x_size]
    # same as diff[:, None, :, :, :]
    diff_expanded = diff.unsqueeze(dim = 1)

    # Scale the differences (lengthscale) and square
    # l1[:, :, None, None, :] created explicit dimensions to that dimensionality matches
    norm = torch.square(diff_expanded / l1[:, :, None, None, :])
    # Norm has shape [B, y_size, num_total_points, num_total_points, x_size]
    # norm = tf.square(diff[:, None, :, :, :] / l1[:, :, None, None, :])

    # Sum along last dimension (x_size) to reduce this dimension
    norm = torch.sum(norm, dim = -1)
    # Norm now has shape [B, y_size, num_total_points, num_total_points]

    # norm = tf.reduce_sum(norm, -1)  # [B, data_size, num_total_points, num_total_points]

    kernel = torch.square(sigma_f)[:, :, None, None] * torch.exp(-0.5 * norm)
    # kernel = tf.square(sigma_f)[:, :, None, None] * tf.exp(-0.5 * norm)
    # Kernel has shape [B, y_size, num_total_points, num_total_points]

    # Add some noise to the diagonal to make the cholesky work.
    # sigma_noise.pow(2) = sigma_noise ** 2
    kernel = kernel + ((sigma_noise ** 2) * torch.eye(n = num_total_points))
    # kernel += (sigma_noise**2) * tf.eye(num_total_points)

    return kernel

  def generate_curves(self):
    """Builds the op delivering the data.

    Generated functions are `float32` with x values between -2 and 2.
    
    Returns:
      A `CNPRegressionDescription` namedtuple.
    """
    # Sample number of context points between 3 and max_num_content
    # Torch: low (inclusive) and high (exclusive)
    num_context = torch.randint(low = 3, high = self._max_num_context, size = (1,))
    # num_context = tf.random_uniform(shape = [], minval = 3, maxval = self._max_num_context, dtype = tf.int32)

    ### X-VALUES ###
    # If we are TESTING we want to have more targets and have them evenly distributed in order to plot the function.
    if self._testing:
      num_target = 400
      num_total_points = num_target
      # tf.expand_dims or torch.unsqueeze add dimension of length one
      # torch.tile create x batch replicas
      x_values = torch.tile(input = torch.arange(start = -2., end = 2., step = 1./100).unsqueeze(dim = 0),
                 dims = (self._batch_size, 1))
     
      # Add explicit last dimension
      x_values = x_values.unsqueeze(dim = -1)
      # x_value has shape (batch_size, num_target, x_size)

      # x_values = tf.tile(tf.expand_dims(tf.range(-2., 2. , 1. / 100, dtype = tf.float32), axis = 0),[self._batch_size, 1])
      # x_values = tf.expand_dims(x_values, axis=-1)

    # During TRAINING the number of target points and their x-positions are selected at random
    # Since x_value samples have no order, the kernel looks funky
    else:
      num_target = torch.randint(low = 2, high = self._max_num_context, size = (1,))
      # num_target = tf.random_uniform(shape = (), minval = 2, maxval = self._max_num_context, dtype = tf.int32)
      num_total_points = num_context + num_target
      # sample unformly between [0, 1), then scale and shift
      x_values = ((torch.rand(size = (self._batch_size, num_total_points, self._x_size)) * 4) - 2)
      # x_values = tf.random_uniform([self._batch_size, num_total_points, self._x_size], -2, 2)

    ### Y-VALUES ###
    if self._random_kernel_parameters:
      l1 = ((torch.rand(size = (self._batch_size, self._y_size, self._x_size)) * (self._l1_scale - 0.1)) + 0.1)
      sigma_f = ((torch.rand(size = (self._batch_size, self._y_size, self._x_size)) * (self._sigma_scale - 0.1)) + 0.1)
    # Set kernel parameters
    # Copy l1_scale and sigma_f into the right shaped tensors

    # Same Fixed parameters
    else:
      l1 = torch.ones(size = (self._batch_size, self._y_size, self._x_size)) * self._l1_scale
      sigma_f = torch.ones(size = (self._batch_size, self._y_size)) * self._sigma_scale

    # l1 = (tf.ones(shape = [self._batch_size, self._y_size, self._x_size]) * self._l1_scale)
    # sigma_f = tf.ones(shape = [self._batch_size, self._y_size]) * self._sigma_scale

    ### GAUSSIAN KERNEL ###
    # Pass the x_values through the Gaussian kernel
    # [batch_size, y_size, num_total_points, num_total_points]
    kernel = self._gaussian_kernel(x_values, l1, sigma_f)

    # Computes the Cholesky decomposition for batches of symmetric positive-definite matrices
    cholesky = torch.linalg.cholesky(kernel)
    # cholesky has shape [batch_size, y_size, num_total_points, num_total_points]

    # Calculate Cholesky, using double precision for better stability:
    # cholesky = tf.cast(tf.cholesky(tf.cast(kernel, tf.float64)), tf.float32)

    # Sample a curve: randn stand for random normal
    y_values = torch.matmul(cholesky, torch.randn(size = (self._batch_size, self._y_size, num_total_points, 1)))
    # y_values has shape [batch_size, y_size, num_total_points, 1]

    # y_values = tf.matmul(cholesky, tf.random_normal([self._batch_size, self._y_size, num_total_points, 1]))

    # Squeeze last dimension and transpose last two dimensions
    y_values = torch.transpose(input = y_values.squeeze(dim = -1), dim0 = 2, dim1 = 1)
    # y_values now has shape [batch_size, num_total_points, y_size]
    
    # y_values = tf.transpose(tf.squeeze(y_values, 3), [0, 2, 1])

    if self._testing:
      # Select the targets
      target_x = x_values
      target_y = y_values

      # Select the observations (num_context subset of target)
      # Returns a random permutation of integers from 0 to n - 1.
      idx = torch.randperm(n = int(num_target))
      # idx = tf.random_shuffle(tf.range(num_target))

      # Subset first "num_context" points from dim 1 into the context
      context_x = x_values[:, idx[:num_context], :]
      context_y = y_values[:, idx[:num_context], :]

      # context_x = tf.gather(x_values, idx[:num_context], axis=1)
      # context_y = tf.gather(y_values, idx[:num_context], axis=1)

    else:
      # Select the targets which will consist of the context points as well as some new target points
      # same as target_x = x_values (all values)
      target_x = x_values[:, :num_target + num_context, :]
      target_y = y_values[:, :num_target + num_context, :]

      # Select the observations
      context_x = x_values[:, :num_context, :]
      context_y = y_values[:, :num_context, :]

    query = ((context_x, context_y), target_x)

    return CNPRegressionDescription(
        query = query,
        target_y = target_y,
        num_total_points = target_x.shape[1],
        # num_total_points=tf.shape(target_x)[1],
        num_context_points = num_context)

## Encoder: Latent Path:

I have left off here.

In [None]:
from torch import nn

class LatentEncoder(nn.Module):
  """The Latent Encoder."""

  def __init__(self, output_sizes, num_latents):
    """(A)NP latent encoder.

    Args:
      output_sizes: An iterable containing the output sizes of the encoding MLP.
      num_latents: The latent dimensionality.
    """
    super().__init__()
    self._output_sizes = output_sizes
    self._num_latents = num_latents

    # PyTorch: need to initiate layers in __init__ not forward()
    # First layer - Warning: HARDCODE 2
    self.module_list = nn.ModuleList([nn.Linear(in_features = 2, out_features = self._output_sizes[0])])
    # Activate
    self.module_list.append(nn.ReLU(inplace = True))

    # Add as many layers as needed
    for i, size in enumerate(self._output_sizes[1 : -1]):
      # i: previous index since we start at 1 
      self.module_list.append(nn.Linear(in_features = self._output_sizes[i], out_features = self._output_sizes[i + 1]))
      self.module_list.append(nn.ReLU(inplace = True))

    # Last layer without activation
    self.module_list.append(nn.Linear(in_features = self._output_sizes[-2], out_features = self._output_sizes[-2]))

  def forward(self, context_x, context_y):
    """Encodes the inputs into one representation.

    Args:
      context_x: Tensor of size bs x observations x m_ch. For this 1D regression
          task this corresponds to the x-values.
      context_y: Tensor of size bs x observations x d_ch. For this 1D regression
          task this corresponds to the y-values.
      num_context_points: A tensor containing a single scalar that indicates the
          number of context_points provided in this iteration.

    Returns:
      representation: The encoded representation averaged over all context 
          points.
    """
    # Concatenate x and y along the filter axes
    # DIFFERS from other implementation which concats along dim = 1
    encoder_input = torch.cat((context_x, context_y), dim = -1)
    # encoder_input = tf.concat([context_x, context_y], axis = -1)

    # Get the shapes of the input and reshape to parallelise across observations & batches
    batch_size, num_context_points , filter_size = encoder_input.shape
    # batch_size, _ , filter_size = encoder_input.shape.as_list()

    # Combine dim batches to improve parallelisation
    hidden = torch.reshape(input = encoder_input, shape = (batch_size * num_context_points, -1))
    # hidden = tf.reshape(encoder_input, (batch_size * num_context_points, -1))
    # Redundant:
    # hidden.set_shape((None, filter_size))

    # FORWARD
    for module in self.module_list:
            hidden = module(hidden)

    # # Pass through MLP
    # with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
    #  for i, size in enumerate(self._output_sizes[:-1]):
    #    hidden = tf.nn.relu(
    #       tf.layers.dense(hidden, size, name="Encoder_layer_{}".format(i)))

    #  # Last layer without a ReLu
    #  hidden = tf.layers.dense(
    #      hidden, self._output_sizes[-1], name = "Encoder_layer_{}".format(i + 1))

    # Bring back into original shape
    hidden = torch.reshape(input = hidden, shape = (batch_size, num_context_points, self._output_sizes[-1]))
    # hidden = tf.reshape(hidden, (batch_size, num_context_points, size))

    # Aggregator: take the mean over all points (dim 1). One represntation per batch.
    representation = torch.mean(input = hidden, dim = 1)
    # representation = tf.reduce_mean(hidden, axis=1)

    return representation

In [None]:
# utility methods
def batch_mlp(input, output_sizes, variable_scope):
  """Apply MLP to the final axis of a 3D tensor (reusing already defined MLPs).
  
  Args:
    input: input tensor of shape [B,n,d_in].
    output_sizes: An iterable containing the output sizes of the MLP as defined 
        in `basic.Linear`.
    variable_scope: String giving the name of the variable scope. If this is set
        to be the same as a previously defined MLP, then the weights are reused.
    
  Returns:
    tensor of shape [B, n, d_out] where d_out = output_sizes[-1]
  """
  # Get the shapes of the input and reshape to parallelise across observations
  batch_size, _, filter_size = input.shape
  # batch_size, _, filter_size = input.shape.as_list()
  output = torch.reshape(input = input, shape = (-1, filter_size))
  # output = tf.reshape(input, (-1, filter_size))
  # Redundant
  # output.set_shape((None, filter_size))

  # Pass through MLP
  with tf.variable_scope(variable_scope, reuse=tf.AUTO_REUSE):
    for i, size in enumerate(output_sizes[:-1]):
      output = tf.nn.relu(
          tf.layers.dense(output, size, name="layer_{}".format(i)))

    # Last layer without a ReLu
    output = tf.layers.dense(
        output, output_sizes[-1], name="layer_{}".format(i + 1))

  # Bring back into original shape
  output = tf.reshape(output, (batch_size, -1, output_sizes[-1]))
  return output