In [2]:
import os, sys

sys.path.append(os.getcwd())
from diffmpm.material import SimpleMaterial,LinearElastic
from diffmpm.particle import Particles
from diffmpm.element import Quadrilateral4Node
from diffmpm.constraint import Constraint
from diffmpm.mesh import Mesh2D
from diffmpm.solver import MPMExplicit
import jax.numpy as jnp
import numpy as np

mesh_config = {}
density = 1
# poisson_ratio = 0
youngs_modulus = 1000
material = LinearElastic(
    {
        "id":0,
        "youngs_modulus": youngs_modulus,
        "density": density,
        "poisson_ratio": 0.0,
    }
)
particle_loc = jnp.array([[0.0, 0.0], [0.5, 0.0], [0.0, 0.5], [0.5, 0.5]]).reshape(
    4, 1, 2
)
particles = Particles(particle_loc, material, jnp.zeros(particle_loc.shape[0],dtype=jnp.int32))
particles.velocity=particles.velocity.at[:].set(0.0)
constraints = [(0, Constraint(1, 0.0))]
external_loading = jnp.array([0.0, -9.8]).reshape(1,2)
element = Quadrilateral4Node([1, 1], 1, [1,1], constraints)
mesh_config["particles"] = [particles]
mesh_config["elements"] = element
mesh_config["particle_surface_traction"] = []
mesh = Mesh2D(mesh_config)
solver = MPMExplicit(mesh, 0.01,sim_steps=10)

real_ans = solver.solve_jit(external_loading)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
from jax import jit
@jit
def compute_loss(E,solver,target_stress):
    material_props=solver.mesh.particles[0].material.properties
    material_props["youngs_modulus"]=E
    solver.mesh.particles[0].material=LinearElastic(material_props)
    external_loading_local=jnp.array([0.0, -9.8]).reshape(1,2)
    solver.mesh.particles[0].velocity = mesh.particles[0].velocity.at[:].set(0.0)
    result = solver.solve_jit(external_loading_local)
    stress = result["stress"]
    loss = jnp.linalg.norm(stress - target_stress)
    return loss

In [10]:
import optax
from tqdm import tqdm
from jax import jit, value_and_grad

def optax_adam(params,niter,mpm,target_vel):
    start_alpha=0.1
    optimizer=optax.adam(start_alpha)
    opt_state=optimizer.init(params)
    param_list=[]
    loss_list=[]
    t=tqdm(range(niter),desc=f"E: {params}")
    for _ in t:
        lo,grads=value_and_grad(compute_loss)(params,mpm,target_vel)
        updates,opt_state=optimizer.update(grads,opt_state)
        params=optax.apply_updates(params,updates)
        t.set_description(f"E: {params}")
        param_list.append(params)
        loss_list.append(lo)
    return param_list,loss_list
params=1050.0
parameter_list,loss_list=optax_adam(params,500,solver,real_ans["stress"])

E: 1000.3665161132812: 100%|██████████| 500/500 [00:11<00:00, 43.43it/s]


In [11]:
parameter_list

[Array(1049.9, dtype=float32),
 Array(1049.8, dtype=float32),
 Array(1049.7001, dtype=float32),
 Array(1049.6001, dtype=float32),
 Array(1049.5001, dtype=float32),
 Array(1049.4001, dtype=float32),
 Array(1049.3002, dtype=float32),
 Array(1049.2002, dtype=float32),
 Array(1049.1002, dtype=float32),
 Array(1049.0002, dtype=float32),
 Array(1048.9003, dtype=float32),
 Array(1048.8003, dtype=float32),
 Array(1048.7003, dtype=float32),
 Array(1048.6003, dtype=float32),
 Array(1048.5004, dtype=float32),
 Array(1048.4004, dtype=float32),
 Array(1048.3004, dtype=float32),
 Array(1048.2004, dtype=float32),
 Array(1048.1005, dtype=float32),
 Array(1048.0005, dtype=float32),
 Array(1047.9005, dtype=float32),
 Array(1047.8005, dtype=float32),
 Array(1047.7006, dtype=float32),
 Array(1047.6006, dtype=float32),
 Array(1047.5006, dtype=float32),
 Array(1047.4006, dtype=float32),
 Array(1047.3007, dtype=float32),
 Array(1047.2007, dtype=float32),
 Array(1047.1007, dtype=float32),
 Array(1047.0007, dt