In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from nmjmc.systems import GaussianTripleWell
from nmjmc.nn import NeuralMJMCNetwork
from nmjmc.sampling import MCSampler
import tensorflow as tf
from tensorflow.nn import relu
from functools import partial, update_wrapper
from matplotlib.colors import LogNorm

In [None]:
minima = np.array([[-2.2, -1.0], [0.0, 2], [2, -0.8]])
factors = np.array([10.0, 10.0, 10.0, 0.2])

In [None]:
triple_well = GaussianTripleWell()
sampler = MCSampler(triple_well, minima[0], 2, stride=10)

In [None]:
triple_well.plot_contour(bounds=[-4, 4, -4, 4, -4, 4])
plt.colorbar()
plt.scatter(minima[:, 0], minima[:, 1])

In [None]:
allTrajs = []
for minimum in minima:
    trajs = []
    for i in range(1):
        sampler.reset(np.array([minimum]))
        sampler.run(2000)
        trajs.append(sampler.traj)
    allTrajs.append(np.array(trajs))
allTrajs = np.array(allTrajs)

In [None]:
data = allTrajs.reshape((3, -1, 2))

In [None]:
triple_well.plot_contour(bounds=[-3,3,-3,3,-3,3])
plt.colorbar()
for j in range(3):
    plt.plot(data[j,:,0], data[j,:,1])

In [None]:
rc_01 = minima[0] - minima[1]
rc_02 = minima[0] - minima[2]
rc_12 = minima[1] - minima[2]

In [None]:
def split_01(x, _):
    d0 = x - minima[0]
    d0 = d0[:, 0] ** 2 + d0[:, 1] ** 2
    d1 = x - minima[1]
    d1 = d1[:, 0] ** 2 + d1[:, 1] ** 2
    return d0 < d1


def split_02(x, _):
    d0 = x - minima[0]
    d0 = d0[:, 0] ** 2 + d0[:, 1] ** 2
    d2 = x - minima[2]
    d2 = d2[:, 0] ** 2 + d2[:, 1] ** 2
    return d0 < d2


def split_12(x, _):
    d1 = x - minima[1]
    d1 = d1[:, 0] ** 2 + d1[:, 1] ** 2
    d2 = x - minima[2]
    d2 = d2[:, 0] ** 2 + d2[:, 1] ** 2
    return d1 < d2

In [None]:
dim = 2
nnodes = [20, 20, 20]
nnodes_small = [8 * dim, 4 * dim, 2 * dim]
nintermediates = 0
block_length = 10
nnodes_sigma = []

In [None]:
nn_01 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_01,
)
nn_02 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_02,
)
nn_12 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_12,
)

In [None]:
def _loss_NMJMC(y_true, y_pred, energy_function, factor_distance_all):
    x, y, j_x = nn_01.split_output(y_pred)
    E_x = energy_function(y_true)

    energy = tf.check_numerics(energy_function(y), "y") - E_x
    diff_all = y - y_true
    red_diff_all = factor_distance_all * tf.reduce_sum(diff_all ** 2, axis=1)

    log_jacobian = tf.reduce_sum(j_x, axis=1)
    negative_log_acceptance = energy - log_jacobian
    return (negative_log_acceptance + factor_distance_all * red_diff_all) ** 2

In [None]:
def wrapped_partial(func, *args, **kwargs):
    partial_func = partial(func, *args, **kwargs)
    update_wrapper(partial_func, func)
    return partial_func

In [None]:
loss_1 = wrapped_partial(
    _loss_NMJMC, energy_function=triple_well.energy_tf, factor_distance_all=100.0
)
loss_2 = wrapped_partial(
    _loss_NMJMC, energy_function=triple_well.energy_tf, factor_distance_all=1.0
)

In [None]:
training_01 = np.concatenate([data[0, :], data[1, :]])
labels_01 = np.concatenate([data[1, :], data[0, :]])
training_02 = np.concatenate([data[0, :], data[2, :]])
labels_02 = np.concatenate([data[2, :], data[0, :]])
training_12 = np.concatenate([data[1, :], data[2, :]])
labels_12 = np.concatenate([data[2, :], data[1, :]])

In [None]:
nn_01.train_pair(training_01, labels_01, loss_1, nepochs=100)
nn_02.train_pair(training_02, labels_02, loss_1, nepochs=100)
nn_12.train_pair(training_12, labels_12, loss_1, nepochs=100)

In [None]:
nn_01.train_pair(training_01, labels_01, loss_2, nepochs=100, learning_rate=0.0001)
nn_02.train_pair(training_02, labels_02, loss_2, nepochs=100, learning_rate=0.0001)
nn_12.train_pair(training_12, labels_12, loss_2, nepochs=100, learning_rate=0.0001)

In [None]:
out = nn_12.generate_output(training_12[-10000:])

In [None]:
plt.hist2d(out['x'][:,0], out['x'][:,1], range=[[-3,3],[-3,3]], bins=100);

In [None]:
plt.hist2d(out['y'][:,0], out['y'][:,1], range=[[-3,3],[-3,3]], bins=100);

In [None]:
nn_01.save_network('../local_data/pretrained_models/rnvp_01_NJMC_full_partition')
nn_02.save_network('../local_data/pretrained_models/rnvp_02_NJMC_full_partition')
nn_12.save_network('../local_data/pretrained_models/rnvp_12_NJMC_full_partition')