In [None]:
import numpy as np
import numpy.random as npr
from tqdm import tqdm
import tensorflow as tf

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from cnf import CNF
from neural_ode import NeuralODE


In [None]:
num_samples = 512

p0 = tf.distributions.Normal(loc=[1.0, 0.0], scale=[1.0, 1.0])
logdet0 = tf.zeros([num_samples, 1])
x0 = tf.random_normal([num_samples, 2])
h0 = tf.concat([x0, logdet0], axis=1)

cnf_net = CNF(input_dim=2, hidden_dim=32, n_ensemble=16)
ode = NeuralODE(model=cnf_net, t=np.linspace(1, 0.0, 10))

In [None]:
def compute_gradients_and_update():
    hN = ode.forward(inputs=h0)
    with tf.GradientTape() as g:
        g.watch(hN)
        xN, logdetN = hN[:, :2], hN[:, 2]
        # L = log(p(zN))
        mle = tf.reduce_sum(p0.log_prob(xN), -1)
        # loss to minimize
        loss = -tf.reduce_mean(mle - logdetN)

    dloss = g.gradient(loss, hN)
    h0_rec, dLdh0, dLdW = ode.backward(hN, dloss)          
    optimizer.apply_gradients(zip(dLdW, cnf_net.weights))
    return loss


compute_gradients_and_update = tfe.defun(compute_gradients_and_update)

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)

In [None]:
loss_history = []
for step in tqdm(range(1000)):
    loss = compute_gradients_and_update()
    loss_history.append(loss.numpy())
    if step % 200 == 0:
        plt.plot(loss_history)
        plt.show()
        hN = ode.forward(h0)
        xN, logdetN = hN[:, :2], hN[:, 2]
        plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
        plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
        plt.axis("equal")
        plt.show()

In [None]:
loss_history[-1:]

In [None]:
plt.plot(loss_history)

In [None]:
hN = ode.forward(h0)
xN, logdetN = hN[:, :2], hN[:, 2]
plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
plt.axis("equal")
plt.show()

In [None]:
from sklearn.datasets import make_moons

p0 = tf.distributions.Normal(loc=[0.0, 0.0], scale=[1.0, 1.0])
x0 = tf.to_float(make_moons(n_samples=256, noise=0.05)[0])
logdet0 = tf.zeros([256, 1])
h0 = tf.concat([x0, logdet0], axis=1)

cnf_net = CNF(input_dim=2, hidden_dim=32, n_ensemble=16)
ode = NeuralODE(model=cnf_net, t=np.linspace(1, 0.0, 10))

In [None]:
optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-4)

In [None]:
def compute_gradients_and_update():
    hN = ode.forward(inputs=h0)
    with tf.GradientTape() as g:
        g.watch(hN)
        xN, logdetN = hN[:, :2], hN[:, 2]
        # L = log(p(zN))
        mle = tf.reduce_sum(p0.log_prob(xN), -1)
        # loss to minimize
        loss = -tf.reduce_mean(mle - logdetN)

    dloss = g.gradient(loss, hN)
    h0_rec, dLdh0, dLdW = ode.backward(hN, dloss)          
    optimizer.apply_gradients(zip(dLdW, cnf_net.weights))
    return loss

compute_gradients_and_update = tfe.defun(compute_gradients_and_update)

In [None]:
loss_history = []
for step in tqdm(range(1100)):
    loss = compute_gradients_and_update()
    loss_history.append(loss.numpy())
    if step % 200 == 0:
        plt.subplot(121)
        plt.plot(loss_history)
        plt.subplot(122)
        hN = ode.forward(h0)
        xN, logdetN = hN[:, :2], hN[:, 2]
        plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
        plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
        plt.axis("equal")
        plt.show()

In [None]:
hN = ode.forward(h0)
xN, logdetN = hN[:, :2], hN[:, 2]
plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
_ = plt.axis("equal")

In [None]:
ode = NeuralODE(model=cnf_net, t=np.linspace(0, 1.0, 20))
h0_reconstruction = ode.forward(inputs=hN)

In [None]:
x0_rec = h0_reconstruction[:, :2]
plt.scatter(*x0_rec.numpy().T, color='k', alpha=0.5)
plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
_ = plt.axis("equal")

In [None]:
hN_sample = tf.concat([p0.sample(256), logdet0], axis=1)
h0_reconstruction = ode.forward(inputs=hN_sample)

In [None]:
x0_rec = h0_reconstruction[:, :2]
plt.scatter(*x0_rec.numpy().T, color='k', alpha=0.5)
plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
_ = plt.axis("equal")

In [None]:
num_samples = 512

logdet0 = tf.zeros([num_samples, 1])
x0 = tf.random_normal([num_samples, 2])
h0 = tf.concat([x0, logdet0], axis=1)

cnf_net = CNF(input_dim=2, hidden_dim=32, n_ensemble=16)
ode = NeuralODE(model=cnf_net, t=np.linspace(1, 0.0, 10))

In [None]:
def w1(z):
    return tf.sin(2.*np.pi*z[0]/4.)
def w2(z):
    return 3.*tf.exp(-.5*(((z[0]-1.)/.6))**2)
def w3(z):
    return 3.*(1+tf.exp(-(z[0]-1.)/.3))**-1


def potential_energy(z):
    z = tf.transpose(z)
    return .5*((tf.norm(z, ord=2, axis=0) - 2.)/.4)**2 \
        - tf.log(tf.exp(-.5*((z[0]-2.)/.6)**2) + tf.exp(-.5*((z[0]+2.)/.6)**2))

In [None]:
def compute_gradients_and_update():
    hN = ode.forward(inputs=h0)
    with tf.GradientTape() as g:
        g.watch(hN)
        xN, logdetN = hN[:, :2], hN[:, 2]
        # L = log(p(zN))
        mle = - tf.reduce_sum(potential_energy(xN), -1)
        # loss to minimize
        loss = -tf.reduce_mean(mle - logdetN)

    dloss = g.gradient(loss, hN)
    h0_rec, dLdh0, dLdW = ode.backward(hN, dloss)          
    optimizer.apply_gradients(zip(dLdW, cnf_net.weights))
    return loss

compute_gradients_and_update = tfe.defun(compute_gradients_and_update)

In [None]:
loss_history = []

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)

In [None]:
for step in tqdm(range(3001)):
    loss = compute_gradients_and_update()
    loss_history.append(loss.numpy())
    if step % 200 == 0:
        plt.subplot(121)
        plt.plot(loss_history)
        plt.subplot(122)
        hN = ode.forward(h0)
        xN, logdetN = hN[:, :2], hN[:, 2]
        plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
        plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
        plt.axis("equal")
        plt.show()

In [None]:
hN = ode.forward(h0)
xN, logdetN = hN[:, :2], hN[:, 2]
plt.scatter(*xN.numpy().T, color='k', alpha=0.5)
plt.scatter(*x0.numpy().T, color='r', alpha=0.5)
_ = plt.axis("equal")