In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from nmjmc.systems import GaussianDoublewell, GaussianTripleWell
from nmjmc.sampling import MCSampler, VoronoiMixture
from nmjmc.nn import NeuralMJMCNetwork
import tensorflow as tf
from functools import partial, update_wrapper
from matplotlib.colors import LogNorm
from tqdm import tqdm_notebook as tqdm
import scipy.integrate as si
import matplotlib as mpl

In [None]:
minima = np.array([[-2.2, -1.0], [0.0, 2], [2, -0.8]])
factors = np.array([10.0, 10.0, 10.0, 0.2])

In [None]:
triple_well = GaussianTripleWell()
sampler = MCSampler(triple_well, minima[0], 2)

In [None]:
config = tf.ConfigProto(device_count={"GPU": 0})
sess = tf.Session(config=config)

In [None]:
def split_01(x, _):
    d0 = x - minima[0]
    d0 = d0[:, 0] ** 2 + d0[:, 1] ** 2
    d1 = x - minima[1]
    d1 = d1[:, 0] ** 2 + d1[:, 1] ** 2
    return d0 < d1


def split_02(x, _):
    d0 = x - minima[0]
    d0 = d0[:, 0] ** 2 + d0[:, 1] ** 2
    d2 = x - minima[2]
    d2 = d2[:, 0] ** 2 + d2[:, 1] ** 2
    return d0 < d2


def split_12(x, _):
    d1 = x - minima[1]
    d1 = d1[:, 0] ** 2 + d1[:, 1] ** 2
    d2 = x - minima[2]
    d2 = d2[:, 0] ** 2 + d2[:, 1] ** 2
    return d1 < d2

In [None]:
dim = 2
nnodes = [20, 20, 20]
nnodes_small = [8 * dim, 4 * dim, 2 * dim]
nintermediates = 0
block_length = 10
nnodes_sigma = []

In [None]:
nn_01 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_01,
)
nn_02 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_02,
)
nn_12 = NeuralMJMCNetwork(
    nnodes,
    nnodes,
    block_length,
    block_length,
    dim=2,
    system=triple_well,
    split_cond=split_12,
)

In [None]:
nn_01.load_weights(
    "../local_data/pretrained_models/rnvp_01_NJMC_full_partition_weights.h5"
)
nn_02.load_weights(
    "../local_data/pretrained_models/rnvp_02_NJMC_full_partition_weights.h5"
)
nn_12.load_weights(
    "../local_data/pretrained_models/rnvp_12_NJMC_full_partition_weights.h5"
)

In [None]:
minima = np.array([[-2.2, -1.0], [0.0, 2], [2, -0.8]])

In [None]:
selection_probabilities = np.zeros((4, 4))
p_global = 0.1
selection_probabilities[0, 0] = p_global
selection_probabilities[0, 1] = p_global
selection_probabilities[1, 0] = p_global
selection_probabilities[1, 2] = p_global
selection_probabilities[2, 1] = p_global
selection_probabilities[2, 2] = p_global
selection_probabilities[:, 3] = 1.0 - np.sum(selection_probabilities[:, :3], axis=1)
selection_probabilities[3, 3] = 1.0

In [None]:
networks = [nn_01, nn_02, nn_12]

In [None]:
x0 = np.repeat(np.expand_dims(minima[2], 0), 100, axis=0)

In [None]:
kernel_connectivity = [[0, 1], [0, 2], [2, 1]]

In [None]:
mm = VoronoiMixture(
    minima,
    networks,
    selection_probabilities,
    triple_well.energy,
    kernel_connectivity,
    dim=2,
)
samples, global_pacc = np.array(mm.run(x0, 100, reporter="notebook"))

In [None]:
np.save("../local_data/samples_triple_well.np", samples)