# Ladder GNP
For Pytorch 1.0.1, python 3.7

In [1]:
import numpy as np

import torch
from torch.distributions import multivariate_normal, normal
import torch.nn as nn

import collections, copy
from datetime import datetime

import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super().__init__()
        layers = []
        out_dim = output_size
        for i, size in enumerate(hidden_sizes):
            in_dim = input_size if i == 0 else hidden_sizes[i-1]
            out_dim = size 
            layers += [nn.Linear(in_dim, out_dim), nn.ReLU()]
        # Last layer without a ReLU
        layers += [nn.Linear(out_dim, output_size)]
        self.mlp = nn.Sequential(*layers)
                   
    def forward(self, x):
        return self.mlp(x)

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_sizes, a_dim, z_dim):
        """
          input_size: raw input size = d_x + d_y
          hidden_sizes: The dimensions of hidden layers of the encoding MLP.
          a_dim: aggregation vector dimension.
          z_dim: the size of dropout weigths
          
          a_mlp: input_size -> hidden_sizes -> a_dim = aggregation vector
          z_mlp: a_dim -> hidden_sizes -> z_dim = z distribution
        """
        super().__init__()
        self.a_mlp = MLP(input_size, hidden_sizes, a_dim)
        self.g_mlp = MLP(input_size, hidden_sizes, a_dim)
        self.z_dim = z_dim
        self.z_mlp = MLP(a_dim, hidden_sizes, z_dim*2)
        self.embeddings = None
        self._softplus = nn.Softplus()
        
    def _to_one_hot(self, y, n_dims, dtype=torch.cuda.FloatTensor):
        scatter_dim = len(y.size())
        y_tensor = y.type(torch.cuda.LongTensor).view(*y.size(), -1)
        zeros = torch.zeros(*y.size(), n_dims).type(dtype)

        return zeros.scatter(scatter_dim, y_tensor, 1)   
    
    def forward(self, x, y):
        """Encodes the inputs into one representation.

        Args:
          x: Tensor of shape [Batches,contexts,d_x]. For this 1D regression
              task this corresponds to the x-values.
          y: Tensor of shape [B,contexts,d_y]. For this 1D regression
              task this corresponds to the y-values.

        Returns:
          A normal distribution over tensors of shape [B, z_dim]
        """
        '''
        num_context = x.size(1)
        bins = (self.z_dim*(x+2)*0.25).int()
        masked = self._to_one_hot(bins, self.z_dim).squeeze(-2)
        masked = torch.max(masked, dim=1)[0]
        #print(masked[0])
        masked = (masked).unsqueeze(1).repeat([1, num_context, 1])
        #print(masked[0])
        '''
        # Concatenate x and y along the filter axes
        encoder_input = torch.cat([x, y], dim=-1)

        # Pass final axis through MLP
        es = self.a_mlp(encoder_input)
        #gs = self.g_mlp(encoder_input)
        #gs = self.g_mlp(encoder_input).masked_fill(masked==0, -1e9) # B C d
        #B, C, d = gs.shape
        #gs = torch.tanh(gs)
        #gs = torch.softmax(gs, dim=1)

        hidden = torch.mean(es, dim=1)
        self.embeddings = hidden
        #z_sigma = g_tilda*0.1 + (1-g_tilda)*1.0


        # Produce mu & sigma
        hidden = self.z_mlp(hidden)
        
        z_mu, log_sigma = torch.split(hidden, self.z_dim, dim=-1)

        z_sigma = 0.1 + 0.9 * self._softplus(log_sigma)

        z_dist = normal.Normal(z_mu, z_sigma)
        
        return z_dist, z_mu, z_sigma

