In this notebook, we'll show you how to optimize particle shapes in order to design bulk materials properties. In particular, we will design patchy particles that stabilize an octahedral cluster.

We're considering patchy particles to be isotropic particles that interact repulsively that have interacting patches rigidly attached to their surface. We're going to optimize over the positions of the interacting patches.



# Imports

The first thing we need to do is import the relevant packages.


In [None]:
!pip install jax-md
!pip install --upgrade "jax[cuda]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
from jax import jit, random, grad, value_and_grad, remat, jacfwd, vmap, lax
from jax.example_libraries import optimizers
from jax_md import space, smap, energy, minimize, quantity, simulate, partition, rigid_body, util

import numpy as np

Collecting jax-md
  Downloading jax_md-0.2.8-py3-none-any.whl (150 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.0/151.0 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting jraph (from jax-md)
  Downloading jraph-0.0.6.dev0-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dataclasses (from jax-md)
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting einops (from jax-md)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ml-collections (from jax-md)
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting e3nn-jax (fro

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]==0.4.13
  Downloading jax-0.4.13.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxlib==0.4.13+cuda11.cudnn86 (from jax[cuda]==0.4.13)
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.13%2Bcuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl (188.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.3/188.3 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: jax
  Building wheel for jax (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.4.13-py3-none-any.whl size=1518704 sha256=b735d1b8544169796086c914

# Define the octahedron

The goal of this section is to design particles that can stabilize an octahedron. We're going to initialize particles into an octahedral configuration. The initial parameter values for those particles will cause the cluster to melt. We will optimize to find parameters that can stabilize the cluster, so it doesn't melt!


The first step is to define the octahedron. We'll need to specify particle coordinates and orientations. We'll need to consider both the central particles and the patches on those central particles.


In [None]:
'''
We first define the features of an octahedral cluster without
considering the patches on the patchy particles.
'''

#Radius of a particle without the patches
RADIUS = 2.5

# Coordinates of an octahedron
OCT_COORDS = 2.0/jnp.sqrt(2.0) * RADIUS \
                * jnp.array([[1.0, 0.0, 0.0],
                            [-1.0, 0.0, 0.0],
                            [0.0, 1.0, 0.0],
                            [0.0, -1.0, 0.0],
                            [0.0, 0.0, 1.0],
                            [0.0, 0.0, -1.0]])

Before we get any further, let's visualize the octahedron itself:

In [None]:
from jax_md.colab_tools import renderer


diameter = jnp.ones((len(OCT_COORDS))) * 2 * RADIUS
renderer.render(20.,
                {'particle':renderer.Sphere(OCT_COORDS + 10., diameter=diameter)},
                resolution=(512, 512)
                )

We can see that the central particles are in an ideal octahedral cluster. Let's now add in the patch information.

In [None]:
'''
Next, let's add the patches.
'''
# Number of patches. We will have two types of patches
# that form two interacting rings on the central particle
NUM_PATCHES = 10 #number of patches per ring
TYPE_PATCH = 2 #number of patch species

'''
This function takes in two angles that represent the locations of the two
interacting rings of patches. It returns a `RigidBody' object called a
`rigid_point_union'. This object stores the locations of each of the
patches, their species and their masses.

In essence, this function lets us go from a simple, low-dimensional
representation of the patchy particle to the full complexity needed
for the simulation.
'''

@jit # use JIT compilation to speed up the simulation, is one kind of decorator
def thetas_to_shape(thetas, radius=0.5):
    # space the patches evenly along each ring
    phis = jnp.linspace(0, jnp.pi*2, NUM_PATCHES+1)[:-1]
    theta_B = thetas[0]
    theta_C = thetas[1]

    # Convert angles to xyz coordinates, and store coordinates in `patch_positions'
    patch_positions = jnp.zeros((NUM_PATCHES*2, 3), dtype = jnp.float64)
    patch_positions = patch_positions.at[:NUM_PATCHES,0].set(radius*jnp.cos(phis)*jnp.sin(theta_B))
    patch_positions = patch_positions.at[:NUM_PATCHES,1].set(radius*jnp.sin(phis)*jnp.sin(theta_B))
    patch_positions = patch_positions.at[:NUM_PATCHES,2].set(radius*jnp.cos(theta_B))
    patch_positions = patch_positions.at[NUM_PATCHES:,0].set(radius*jnp.cos(phis)*jnp.sin(theta_C))
    patch_positions = patch_positions.at[NUM_PATCHES:,1].set(radius*jnp.sin(phis)*jnp.sin(theta_C))
    patch_positions = patch_positions.at[NUM_PATCHES:,2].set(radius*jnp.cos(theta_C))
    positions = jnp.concatenate((jnp.array([[0.0, 0.0, 0.0]]), patch_positions), axis=0)

    # Define two species: one for each patch type. Each species has its own ring.
    species = jnp.zeros(NUM_PATCHES*2 + 1, dtype = jnp.int32)
    species = species.at[1:NUM_PATCHES+1].set(1)
    species = species.at[NUM_PATCHES+1:].set(2)
    species = jnp.array(species, dtype = jnp.int32)

    # We give the patches very slightly different masses to break the symmetry
    # of the initial patch positions. Highly symmetric stuctures can lead to
    # poorly defined gradients, because the initial configuration
    # is at a saddle point.
    patch_mass = jnp.linspace(0.1, 1.0, NUM_PATCHES*2)*1e-8
    mass = jnp.concatenate((jnp.array([1.0]), patch_mass), axis = 0)

    shape = rigid_body.point_union_shape(positions, mass).set(point_species=species)
    return shape


Before we simulate anything, let's visualize the octahedron:

In [None]:
# Specify angles (in radians) of the two rings of patches on the central particle
thetas = jnp.array([42.0*jnp.pi/180, 53.7*jnp.pi/180])
oct_shape = thetas_to_shape(thetas, radius=RADIUS)

# Now that we have patches, we have to specify the orientations of the octahedron
# particles. The quaternion below orients the particles to be in a perfect
# octahedral configuration
OCT_ORIENTATION = rigid_body.Quaternion(jnp.array([[ 6.12323400e-17, -1.02905911e-01,  6.03345922e-01,  7.90812286e-01],
                  [6.12323400e-17, 9.94691095e-01, 6.24192394e-02, 8.18135993e-02],
                  [ 6.12323400e-17,  7.39570669e-01, -6.61749302e-01,  1.22975964e-01],
                  [6.12323400e-17, 6.52786354e-01, 7.49725191e-01, 1.08545451e-01],
                  [ 6.12323400e-17,  7.56419645e-01,  9.59613460e-02, -6.47009074e-01],
                  [6.12323400e-17, 6.41864569e-01, 8.14285939e-02, 7.62482301e-01]]))

# A `RigidBody' object consists of a center of mass position and an orientation.
# We can store our octahedron information in a `RigidBody' object, and transform
# from the `RigidBody' to a set of positions of all the component particles in order
# to visualize it.
octahedron_rigid_body = rigid_body.RigidBody(center=OCT_COORDS, orientation=OCT_ORIENTATION)
octahedron_positions = vmap(rigid_body.transform, (0, None))(octahedron_rigid_body, oct_shape).reshape(-1, 3)


**Note about how orientation is defined in JAX-MD.**

*What you think is the default orientation might not be the one in the simulation*

JAX-MD allows quite general definition of a RigidBody, where your initial definition might have a non-diagonalized moment of inertia tensor.  JAX-MD default has a helper function to check that and diagonalize the moment of inertia tensor if needed. Because of this, your originally defined RigidBody might be rotated in its body framework.

If you want to randomize your initial particle orientation, this will not cause any problem.  If you want to specify your initial orientation and the end result is not what you think, you might want to check the diagonalization helper function.

In [None]:
species = jnp.array(list(oct_shape.point_species) * len(OCT_COORDS)).flatten()
diameters = jnp.where(species==0, 1., 0.2) * RADIUS * 2

renderer.render(20.,
                {'particle':renderer.Sphere(octahedron_positions, diameter=diameters)},
                resolution=(512, 512)
                )

# Simulate the octahedron

Now that we have an octahedron, we can run a simulation and see whether that octahedron is stable.

We'll now need to specify all the parameters you typically need for a molecular dynamcis simulation: the temperature, the step size, the number of simulation steps, the integrator, and so on.

In [None]:
# Parameters of the simulation

# Number density of the simulation: defines the box size
NUMBER_DENSITY = 0.05
get_box_size = lambda phi, N, rad: ( N * jnp.pi * 4 * rad**3 / phi / 3.0) ** (1/3)
BOX_SIZE = get_box_size(NUMBER_DENSITY, len(OCT_COORDS), RADIUS)
displacement, shift = space.periodic(BOX_SIZE) # define a periodic box

kT = 0.8 #Temperature
dt = 0.0001 #timestep of the simulation
NUM_STEPS = 200000
GAMMA = 5.0 # friction coefficient for the integrator
SAVE_EVERY = 100 #view every SAVE_EVERY frame in the trajectory when visualizing

# Parameters of the energy function
D0 = 4.0 #well depth of the interaction energy

# Set a random key
key = random.PRNGKey(0)

With these parameters specified, we can run a simulation starting from the initial octahedron we defined above.

The first thing we need is an energy function. We consider an energy with two components: an attractive Morse potential that defines how the patches interact with other patches, and a WCA potential that causes central particles to behave similar to hard spheres.

In [None]:
# Define the energy function

# The two types of patches interact with the same interaction strength, D0.
# The patches only interact with other patches of the same type.
morse_interaction_matrix = jnp.array([[D0, 0.0],
                                      [0.0, D0]])
morse_eps = jnp.pad(morse_interaction_matrix, pad_width=(1, 0)) #center particles don't interact via morse potential

# Interaction matrix for the LJ(WCA) interaction is only nonzero for the
# central particle-central particle term.
lj_eps = jnp.zeros((len(thetas) + 1, len(thetas) + 1))
lj_eps = lj_eps.at[0, 0].set(1.0)

pair_energy_lj = energy.lennard_jones_pair(displacement, species=1+len(thetas), sigma=RADIUS*2.0, epsilon=lj_eps, r_cutoff =RADIUS*2.0*2.0**(1/6.0))
pair_energy_morse = energy.morse_pair(displacement, species=1+len(thetas), sigma=0.0, epsilon=morse_eps, alpha=5.0, r_cutoff=1.2)
pair_energy_fn = lambda R, **kwargs: pair_energy_lj(R, **kwargs) + pair_energy_morse(R, **kwargs)

# Convert the energy function to a form that acts on `RigidBody' objects
energy_fn = rigid_body.point_energy(pair_energy_fn, oct_shape)

# Confirm that we can compute the energy of our initial state
eng = energy_fn(octahedron_rigid_body)
print('Energy of the initial state: {}'.format(eng))


Now that we have a working energy function, we can run a short simulation and visualize it.

In [None]:
# The mass and friction coefficient need to be in the form of a `RigidBody' object.
# The rotational friction coefficient is typically three times the translational
# coefficient.
gamma = rigid_body.RigidBody(jnp.array([GAMMA]), jnp.array([GAMMA/3.0]))

# Here, we simulate using langevin dynamics
init_fn, step_fn = simulate.nvt_langevin(energy_fn, shift, dt, kT, gamma = gamma)
step_fn = jit(step_fn)
state = init_fn(key, octahedron_rigid_body, mass=oct_shape.mass())

do_step = lambda state, t: (step_fn(state), state.position)
do_step = jit(do_step)
final_state, trajectory = lax.scan(do_step, state, jnp.arange(NUM_STEPS))

In [None]:
species = jnp.array(list(oct_shape.point_species) * len(OCT_COORDS)).flatten()
diameters = jnp.where(species==0, 1., 0.2) * RADIUS * 2

#transform the trajectory from RigidBody objects to just xyz positions
trajectory_positions = (vmap(vmap(rigid_body.transform, (0, None)), (0, None))(trajectory[::SAVE_EVERY], oct_shape)).reshape(-1, len(OCT_COORDS)*(1 + 2*NUM_PATCHES), 3)

renderer.render(20.,
                {'particle':renderer.Sphere(trajectory_positions, diameter=diameters)},
                resolution=(512, 512)
                )

To make optimization simpler, we're going to combine the steps we needed for the simulation into a single function. At the same time, we'll add gradient rematerialization to make the optimization more efficient.

In [None]:
# Define how often to do gradient rematerialization over the course of the simulation
INNER_STEPS = 100

'''
function to run a simulation of an octahedron composed of patchy particles.

Arguments:
  - thetas: positions of the two rings of particles
  - initial_position: configuration of the patchy particles at the start of the simulation
  - num_steps: number of simulation steps
  - key: random key
  - kT: temperature

Returns:
  - state: the final state of the simulation
'''
def run_sim(thetas, initial_position, num_steps, key, kT = 1.0):
    eng_mat = jnp.array([[D0, 0.0],
                          [0.0, D0]])
    shape = thetas_to_shape(thetas, radius=RADIUS)
    displacement, shift = space.periodic(BOX_SIZE)

    morse_eps = jnp.pad(eng_mat, pad_width=(1, 0)) #center particles don't interact via morse potential
    lj_eps = jnp.zeros((len(thetas) + 1, len(thetas) + 1))
    lj_eps = lj_eps.at[0, 0].set(1.0)

    pair_energy_lj = energy.lennard_jones_pair(displacement, species=1+len(thetas), sigma=RADIUS*2.0, epsilon=lj_eps, r_cutoff =RADIUS*2.0*2.0**(1/6.0))
    pair_energy_morse = energy.morse_pair(displacement, species=1+len(thetas), sigma=0.0, epsilon=morse_eps, alpha=5.0, r_cutoff=1.2)
    pair_energy_fn = lambda R, **kwargs: pair_energy_lj(R, **kwargs) + pair_energy_morse(R, **kwargs)
    energy_fn = rigid_body.point_energy(pair_energy_fn, shape)

    gamma = rigid_body.RigidBody(jnp.array([GAMMA]), jnp.array([GAMMA/3.0]))
    init_fn, step_fn = simulate.nvt_langevin(energy_fn, shift, dt, kT, gamma = gamma)
    step_fn = jit(step_fn)
    state = init_fn(key, initial_position, mass=shape.mass())

    do_step = lambda state, t: (step_fn(state), 0.)
    do_step = jit(do_step)

    # inner/outer: corresponds to forward/reverse (probably reversed order) mode AD
    inner_steps = jnp.arange(INNER_STEPS)
    outer_steps = jnp.arange(int(num_steps/INNER_STEPS))

    def do_outer_step(state, i):
        state, _ = lax.scan(do_step, state, inner_steps)
        return state, 0.
    do_outer_step = jit(remat(do_outer_step))

    state, losses = lax.scan(do_outer_step, state, outer_steps)
    return state

# vmap over the random keys in order to easily run
# ensembles of simulations
run_sim = jit(run_sim, static_argnums=2)
v_run_sim = jit(vmap(run_sim, in_axes=(None, None, None, 0)), static_argnums=2)

# Similar to v_run_sim but allows for an ensemble of starting configurations
many_states_run_sim = jit(vmap(run_sim, in_axes=(None, 0, None, 0)), static_argnums=2)

In [None]:
## Now, instead of doing the series of steps we did earlier, we can just call this function
# to run a simulation:
final_state = run_sim(thetas, octahedron_rigid_body, NUM_STEPS, key, kT = kT)

# Define a loss

We can now start from a set of parameters and run a simulation of our octahedron. In order to optimize those parameters, we need a metric that defines how stable our octahedron is.

We choose a loss function that computes nearest neighbor distances of the final state of the system versus those of an ideal reference octahedron. The loss is the squared difference between those two values.

The first component we need for the loss is the list of reference nearest neighbor distances.

In [None]:
'''
function to compute nearest neighbor distances for a structure.
Arguments:
  - Coordinates that define a reference structure
Returns:
  - List of sorted nearest neighbor distances
'''
def get_desired_dists(ref_shape):
    displacement, shift = space.periodic(BOX_SIZE)
    vdisp = space.map_product(displacement)
    ds = vdisp(ref_shape, ref_shape)
    dists = jnp.sort(space.distance(ds))
    return dists

# OCT_COORDS contains the position coordinates of an ideal octahedron.
# We use these to compute our reference nearest neighbor distances
REF_DIST = get_desired_dists(OCT_COORDS)

In [None]:
'''
function to compute the loss (i.e. squared difference from a reference particle)
for a given particle.

Arguments:
  - center_particle: position of the particle of interest
  - R: coordinates of the particles in the system
Returns:
  - squared difference from the nearest neighbor distances for the
    associated particle in the reference structure
'''
@jit
def loss_per_particle(center_particle, R):
    displacement, shift = space.periodic(BOX_SIZE)
    vdisp = vmap(displacement, (None, 0))
    ds = vdisp(center_particle, R)
    dists = jnp.sort(space.distance(ds))
    nearest_nbrs_square = np.sum(jnp.sqrt((dists[:len(OCT_COORDS)] - REF_DIST)**2))
    return nearest_nbrs_square

# We vmap this loss over all the particles in the system,
# vmap again over the ensemble of simulations, and finally we
# average the loss over the ensemble of simulations
sys_loss = vmap(loss_per_particle, (0, None))
v_loss = vmap(sys_loss, (0, 0))
avg_loss = lambda R_batched: jnp.mean(v_loss(R_batched, R_batched))
avg_loss = jit(avg_loss) #jit to compile the loss function


In [None]:
# Compute the loss for the final state of the simulation we observed.
# Note that because we ran one simulation and not an ensemble, we use
# sys_loss, which was only vmapped once.
loss = jnp.mean(sys_loss(final_state.position.center, final_state.position.center))
print(f'Sample loss: {loss}')

Sample loss: 3.717162776238019


# Set up optimization

We now have all the pieces we need to set up our optimization. Given a set of parameters, we can run a simulation and compute a loss. Now, we need to compute the gradient of the loss with respect to the parameters, and update the parameters accordingly.

To reduce memory constraints, we only optimize over a portion of the simulation. So before we set up the optimization loop, we're going to define two functions: one that runs a partial simulation (the part of the simulation we don't optimize over), and one that takes in the result of the partial simulation, continues running the simulation, and outputs a loss. This second function is the portion we'll optimize over.

In [None]:
# Parameters of the optimization

# Number of simulation steps to include in the optimization
NUM_STEPS_TO_OPT = 1000
NUM_STEPS_TO_RUN = NUM_STEPS - NUM_STEPS_TO_OPT

# Number of simulations in the ensemble
BATCH_SIZE = 1

# Learning rate for the Adam optimizer
LEARNING_RATE = 0.01

# Number of steps of the optimization loop to run
OPT_STEPS = 10

# Print out the loss, gradient and parameters every SAVE_EVERY steps
SAVE_EVERY = 1

In [None]:
'''
function that takes in the input parameters and runs a portion of
'''
def run_partial_sim(params, key, batch_size):
    # Since we're running long simulations/or many replicas, if we store all of
    # the trajectories at all time steps, the memory consumation would be huge
    # so here we only store the last a few time steps, while all the previous timesteps are run but not stored
    init_positions = octahedron_rigid_body #initialize in an ideal octahedron configuration
    sim_keys = random.split(key, batch_size)
    states = v_run_sim(params, init_positions, NUM_STEPS_TO_RUN, sim_keys)
    return states.position

def get_mean_loss(params, initial_positions, keys):
    states = many_states_run_sim(params, initial_positions, NUM_STEPS_TO_OPT, keys)
    return avg_loss(states.position.center)

Now that we've defined the function we want to optimize over, we can compute a gradient!

In [None]:
# g_mean_loss computes the gradient of the funciton `get_mean_loss' via reverse mode automatic differentiation
# It also returns the values of the losses
g_mean_loss = jit(value_and_grad(get_mean_loss))

# We first run the partial simulations to get our initial configurations for the optimization
partial_sim_positions = run_partial_sim(thetas, key, BATCH_SIZE)

# Starting from those initial configurations, we do our gradient computation
keys = random.split(key, BATCH_SIZE)
losses, grads = g_mean_loss(thetas, partial_sim_positions, keys)

In [None]:
print(f'Loss values: {losses}')
print(f'Gradient values: {grads}')

Loss values: 5.492759755061078
Gradient values: [-5.34337462 10.38204392]


The last piece we need is an optimizer that takes our gradient and updates the parameters. We use an Adam optimizer.

To make this procedure simpler, we wrap our optimization loop in a function that takes in initial parameters and other relevant information.

In [None]:
'''
function that runs an optimization loop

Arguments:
  - input_params: initial parameter values for the paramerters to optimizer over
  - key: random key
  - opt_steps: number of steps of the optimization loop to run
  - batch_size: number of simulations in the ensemble to optimize over
  - save_every: prints the loss, gradient, and paramter values every save_every steps
  - learning_rate: learning rate for the Adam optimizer

Returns:
  - min_loss: minimum loss value observed over the course of the optimization
  - min_loss_params: parameters associated with the minimum loss value
'''

def optimize(input_params, key, opt_steps, batch_size, save_every, learning_rate=0.1):

    # Define a learning rate schedule that starts at learning_rate and decreases every opt_steps / 3
    learning_rate_schedule = jnp.ones(opt_steps)*learning_rate
    ind = int(opt_steps / 3)
    learning_rate_schedule = learning_rate_schedule.at[ind:2*ind].set(learning_rate * 0.5)
    learning_rate_schedule = learning_rate_schedule.at[2*ind:].set(learning_rate * 0.1)
    learning_rate_fn = lambda i: learning_rate_schedule[i]

    # Initialize the Adam optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate_fn)

    # If the gradient is too large, clip it to size 'clip', and preserve the sign
    def clip_gradient(g, clip=10000.0):
        return jnp.array(jnp.where(jnp.abs(g) > clip, jnp.sign(g)*clip, g))

    # Run a single optimization step
    def step(i, opt_state, key, batch_size=10, save_every=10, cmd='w'):

        params = get_params(opt_state)

        # Run the partial simulations to get initial configurations
        key, split = random.split(key)
        simulation_keys = random.split(split, batch_size)
        initial_positions = run_partial_sim(params, key, batch_size)

        # Compute losses and gradients
        loss, gs = g_mean_loss(params, initial_positions, simulation_keys)
        gs = vmap(clip_gradient)(gs)

        if BATCH_SIZE > 1:
          g = jnp.mean(jnp.array(gs), axis=0)
        else: g = gs

        # Print current state
        if(i%save_every==0):
            print(f'Loss: {loss}')
            print(f'Gradient: {g}')
            print(f'Parameters: {params}')

        # Update parameters based on teh gradient and return
        return opt_update(i, g, opt_state), loss

    opt_state = opt_init(input_params)
    min_loss_params = input_params
    min_loss = 1e6

    # Run the optimizaation
    for i in range(0, opt_steps):
        # Run one optimization step
        key, split=random.split(key)
        new_opt_state, loss = step(i, opt_state, split, batch_size=batch_size, save_every=save_every)

        # If the new loss is lower than any previous loss, save it
        if loss < min_loss:
            min_loss = loss
            min_loss_params = get_params(opt_state)
        opt_state = new_opt_state

    # return the minimum loss and the parameters associated with that loss
    # Note that this is a choice: we could have returned the final parameters rather than the minimum loss parameters
    return min_loss, min_loss_params

In [None]:
min_loss, min_loss_params = optimize(thetas, key, OPT_STEPS, BATCH_SIZE, SAVE_EVERY, learning_rate=LEARNING_RATE)

Loss: 73.31760111046208
Gradient: [-0.68197723 -1.31525269]
Parameters: [0.79 0.96]
Loss: 29.040727302684818
Gradient: [12.07233128 28.7577191 ]
Parameters: [0.791 0.961]


Now we have our optimal parameters! As we discuss in the associated paper, we can now run forward simulations with these parameters either in Jax-MD or in another MD engine that allows for longer/larger MD simulations.