# Invertible Neural Networks (for generation)

In [76]:
import jax 
import jax.numpy as np
import jax.scipy as sp
import numpy as onp
import matplotlib.pyplot as plt

$$\frac{|| f(x_0) - f(x_1) ||}{||x_0 - x_1||} < q \qquad \forall_{x_0, x_1 \in X}$$
$$ z = f(x) + x $$
$$ x = z - f(x) $$
$$ x = z - f(z - f(x))$$
$$ x = z - f(z - f(\dots))$$

Is $g(x) = z - f(x)$ a contraction? 

$$|| g(x_0) - g(x_1) || = || z - f(x_0) - z + f(x_1) || $$
$$ = || f(x_1) - f(x_0) || = $$
$$ = || f(x_1) - f(x_0) || < q ||x_0 - x_1|| $$

Therefore, $g$ is a contraction. 

Let $X \sim \mathcal{X}$. 
$$ p(f | D) = \frac{p(D | f)p(f)}{p(D)} \propto p(D | f)p(f) $$
$$ \mathbb{D}_{KL}[q(f)\;||\;p(f | D)] = -\mathbb{E}_{f \sim q, X \sim \mathcal{X}}\text{log}(f(X)) + \mathbb{D}_{KL}[q(f) \;||\;p(f)]$$
Easier, not accurate:
$$ \max_{\theta \in \Theta} \mathbb{E}_{X\sim\mathcal{X}}\text{log}\,m_\theta(X) $$
$$ \text{s.t. } \forall_{w \in W} ||w||_2 < q $$
where $0 \leq q < 1$ and $W \subset \theta$ are the weights of the neural network.

In [98]:
def create_layer(key, input_dim, output_dim, activation_function=jax.nn.relu):
    weight = jax.random.normal(key, shape=(input_dim, output_dim))
    bias = np.zeros((output_dim,))
    return weight, bias, activation_function

def create_mlp(key, inout):
    return [create_layer(key, inout[i], inout[i+1])
            for i, key in enumerate(jax.random.split(key, len(inout)-1))]

def layer_forward(layer, X):
    weight, bias, activation_function = layer
    return activation_function(X @ weight + bias)

def residual_forward(layer, X):
    return layer_forward(layer, X) + X

def forward(model, X):
    for layer in model:
        X = residual_forward(layer, X)
    return X

def layer_constraint(layer):
    weight, _, _ = layer
    return np.linalg.norm(weight, ord=2)

def constraint(model):
    return sum(layer_constraint(layer) for layer in model)

def loss(model):
    