In [7]:
class Decoder(nn.Module):
    def __init__(self, input_size, x_input_size, hidden_sizes, output_size):        
        """
          input_size: d_x + dimension of z
          hidden_sizes: The dimensions of hidden layers of the decoding MLP. 
          output_size: d_y
        """
        super().__init__()
        self.output_size = output_size
        self._mlp = MLP(input_size + x_input_size, hidden_sizes, output_size*2)
        #self.g_mlp = MLP(x_input_size, hidden_sizes, input_size)
        self._softplus = nn.Softplus()

    def forward(self, target_x, representation):
        """Decodes the individual targets.

        Args:
          target_x: The x locations for the target query.
              Tensor of shape [B,targets,d_x].
          representation: The representation of the context for target predictions. 
              Tensor of shape [B,z_dim].
              z_dim = sum(hidden_sizes)

        Returns:
          dist: A multivariate Gaussian over the target points. A distribution over
              tensors of shape [B,targets,d_y].
          mu: The mean of the multivariate Gaussian.
              Tensor of shape [B,targets,d_y].
          sigma: The standard deviation of the multivariate Gaussian.
              Tensor of shape [B,targets,d_y].
        """
        # Pass final axis through MLP
        #x = self.g_mlp(target_x)
        num_targets = target_x.shape[1]
        representation = representation.unsqueeze(1).repeat([1, num_targets, 1])
        #representation = representation*gs
        hidden = torch.cat([representation, target_x], dim=-1)
        hidden = self._mlp(hidden)
        mu, log_sigma = torch.split(hidden, self.output_size, dim=-1)

        # Bound the variance
        sigma = 0.1 + 0.9 * self._softplus(log_sigma)
        
        # Get the distribution
        dist = normal.Normal(mu, sigma) #multivariate, but, diagonal covariance matrix

        return dist, mu, sigma

In [None]:
class ZMLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):        
        """
          input_size: d_x + dimension of z
          hidden_sizes: The dimensions of hidden layers of the decoding MLP. 
          output_size: d_y
        """
        super().__init__()
        self.output_size = output_size
        self._mlp = MLP(input_size, hidden_sizes, output_size*2)

    def forward(self, z):
        """Decodes the individual targets.

        Args:
          target_x: The x locations for the target query.
              Tensor of shape [B,targets,d_x].
          representation: The representation of the context for target predictions. 
              Tensor of shape [B,z_dim].
              z_dim = sum(hidden_sizes)

        Returns:
          dist: A multivariate Gaussian over the target points. A distribution over
              tensors of shape [B,targets,d_y].
          mu: The mean of the multivariate Gaussian.
              Tensor of shape [B,targets,d_y].
          sigma: The standard deviation of the multivariate Gaussian.
              Tensor of shape [B,targets,d_y].
        """
        hidden = self._mlp(z)
        mu, log_sigma = torch.split(hidden, self.output_size, dim=-1)

        # Bound the variance
        sigma = 0.1 + 0.9 * torch.sigmoid(log_sigma)

        return mu, sigma

