In [21]:

import jax.numpy as jnp
from jax import grad, jit, vmap, lax
import jax.scipy as jsp
import jax.scipy.optimize as jsp_opt
import optax 
import jaxopt
from jaxopt import ScipyBoundedMinimize
import matplotlib.pyplot as plt
import jax

def jdprint(obj, string = ""):
    jax.debug.print(string+': {x}', x=obj)

@jit
def mpm(E):
    # nsteps
    nsteps = 10
    
    # mom tolerance
    # tol = 1e-12

    # Domain length
    Lx = 1
    Ly = 1

    # Material properties
    rho = 1
    poisson_ratio = 0
    bulk_modulus = E / (3 * (1 - 2 * poisson_ratio))
    G = E / (2 * (1 + poisson_ratio))
    a1 = bulk_modulus + (4 / 3) * G
    a2 = bulk_modulus - (2 / 3) * G

    # Computational grid
    nelementsx = 1
    nelementsy = 1
    nelements = nelementsx * nelementsy
    
    dx = Lx / nelementsx
    dy = Ly / nelementsy
    
    # Create equally spaced nodes
    x_n, y_n = jnp.meshgrid(jnp.linspace(0, Lx, nelementsx + 1), jnp.linspace(0, Ly, nelementsy + 1))
    x_n = x_n.flatten()
    y_n = y_n.flatten()
    x_n_1D = jnp.linspace(0, Lx, nelementsx + 1)
    y_n_1D = jnp.linspace(0, Ly, nelementsy + 1)
    nnodes = len(x_n)
    
    # jdprint(x_n, 'xn')
    # jdprint(y_n, 'yn')
    # jdprint(x_n_1D, 'xn1d')
    # jdprint(y_n_1D, 'yn1d')
    
    # Set-up a 2D array of elements with node ids
    elements = jnp.zeros((nelements, 4), dtype=int)
    for eidx in range(nelementsx):
        for eidy in range(nelementsy):
            eid = eidx + eidy * nelementsx
            elements = elements.at[eid, 0].set(eidx + eidy * (nelementsx + 1))
            elements = elements.at[eid, 1].set(elements[eid, 0] + 1)
            elements = elements.at[eid, 2].set(elements[eid, 0] + nelementsx + 1)
            elements = elements.at[eid, 3].set(elements[eid, 0] + nelementsx + 2)

    jdprint(elements)
            
    # Loading conditions
    v0x = 0              # initial velocity
    v0y = 0              # initial velocity
    # c  = jnp.sqrt(E/rho)   # speed of sound
    b1 = jnp.pi / (2 * Lx) # beta1
    b2 = jnp.pi / (2 * Ly) # beta2
    # w1 = b1 * c            # omega1
    # w2 = b2 * c            # omega2
    
    # Create material points at the center of each element
    nparticles = 4 * nelements  # number of particles, 4 per element
    
    # Id of the particle in the central element
    pmid = nelements // 2  # Midpoint of the material points
    
    # Material point properties
    x_p        = jnp.zeros(nparticles)                  # positions
    y_p        = jnp.zeros(nparticles)                  # positions
    vol_p      = jnp.ones(nparticles) * dx * dy * 0.25  # volume
    mass_p     = vol_p * rho                            # mass
    stress_px  = jnp.zeros(nparticles)                  # stress in xx
    stress_py  = jnp.zeros(nparticles)                  # stress in yy
    stress_pxy = jnp.zeros(nparticles)                  # stress in xy
    vel_px     = jnp.zeros(nparticles)                  # velocity
    vel_py     = jnp.zeros(nparticles)                  # velocity
    # stress_p  = jnp.zeros((nparticles, 3))       # stress
    # disp_p    = jnp.zeros((nparticles, 2))       # displacement
    # vel_p     = jnp.zeros((nparticles, 2))       # velocity
    # strain_p  = jnp.zeros((nparticles, 3))       # strain

    # Set-up a 2D array of elements with particle ids
    elements_particles = elements
    

    # Initialize particle positions
    # need to fix this to generalize to more than 1 element
    x_p_1D      = jnp.array([0.25, 0.75])
    y_p_1D      = jnp.array([0.25, 0.75])
    x_p, y_p = jnp.meshgrid(x_p_1D, y_p_1D)
    x_p = x_p.flatten()
    y_p = y_p.flatten()


    # set initial velocities
    vel_p   = jnp.array([v0x * jnp.sin(b1 * x_p), v0y * jnp.sin(b2 * y_p)])
    

    # Time steps and duration
    # dt_crit = jnp.max(jnp.array([dx / c, dy / c]))
    dt = 0.01
    
    # results
    # tt = jnp.zeros(nsteps)
    vt = jnp.zeros((nsteps, 2))
    xt = jnp.zeros((nsteps, 2))
    max_stress = jnp.zeros((nsteps, 2))



    
    def step(i, carry):
        x_p, y_p, mass_p, vel_px, vel_py, vol_p, stress_px, stress_py, vt, xt, max_stress = carry
        # reset nodal values
        mass_n  = jnp.zeros(nnodes)   # mass
        mom_nx   = jnp.zeros(nnodes)  # momentum
        mom_ny   = jnp.zeros(nnodes)  # momentum
        fint_nx  = jnp.zeros(nnodes)  # internal force
        fint_ny  = jnp.zeros(nnodes)  # internal force

        # iterate through each element
        for eid in range(nelements):


            # get nodal ids
            nid1, nid2, nid3, nid4 = elements[eid]

            # get particle ids
            for pid in elements_particles[eid]:
                # compute shape functions and derivatives
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/elements/2d/quadrilateral_element.tcc
                N1x = 1 - abs(x_p[pid] - x_n[nid1]) / dx
                N2x = 1 - abs(x_p[pid] - x_n[nid2]) / dx
                N3x = 1 - abs(x_p[pid] - x_n[nid3]) / dx
                N4x = 1 - abs(x_p[pid] - x_n[nid4]) / dx
                N1y = 1 - abs(y_p[pid] - y_n[nid1]) / dy
                N2y = 1 - abs(y_p[pid] - y_n[nid2]) / dy
                N3y = 1 - abs(y_p[pid] - y_n[nid3]) / dy
                N4y = 1 - abs(y_p[pid] - y_n[nid4]) / dy
                
                dN1x = -1/dx
                dN2x = 1/dx
                dN3x = -1/dx
                dN4x = 1/dx
                dN1y = -1/dy
                dN2y = -1/dy
                dN3y = 1/dy
                dN4y = 1/dy


                # map particle mass and momentum to nodes
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L521-L532
                mass_n = mass_n.at[nid1].set(mass_n[nid1] + (N1x * N1y) * mass_p[pid])
                mass_n = mass_n.at[nid2].set(mass_n[nid2] + (N2x * N2y) * mass_p[pid])
                mass_n = mass_n.at[nid3].set(mass_n[nid3] + (N3x * N3y) * mass_p[pid])
                mass_n = mass_n.at[nid4].set(mass_n[nid4] + (N4x * N4y) * mass_p[pid])

                mom_nx = mom_nx.at[nid1].set(mom_nx[nid1] + N1x * mass_p[pid] * vel_px[pid])
                mom_nx = mom_nx.at[nid2].set(mom_nx[nid2] + N2x * mass_p[pid] * vel_px[pid])
                mom_nx = mom_nx.at[nid3].set(mom_nx[nid3] + N3x * mass_p[pid] * vel_px[pid])
                mom_nx = mom_nx.at[nid4].set(mom_nx[nid4] + N4x * mass_p[pid] * vel_px[pid])
                mom_ny = mom_ny.at[nid1].set(mom_ny[nid1] + N1y * mass_p[pid] * vel_py[pid])
                mom_ny = mom_ny.at[nid2].set(mom_ny[nid2] + N2y * mass_p[pid] * vel_py[pid])
                mom_ny = mom_ny.at[nid3].set(mom_ny[nid3] + N3y * mass_p[pid] * vel_py[pid])
                mom_ny = mom_ny.at[nid4].set(mom_ny[nid4] + N4y * mass_p[pid] * vel_py[pid])

                # compute nodal internal force
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/solvers/mpm_scheme/mpm_scheme.tcc#L119-L141
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L675-L682

                fint_nx = fint_nx.at[nid1].set(fint_nx[nid1] - vol_p[pid] * (stress_px[pid] * dN1x + stress_pxy[pid] * stress_pxy[pid] * dN1y))
                fint_nx = fint_nx.at[nid2].set(fint_nx[nid2] - vol_p[pid] * (stress_px[pid] * dN2x + stress_pxy[pid] * stress_pxy[pid] * dN2y))
                fint_nx = fint_nx.at[nid3].set(fint_nx[nid3] - vol_p[pid] * (stress_px[pid] * dN3x + stress_pxy[pid] * stress_pxy[pid] * dN3y))
                fint_nx = fint_nx.at[nid4].set(fint_nx[nid4] - vol_p[pid] * (stress_px[pid] * dN4x + stress_pxy[pid] * stress_pxy[pid] * dN4y))
                fint_ny = fint_ny.at[nid1].set(fint_ny[nid1] - vol_p[pid] * (stress_py[pid] * dN1y + stress_pxy[pid] * stress_pxy[pid] * dN1x))
                fint_ny = fint_ny.at[nid2].set(fint_ny[nid2] - vol_p[pid] * (stress_py[pid] * dN2y + stress_pxy[pid] * stress_pxy[pid] * dN2x))
                fint_ny = fint_ny.at[nid3].set(fint_ny[nid3] - vol_p[pid] * (stress_py[pid] * dN3y + stress_pxy[pid] * stress_pxy[pid] * dN3x))
                fint_ny = fint_ny.at[nid4].set(fint_ny[nid4] - vol_p[pid] * (stress_py[pid] * dN4y + stress_pxy[pid] * stress_pxy[pid] * dN4x))

                # apply tractions?
                # https://github.com/cb-geo/mpm/blob/adfe6f126b1166a711271da0dfcc32e3d973ea0b/include/mesh.tcc#L1322-L1336
        
        # apply boundary conditions
        # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L857-L861
        # momentum conditions
        # mom_nx = mom_nx.at[::nelementsx+1].set(jnp.zeros(nelementsy + 1))  # Nodal velocity v = 0 in m * v at node 0.
        mom_ny = mom_ny.at[::nelementsy+1].set(jnp.zeros(nelementsx + 1))  # Nodal velocity v = 0 in m * v at node 0.

        # force conditions
        # fint_nx = fint_nx.at[::nelementsx+1].set(jnp.zeros(nelementsy + 1))  # Nodal force f = m * a, where a = 0 at node 0.
        fint_ny = fint_ny.at[::nelementsy+1].set(jnp.zeros(nelementsx + 1))  # Nodal force f = m * a, where a = 0 at node 0.

        mom_ny = mom_ny.at[-(nelementsy+1)::].set(-0.01 * jnp.ones((nelementsx + 1)))  # Nodal velocity v = 0 in m * v at node 0.

        # update nodal momentum
        mom_nx = mom_nx + fint_nx * dt
        mom_ny = mom_ny + fint_ny * dt


        # update particle velocity, position, and stress
        # iterate through each element
        for eid in range(nelements):
            # get nodal ids
            nid1, nid2, nid3, nid4 = elements[eid]

            # get particle ids
            for pid in elements_particles:
                
            

                # compute shape functions and derivatives
                N1x = 1 - abs(x_p[pid] - x_n[nid1]) / dx
                N2x = 1 - abs(x_p[pid] - x_n[nid2]) / dx
                N3x = 1 - abs(x_p[pid] - x_n[nid3]) / dx
                N4x = 1 - abs(x_p[pid] - x_n[nid4]) / dx
                N1y = 1 - abs(y_p[pid] - y_n[nid1]) / dy
                N2y = 1 - abs(y_p[pid] - y_n[nid2]) / dy
                N3y = 1 - abs(y_p[pid] - y_n[nid3]) / dy
                N4y = 1 - abs(y_p[pid] - y_n[nid4]) / dy
                
                dN1x = -1/dx
                dN2x = 1/dx
                dN3x = -1/dx
                dN4x = 1/dx
                dN1y = -1/dy
                dN2y = -1/dy
                dN3y = 1/dy
                dN4y = 1/dy
                
                

                # compute particle velocity
                # compute_acceleration_velocity:
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/node.tcc#L223-L253
                
                vel_px = vel_px.at[pid].set(vel_px[pid] + dt * N1x * fint_nx[nid1] / mass_n[nid1])
                vel_px = vel_px.at[pid].set(vel_px[pid] + dt * N2x * fint_nx[nid2] / mass_n[nid2])
                vel_px = vel_px.at[pid].set(vel_px[pid] + dt * N3x * fint_nx[nid3] / mass_n[nid3])
                vel_px = vel_px.at[pid].set(vel_px[pid] + dt * N4x * fint_nx[nid4] / mass_n[nid4])
                vel_py = vel_py.at[pid].set(vel_py[pid] + dt * N1y * fint_ny[nid1] / mass_n[nid1])
                vel_py = vel_py.at[pid].set(vel_py[pid] + dt * N2y * fint_ny[nid2] / mass_n[nid2])
                vel_py = vel_py.at[pid].set(vel_py[pid] + dt * N3y * fint_ny[nid3] / mass_n[nid3])
                vel_py = vel_py.at[pid].set(vel_py[pid] + dt * N4y * fint_ny[nid4] / mass_n[nid4])
                
                
                
                # update particle position based on nodal momentum
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L778-L810
                
                # nodal velocity
                nv1x = mom_nx[nid1]/mass_n[nid1]
                nv2x = mom_nx[nid2]/mass_n[nid2]
                nv3x = mom_nx[nid3]/mass_n[nid3]
                nv4x = mom_nx[nid4]/mass_n[nid4]
                nv1y = mom_ny[nid1]/mass_n[nid1]
                nv2y = mom_ny[nid2]/mass_n[nid2]
                nv3y = mom_ny[nid3]/mass_n[nid3]
                nv4y = mom_ny[nid4]/mass_n[nid4]
                
                # compute_updated_position
                x_p = x_p.at[pid].set(x_p[pid] + dt * (N1x * nv1x + N2x * nv2x + N3x * nv3x + N4x * nv4x))
                y_p = y_p.at[pid].set(y_p[pid] + dt * (N1y * nv1y + N2y * nv2y + N3y * nv3y + N4y * nv4y))

                

                # rate of strain increment
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L603-L619
                grad_vx = (dN1x * nv1x + dN2x * nv2x + dN3x * nv3x + dN4x * nv4x)
                grad_vy = (dN1y * nv1y + dN2y * nv2y + dN3y * nv3y + dN4y * nv4y)
                grad_vxy = (dN1x * nv1y + dN2x * nv2y + dN3x * nv3y + dN4x * nv4y) + (dN1y * nv1x + dN2y * nv2x + dN3y * nv3x + dN4y * nv4x)
                # particle dstrain
                dstrainx = grad_vx * dt
                dstrainy = grad_vy * dt
                dstrainxy = grad_vxy * dt
                # particle volume
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/particles/particle.tcc#L498-L505
                vol_p = vol_p.at[pid].set((1 + dstrainx + dstrainy) * vol_p[pid])
                # update stress using linear elastic model
                # https://github.com/cb-geo/mpm/blob/86ba10eeca3badba31b37c49e11a3930ee6f2c16/include/materials/linear_elastic.tcc#L60-L66
                stress_px = stress_px.at[pid].set(stress_px[pid] + a1 * dstrainx + a2 * dstrainy + G * dstrainxy)
                stress_py = stress_py.at[pid].set(stress_py[pid] + a1 * dstrainy + a2 * dstrainx + G * dstrainxy)

        # results
        vt = vt.at[i].set(vel_px[pmid], vel_py[pmid])
        xt = xt.at[i].set(x_p[pmid], y_p[pmid])
        max_stress = max_stress.at[i].set([jnp.max(stress_px), jnp.max(stress_py)])
        # max_stress = max_stress.at[i, 1].set()


        return (x_p, y_p, mass_p, vel_px, vel_py, vol_p, stress_px, stress_py, vt, xt, max_stress)

    x_p, y_p, mass_p, vel_px, vel_py, vol_p, stress_px, stress_py, vt, xt, max_stress = lax.fori_loop(0, nsteps, step, (x_p, y_p, mass_p, vel_px, vel_py, vol_p, stress_px, stress_py, vt, xt, max_stress))


    
    return vt, max_stress

