# Neural Processes
Author: Jin Yeom (jinyeom@utexas.edu)

## Contents
- [Introduction](#Introduction)
- [Implementation](#Implementation)
- [Experiments](#Experiments)

In [1]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F

## Introduction

[Neural Processes](https://arxiv.org/abs/1807.01622) (NPs) are a novel class of function approximation  methods that was presented by DeepMind, just few months back in July, 2018. Generalizing from their previous work, [Conditional Neural Processes](https://arxiv.org/abs/1807.01613) (CNPs), NPs bring benefits of neural networks and Gaussian processes (GPs), i.e., while being able to estimate the distribution of functions and adapt rapidly to new observations like GPs, they are computationally efficient during training like neural networks.

I'm personally fascinated by this new approach, since it addresses two major disadvantages of common deep learning practices I have been very unhappy about: 1) it is able to estimate uncertainty in its prediction, and 2) it can be trained further with additional data without disrupting previously trained knowledge.

NPs are consist of three main components:
* An **encoder** $h$ which encodes *pairs* of $(x_i, y_i)$ context values to their representations $r_i$ 
* An **aggregator** $a$ that summarizes encoded context values to a representation $r$, which parameterizes the latent distribution $z \sim \mathcal{N}(\mu(r), I\sigma(r))$
* A **condition decoder** $g$ that takes $z$ and the new target locations $x_T$ and predicts $y_T$

## Implementation

### Encoder

In [4]:
class Encoder(nn.Module):
    def __init__(self, cx_size, cy_size, repr_size):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(cx_size + cy_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, repr_size)
        
    def forward(self, context_x, context_y):
        input_ = torch.cat((context_x, context_y), dim=-1)
        input_ = F.relu(self.fc1(input_), inplace=True)
        input_ = F.relu(self.fc2(input_), inplace=True)
        input_ = F.relu(self.fc3(input_), inplace=True)
        repr_ = self.fc4(input_) # representation
        # TODO: return the aggregated 

### Conditional decoder

In [6]:
class ConditionalDecoder(nn.Module):
    def __init__(self, tx_size, repr_size):
        super(ConditionalDecoder, self).__init__()
        
    def forward(self, input_):
        raise NotImplementedError

## Experiments