In [8]:
class DropoutNeurlProcess(nn.Module):
    def __init__(self, input_sizes, encoder_hidden_sizes, a_dim, decoder_hidden_sizes):
        """
        Args:
          input_sizes: [d_x, d_y]
          encoder_hidden_sizes: [eh1, eh2]
          a_dim:
          decoder_hidden_sizes: [dh1, dh2, dh3]
        """
        super().__init__()    
        
        z_dim = a_dim
        self.encoder = Encoder(sum(input_sizes), encoder_hidden_sizes, a_dim, z_dim)
        self.decoder = Decoder(z_dim, input_sizes[0], decoder_hidden_sizes, input_sizes[1])
        self.z_mlp = ZMLP(z_dim, encoder_hidden_sizes, z_dim)
        self.l_mlp = ZMLP(z_dim, encoder_hidden_sizes, z_dim)
        self.d_mlp = MLP(z_dim, encoder_hidden_sizes, z_dim)
        self.z_dim = z_dim
        self.device = "cpu"
        if torch.cuda.is_available():
            self.device = "cuda"
            
    def sample_z(self, mu, sigma, dist=None):
        """Reparametrization trick
        """
        eps = torch.randn(sigma.shape) 
        if torch.cuda.is_available():
            eps = eps.cuda()
        return mu + sigma * eps

    def gaussian_product(self, mu1, sigma1, mu2, sigma2):
        sigma = sigma1*sigma2/torch.sqrt(sigma1.pow(2)+sigma2.pow(2))
        mu = (mu1/(sigma1.pow(2)) + mu2/(sigma2.pow(2)))*(sigma.pow(2))
        return mu, sigma
    
    def _KL(self, mu, sigma, mup, sigmap):
        t = 2.0*torch.log(sigmap) - 2.0*torch.log(sigma) + (sigma.pow(2))* (1.0/sigmap.pow(2)) + (1.0/sigmap.pow(2))*((mu-mup).pow(2))-1.0
        return t*0.5
    
    def forward(self, context_x, context_y, target_x, target_y=None, beta=1.0):
        """Returns the predicted mean and variance at the target points.

        Args:
          context_x: Tensor of shape [B,num_contexts,d_x]. 
              Contains the x values of the context points.
          context_y: Tensor of shape [B,num_contexts,d_y]. 
              Contains the y values of the context points.
          target_x: Tensor of shape [B,num_targets,d_x]. 
              Contains the x values of the target points.
          target_y: The ground truth y values of the target y. 
              Tensor of shape [B,num_targets,d_y].

        Returns:
          log_p: The log_probability of the target_y given the predicted
              distribution. Tensor of shape [B,num_targets].
          mu: The mean of the predicted distribution. 
              Tensor of shape [B,num_targets,d_y].
          sigma: The variance of the predicted distribution.
              Tensor of shape [B,num_targets,d_y].
        """
          
        num_batch = context_x.size(0)
        
        mu, sigma, log_p, loss = None, None, None, None
        dkl = None
        
        _, mu_p_3, sigma_p_3 = self.encoder(context_x, context_y)
        z3 = self.sample_z(mu_p_3, sigma_p_3)
        mu_p_2, sigma_p_2 = self.z_mlp(z3)
        z2 = self.sample_z(mu_p_2, sigma_p_2)
        mu_p_1, sigma_p_1 = self.z_mlp(z2)

        if target_y is None:
            # Prediction
            z1 = self.sample_z(mu_p_1, sigma_p_1)
            dist, mu, sigma = self.decoder(target_x, z1)
        else:
            # Inference
            _, mu_q_3, sigma_q_3  = self.encoder(target_x, target_y)
            embeddings = self.encoder.embeddings
            d_1 = self.d_mlp(embeddings)
            d_2 = self.d_mlp(d_1)
            mu_h_q_2, sigma_h_q_2 = self.l_mlp(d_2)
            mu_q_2, sigma_q_2 = self.gaussian_product(mu_p_2, sigma_p_2, mu_h_q_2, sigma_h_q_2)
            mu_h_q_1, sigma_h_q_1 = self.l_mlp(d_1)
            mu_q_1, sigma_q_1 = self.gaussian_product(mu_p_1, sigma_p_1, mu_h_q_1, sigma_h_q_1)            
            z1 = self.sample_z(mu_q_1, sigma_q_1)    

            dist, mu, sigma = self.decoder(target_x, z1)
            log_p = dist.log_prob(target_y)
            log_p = log_p.squeeze(-1)       
            
            num_targets = target_x.size(1)
            dkl3 = torch.sum(self._KL(mu_q_3, sigma_q_3, mu_p_3, sigma_p_3),
                  dim=-1, keepdim=True)
            dkl3 = dkl3.repeat([1, num_targets])
            dkl2 = torch.sum(self._KL(mu_q_2, sigma_q_2, mu_p_2, sigma_p_2),
                  dim=-1, keepdim=True)
            dkl2 = dkl2.repeat([1, num_targets])
            dkl1 = torch.sum(self._KL(mu_q_1, sigma_q_1, mu_p_1, sigma_p_1),
                  dim=-1, keepdim=True)
            dkl1 = dkl1.repeat([1, num_targets])
            loss = - torch.mean(log_p - beta * (dkl1+dkl2+dkl3) / num_targets)
          
            
        return mu, dkl, loss

In [None]:
def get_model():
    HIDDEN_SIZE = 128 #@param {type:"number"}
    input_sizes = [2, 1]
    encoder_hidden_sizes = [HIDDEN_SIZE]*3
    decoder_hidden_sizes = [HIDDEN_SIZE]*4
    model = DropoutNeurlProcess(input_sizes, encoder_hidden_sizes, HIDDEN_SIZE, decoder_hidden_sizes)
    return model