In [22]:
# Assign target
Etarget = 1000
target, max_stress = mpm(Etarget)

max_stress

: [[0 1 2 3]]


Array([[ 0.0000000e+00, -7.9999995e-01],
       [ 2.5578309e-04, -9.6089631e-01],
       [-7.9206586e-02, -9.1391587e-01],
       [-1.2723728e-01, -8.0932498e-01],
       [-6.2189415e-02, -6.9803214e-01],
       [ 2.2585779e-02, -5.9602398e-01],
       [ 2.6012063e-03, -5.0726247e-01],
       [-8.0139548e-02, -4.3166220e-01],
       [-8.3924621e-02, -3.6783326e-01],
       [ 5.6146979e-03, -3.1410712e-01]], dtype=float32)

In [4]:

#############################################################
#  NOTE: Uncomment the line only for TFP optimizer and 
#        jaxopt value_and_grad = True
#############################################################
# @jax.value_and_grad
@jit
def compute_loss(E):
    vt = mpm(E)
    return jnp.linalg.norm(vt - target)

# BFGS Optimizer
# TODO: Implement box constrained optimizer
def jaxopt_bfgs(params, niter):
  opt= jaxopt.BFGS(fun=compute_loss, value_and_grad=True, tol=1e-5, implicit_diff=False, maxiter=niter)
  res = opt.run(init_params=params)
  result, _ = res
  return result

