In [None]:
import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.distributions as D
import torch.nn.functional as F

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline
%load_ext autoreload
%autoreload 2

#### Target distribution

$$p(x_1, x_2) = \mathcal{N}\Big(x_1 |\, \mu= \frac{1}{4 x^2_x}, \sigma=1\Big) \cdot \mathcal{N}\Big(x_2 |\, \mu=0, \sigma=4\Big)$$

In [None]:
batch_size = 512
# Draw batch_size samples from first distribution
loc2, scale2 = torch.tensor([0.0]), torch.tensor([4.0])
dist2 = D.Normal(loc2, scale2)
x2 = dist2.sample((batch_size,))

# Get a point estimate for all x_2 samples
loc1 = 0.25 * torch.pow(x2, 2)
scale1 = torch.ones_like(loc1)
dist1 = D.Normal(loc1, scale1)
x1 = dist1.sample()

x_samples = torch.stack((x1, x2), dim=0).squeeze(-1).transpose(1, 0)

In [None]:
plt.scatter(x_samples[:, 0], x_samples[:, 1], s=10, color='red')
plt.xlim([-5, 30])
plt.ylim([-10, 10])
plt.show()

In [None]:
AT = D.transforms.AffineTransform(torch.zeros((1, 2)), torch.tensor(1.5))

In [None]:
y = AT(x_samples)

In [None]:
plt.scatter(y[:, 0], y[:, 1], s=10, color='blue')
plt.scatter(x_samples[:, 0], x_samples[:, 1], s=10, color='red')
plt.xlim([-5, 60])
plt.ylim([-20, 20])
plt.show()

#### Construct the flow

Flow that resembles a standard fully-connected network, so alternating matrix multiplication with non-linearities

Determinants are computationally expensive, use matrix determinant lemma and a structured affine transform. The latter is parameterized as a lower triangular matrix $M$ and a low rank update:
$$M + VDV^T$$

Next, we need an invertible nonlinearity. Sigmoids and tanh functions are incredibly unstable to invert as small changes in output near -1, 1 correspond to massive changes in the input. ReLu is stable, but not invertible for $x \le 0$.

Hence, PReLU, Parameterized ReLU, which is like leaky ReLU, but with a learnable slope in the negative regime. 

In [None]:
# ## Custom non linearity.
# class ParLeakyReluFunction(torch.autograd.Function):

#     @staticmethod
#     def forward(ctx, input, alpha):
#         ctx.save_for_backward(input, alpha)
#         output = torch.where(input >= 0, input, input * alpha)
#         return output
    
#     @staticmethod
#     def backward(ctx, grad_output):
#         input, alpha = ctx.saved_tensors
#         grad_input = grad_weight = grad_bias = None
        

In [None]:
# Base distribution
base_dist = D.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))

In [None]:
class ParLeakyRelu(nn.Module):
    def __init__(self, alpha, event_dims):
        super(ParLeakyRelu, self).__init__()
        self.alpha = alpha
        self.event_dim = event_dim
        
    def forward(self, x):
        return torch.where(x >= 0, x, x * self.alpha)
        
    def inverse(self, x):
        return torch.where(x >=0, x, 1. / (self.alpha * x))
    
    def inverse_log_det_jacobian(self, x):
        I = torch.ones_like(x)
        J_inv = torch.where(y >= 0, I, 1.0 / (self.alpha * I))
        # Determine log abs det of J_inv
        log_abs_det_J_inv = torch.log(torch.abs(J_inv))
        return torch.sum(log_abs_det_J_inv, dim=self.event_dim)

In [None]:
class MLPBijector(nn.Module):
    def __init__(self, alpha, d, r, num_layers):
        super(MLPBijector, self).__init__()
        self.d = d
        self.r = r
        self.alpha = alpha
        self.num_layers = num_layers
        
        self.num_layers = num_layers
        self.encoder = nn.ModuleList([
            D.transforms.AffineTransform(), 
            ParLeakyRelu(self.alpha)
        ])
        self.output_layer = D.transforms.AffineTransform()
        
    def forward(self, x):
        for i in range(self.num_layers):
            x = encoder(x)
        return self.output_layer(x)
    

In [None]:
# AT = D.transforms.AffineTransform(torch.zeros((1, 2)), torch.tensor(1.5))

In [None]:
def loss(x, x_hat):
    nll_loss = nn.NLLLoss(x_hat, x, reduction=mean)
    return nll_loss

def get_optimizer(model):
    lr = 1e-3
    beta1 = 0.9
    beta2 = 0.999
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                     betas=(beta1, beta2))
    return optimizer


def init_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
            init.xavier_uniform_(m.weight.data)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.Sequential):
            for sub_mod in m:
                init_weights(sub_mod)


alpha = 0.8
d = 2 
r = 2
num_layers = 6

model = MLPBijector(alpha, d, r, num_layers)
init_weights(model)
opt = get_optimizer(model)


In [None]:
# todo, custom layer, or
# nn.Linear for Affine transform, since its the same thing