In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.special
import dataclasses
from functools import partial

import vtk
from vtk.util import numpy_support

import jaxdem as jd

jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np

In [2]:
# build a set of states and systems using vmap

N_systems = 1
N = 100
phi = 0.4
dim = 2
e_int = 1.0
dt = 1e-2

def build_microstate(i):
    # assign bidisperse radii
    rad = jnp.ones(N)
    rad = rad.at[: N // 2].set(0.5)
    rad = rad.at[N // 2:].set(0.7)
    
    # set the box size for the packing fraction and the radii
    volume = jnp.sum((jnp.pi ** (dim / 2) / jax.scipy.special.gamma(dim / 2 + 1)) * rad ** dim)
    L = (volume / phi) ** (1 / dim)
    box_size = jnp.ones(dim) * L

    # create microstate
    key = jax.random.PRNGKey(np.random.randint(0, 1000000))
    pos = jax.random.uniform(key, (N, dim), minval=0.0, maxval=L)
    mass = jnp.ones(N)
    mats = [jd.Material.create("elastic", young=e_int, poisson=0.5, density=1.0)]
    matcher = jd.MaterialMatchmaker.create("harmonic")
    mat_table = jd.MaterialTable.from_materials(mats, matcher=matcher)
    
    # create system and state
    state = jd.State.create(pos=pos, rad=rad, mass=mass, volume=volume)
    system = jd.System.create(
        state_shape=state.shape,
        dt=dt,
        linear_integrator_type="linearfire",
        domain_type="periodic",
        force_model_type="spring",
        collider_type="naive",
        mat_table=mat_table,
        domain_kw=dict(
            box_size=box_size,
        ),
    )
    return state, system

# state, system = jax.vmap(build_microstate)(jnp.arange(N_systems))
# state, system = jax.vmap(build_microstate)(jnp.arange(N_systems))

state, system = build_microstate(0)
state, system, final_pf, final_pe = jd.utils.jamming.bisection_jam(state, system)

Step: 1 -  phi=0.3999999999999999, PE=0.0
Step: 2 -  phi=0.4009999999999999, PE=0.0
Step: 3 -  phi=0.4019999999999999, PE=0.0
Step: 4 -  phi=0.4029999999999999, PE=0.0
Step: 5 -  phi=0.4039999999999999, PE=0.0
Step: 6 -  phi=0.4049999999999999, PE=0.0
Step: 7 -  phi=0.4059999999999999, PE=0.0
Step: 8 -  phi=0.4069999999999999, PE=0.0
Step: 9 -  phi=0.4079999999999999, PE=0.0
Step: 10 -  phi=0.4089999999999999, PE=0.0
Step: 11 -  phi=0.4099999999999999, PE=0.0
Step: 12 -  phi=0.4109999999999999, PE=0.0
Step: 13 -  phi=0.4119999999999999, PE=0.0
Step: 14 -  phi=0.4129999999999999, PE=0.0
Step: 15 -  phi=0.4139999999999999, PE=0.0
Step: 16 -  phi=0.4149999999999999, PE=0.0
Step: 17 -  phi=0.4159999999999999, PE=0.0
Step: 18 -  phi=0.4169999999999999, PE=0.0
Step: 19 -  phi=0.4179999999999999, PE=0.0
Step: 20 -  phi=0.41899999999999993, PE=0.0
Step: 21 -  phi=0.41999999999999993, PE=0.0
Step: 22 -  phi=0.42099999999999993, PE=0.0
Step: 23 -  phi=0.42199999999999993, PE=0.0
Step: 24 -  phi=

In [4]:
r_ij = system.domain.displacement(state.pos[None, :, :], state.pos[:, None, :], system)
d_ij = jnp.linalg.norm(r_ij, axis=-1)
sigma_ij = state.rad[None, :] + state.rad[:, None]
contact_mask = (d_ij < sigma_ij) & (d_ij > 0)
h_ij = (sigma_ij - d_ij) * contact_mask

jnp.sum(system.collider.compute_potential_energy(state, system)) / state.N

jnp.sum(h_ij ** 2) / state.N / 2


Array(9.99989306e-17, dtype=float64)

In [None]:
# count contacts

# remove rattlers

# calculate hessian