# Optimizers
def optax_adam(params, niter):
  # Initialize parameters of the model + optimizer.
  start_learning_rate = 1e-1
  optimizer = optax.adam(start_learning_rate)
  opt_state = optimizer.init(params)

  # A simple update loop.
  for i in range(niter):
    print('iteration: ' + i)
    grads = grad(compute_loss)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
  return params
  
# Tensor Flow Probability Optimization library
def tfp_lbfgs(params):
  results = tfp.optimizer.lbfgs_minimize(
        jax.jit(compute_loss), initial_position=params, tolerance=1e-5)
  return results.position

# Initial model - Young's modulus 
params = 95.0

# vt = tfp_lbfgs(params)               # LBFGS optimizer
result = optax_adam(params, 1000)     # ADAM optimizer

"""
f = jax.jit(compute_loss)
df = jax.jit(jax.grad(compute_loss))
E = 95.0
print(0, E)
for i in range(10):
    E = E - f(E)/df(E)
    print(i, E)
"""
print("E: {}".format(result))
vel = mpm(result)
# update time steps
dt = 0.02
nsteps = 10
tt = jnp.arange(0, nsteps) * dt


vel = vel.reshape((2,10))
target = target.reshape((2,10))


# Plot results
plt.plot(tt, vel[0,:], 'r', markersize=1, label='mpm')
plt.plot(tt, target[0,:], 'ob', markersize=1, label='mpm-target')
plt.xlabel('time (s)')
plt.ylabel('x velocity (m/s)')
plt.legend()
plt.show()

plt.plot(tt, vel[1,:], 'r', markersize=1, label='mpm')
plt.plot(tt, target[1,:], 'ob', markersize=1, label='mpm-target')
plt.xlabel('time (s)')
plt.ylabel('y velocity (m/s)')
plt.legend()
plt.show()

2023-03-06 22:21:52.496733: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_compute_loss] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-06 22:33:48.296126: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 13m55.797362408s

********************************
[Compiling module jit_compute_loss] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


: 

: 

In [None]:
result

Array(100.002594, dtype=float32)