In [None]:
import numpy as np

from a_nice_mc.models.discriminator import MLPDiscriminator
from a_nice_mc.models.generator import create_nice_network
from a_nice_mc.train.wgan_nll import Trainer

from hepmc.core.densities.camel import UnconstrainedCamel as Camel

In [None]:
import tensorflow as tf
from a_nice_mc.objectives import Energy
from tensorflow.python.framework import ops

class Camel3d(Energy):
    def __init__(self):
        super(Camel3d, self).__init__()
        self.name = "Camel3d"
        self.z = tf.placeholder(tf.float32, [None, 3], name='z')
        self.mu_a = np.array(3*[1/3], dtype=np.float32)
        self.mu_b = np.array(3*[2/3], dtype=np.float32)
        self.stddev = np.array(3*[.1 / np.sqrt(2)], dtype=np.float32)
        self.camel = Camel(3)
    
    def __call__(self, z):
        z1 = tf.reshape(tf.slice(z, [0, 0], [-1, 1]), [-1])
        z2 = tf.reshape(tf.slice(z, [0, 1], [-1, 1]), [-1])
        z3 = tf.reshape(tf.slice(z, [0, 2], [-1, 1]), [-1])
        return self.tf_energy(z1, z2, z3)
    
    # energy as numpy function
    def energy(self, z1, z2, z3):
        z = np.array([z1, z2, z3]).transpose()
        return self.camel.pot(z).astype(np.float32, copy=False)
        #return self.camel.pot(z).view('float32')
    
    # gradient as numpy function
    def d_energy(self, z1, z2, z3):
        z = np.array([z1, z2, z3]).transpose()
        grad = self.camel.pot_gradient(z).astype(np.float32, copy=False)
        return grad[:, 0], grad[:, 1], grad[:, 2]
        #self.camel.pot_gradient(z).view('float32')
    
    # energy as tensorflow function
    def tf_energy(self, z1, z2, z3, name=None):
        with tf.name_scope(name, "energy", [z1, z2, z3]) as name:
            y = self.py_func(self.energy,
                       [z1, z2, z3],
                       [tf.float32],
                       name=name,
                       grad=self.energy_grad)
            return y[0]
        
    # gradient as tensorflow function
    def tf_d_energy(self, z1, z2, z3, name=None):
        with tf.name_scope(name, "d_energy", [z1, z2, z3]) as name:
            y = tf.py_func(self.d_energy,
                          [z1, z2, z3],
                          [tf.float32, tf.float32, tf.float32],
                          name=name,
                          stateful=False)
            return y
    
    # Define custom py_func which takes also a grad op as argument:
    def py_func(self, func, inp, Tout, stateful=True, name=None, grad=None):
        # Need to generate a unique name to avoid duplicates:
        rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

        tf.RegisterGradient(rnd_name)(grad)
        g = tf.get_default_graph()
        #ith g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
        with g.gradient_override_map({"PyFunc": rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

    # Actual gradient:
    def energy_grad(self, op, grad):
        z1 = op.inputs[0]
        z2 = op.inputs[1]
        z3 = op.inputs[2]
        n_gr = self.tf_d_energy(z1, z2, z3)
        return grad * n_gr[0], grad * n_gr[1], grad * n_gr[2]
    
    @staticmethod
    def mean():
        return np.array(3*[.5])
    
    @staticmethod
    def std():
        return np.array(3*[.181])
    
    def evaluate(self, zv, path=None):
        pass

In [None]:
energy_fn = Camel3d()

In [None]:
def noise_sampler(bs):
    return np.random.normal(0.0, 1.0, [bs, 3])

In [None]:
discriminator = MLPDiscriminator([400, 400, 400])
generator = create_nice_network(
    3, 10,
    [
        ([400], 'v1', False),
        ([400, 400], 'x1', True),
        ([400], 'v2', False)
    ]
)

In [None]:
trainer = Trainer(generator, energy_fn, discriminator, noise_sampler, b=16, m=4)

In [None]:
trainer.train(max_iters=1000)

In [None]:
sample = trainer.sample(batch_size=32, steps=1000)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.hist2d(sample[0][0][:, 0], sample[0][0][:, 1], range=[[0, 1], [0, 1]], bins=20)
plt.show()

In [None]:
plt.figure()
plt.hist2d(sample[0][0][:, 0], sample[0][0][:, 2], range=[[0, 1], [0, 1]], bins=20)
plt.show()

In [None]:
plt.figure()
plt.hist2d(sample[0][0][:, 1], sample[0][0][:, 2], range=[[0, 1], [0, 1]], bins=20)
plt.show()