## Import packages

In [83]:
import random
import numpy as np
import torch 
import plotly.graph_objects as go
import plotly.express as px

In [84]:
np.random.seed(seed=0)
torch.manual_seed(0)
random.seed(0) 

## Visualize double well potential
The energy landscape for the double well model system is given by eq. 18 of Noe et al.:

In [85]:
def E(x, y, a=1, b=6, c=1, d=1): # default parameter values given in pg. 4 of the SI 
    return a*x**4/4 - b*x**2/2 + c*x + d*y**2/2

We can plot this energy landscape as is done in Fig. 2 of Noe et al: 

In [97]:
x_illustrate = np.linspace(-4,4,100)
y_illustrate = np.linspace(-8,8,100)
xx, yy = np.meshgrid(x_illustrate, y_illustrate, sparse=True)
E_illustrate = E(xx,yy)

fig = go.Figure(data =
    go.Contour(
        z=E_illustrate,
        x=x_illustrate, 
        y=y_illustrate,
        reversescale = True,
        colorscale = "viridis",
        contours=dict(
            start=-10,
            end=10,
            size=2,
        ),
    ))

fig.update_layout(
    xaxis_title="x",
    yaxis_title="y",
)

fig.show()

This looks slightly different, specifically the minimum seem to be located at different positions from that in the paper. Let's plot $E=f(x)$:

In [98]:
fig = go.Figure(data=go.Scatter(x=x_illustrate, y=E(x_illustrate,0)))
fig.update_layout(
    xaxis_title="x",
    yaxis_title="E",
)
fig.show()

Let's now generate sample configurations we will use as our input data set. 

In [96]:
x_state_a, y_state_a = np.random.multivariate_normal([-2.5,0], 0.15*np.eye(2), 500).T
x_state_b, y_state_b = np.random.multivariate_normal([2.5,0], 0.15*np.eye(2), 500).T

fig = go.Figure()

fig.add_trace(
    go.Contour(
        z=E_illustrate,
        x=x_illustrate, 
        y=y_illustrate,
        reversescale = True,
        colorscale = "viridis",
        contours=dict(
            start=-10,
            end=10,
            size=2,
        ),
    )
)

fig.add_trace(
    go.Scatter(
        x=x_state_a,
        y=y_state_a,
        mode="markers",
        name='',
        marker_color="blue"
    ))

fig.add_trace(
    go.Scatter(
        x=x_state_b,
        y=y_state_b,
        mode="markers",
        name='',
        marker_color="white"
    ))

fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-8,8])

fig.update_layout(
    #title="Plot Title",
    xaxis_title="x",
    yaxis_title="y",
#     font=dict(
#         family="Courier New, monospace",
#         size=18,
#         color="#7f7f7f"
#     )
)

fig.show()

## Define network architecture
As specified on pg. 4 of the SI: 

In [89]:
n_hidden = 100
temperature = 1.0
l_hidden = 3 

There the training schedule is also specified. Let "1" in the following denote the first set of iterations (where only the ML loss is used) while "2" refers to the second second of iterations (where both the ML and KL loss is utilized): 

In [None]:
iter1 = 200
iter2 = 500

batch1 = 128 
batch2 = 1000 

lr1 = 0.01
lr2 = 0.001

w_kl_1 = 0
w_kl_2 = 1