In [1]:
%load_ext autoreload
%autoreload 2

In [26]:
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

sys.path.append("../core")
sys.path.append("../core/simulations")

import simulation_library, potentials, integrators, data_logging, visuals

from data import gen_double_moon_samples
from losses import getLoss
from trainer import Trainer
from optimizers import getOpt
from network_base import RealNVP
from network_logging import *

plt.style.use("fivethirtyeight")

In [37]:
def crescent_moon_example():
    data = gen_double_moon_samples(10000)
    loss = getLoss().basic_loss()
    opt = getOpt().rmsprop(1e-4)
    model = RealNVP(loss, opt, model_name='double_moon')
    model = LogLoss(model)
    model = LogTargetPlot(model)
    trainer = Trainer(model, data)
    trainer.train(25)

In [38]:
crescent_moon_example()

In [53]:
def double_well_potential(freq = 100):
    system_builder = simulation_library.SystemFactory()
    coords = []
    for i, s in enumerate([[-2, 0], [-1, 0], [1, 0], [2, 0]]):
        system = simulation_library.System(dim = 2)
        wca = potentials.WCAPotential(1, 1)
        system.central_potential = potentials.DoubleWellPotential(a = 1,
                                                                  b = 6, 
                                                                  c = 1,
                                                                  d = 1)
        
        system.add_particle(simulation_library.Particle(wca, np.array(s)))
        system.get_integrator("metropolis", None, temp = .5)
        coords_logger = data_logging.CoordinateLogger(system, freq)
        energy_logger = data_logging.EnergyLogger(system, freq)
        system.registerObserver(coords_logger)
        system.registerObserver(energy_logger)
        system.run(5000 * freq)
        c = np.array(coords_logger.coordinates, dtype=np.float32).squeeze()[:-freq]
        if(len(coords) == 0):
            coords = c
        else:
            coords = np.concatenate((coords, c), axis=0)
    
    loss = getLoss().basic_loss()
    opt = getOpt().rmsprop(1e-4)
    model = RealNVP(loss, opt, model_name="double_well")
    model = LogLoss(model)
    model = LogTargetPlot(model)
    trainer = Trainer(model, coords)
    trainer.train(20)

In [54]:
double_well_potential()