In [1]:

import jax.numpy as jnp
from jax import grad, jit, vmap, lax
import jax.scipy as jsp
import jax.scipy.optimize as jsp_opt
import optax 
import jaxopt
from jaxopt import ScipyBoundedMinimize
import matplotlib.pyplot as plt
import jax

@jit
def mpm(E):
    # nsteps
    nsteps = 2500
    
    # mom tolerance
    tol = 1e-12

    # Domain
    L = 25

    # Material properties
    # E = 100
    rho = 1

    # Computational grid
    nelements = 13 # number of elements
    dx = L / nelements # element length

    # Create equally spaced nodes
    x_n = jnp.linspace(0, L, nelements+1)
    nnodes = len(x_n)

    # Set-up a 2D array of elements with node ids
    elements = jnp.zeros((nelements, 2), dtype = int)
    for nid in range(nelements):
        elements = elements.at[nid,0].set(nid)
        elements = elements.at[nid,1].set(nid+1)

    

    # Create material points at the center of each element
    nparticles = nelements  # number of particles
    # Id of the particle in the central element
    pmid = 6

    # Material point properties
    x_p      = jnp.zeros(nparticles)       # positions
    vol_p    = jnp.ones(nparticles) * dx   # volume
    mass_p   = vol_p * rho                 # mass
    stress_p = jnp.zeros(nparticles)       # stress
    vel_p    = jnp.zeros(nparticles)       # velocity
    
    # Create particle at the center
    x_p      = 0.5 * (x_n[:-1] + x_n[1:])

    # Loading conditions
    v0 = 0.1             # initial velocity
    c  = jnp.sqrt(E/rho)  # speed of sound
    b1 = jnp.pi / (2 * L) # beta1
    w1 = b1 * c          # omega1
    fext = jnp.zeros(nparticles)

    # set initial velocities
    vel_p    = v0 * jnp.sin(b1 * x_p)
    
    # Time steps and duration
    dt_crit = dx / c
    dt = 0.02
    
    # results
    tt = jnp.zeros(nsteps)
    vt = jnp.zeros(nsteps)
    xt = jnp.zeros(nsteps)

    def step(i, carry):
        x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = carry
        # reset nodal values
        mass_n  = jnp.zeros(nnodes)  # mass
        mom_n   = jnp.zeros(nnodes)  # momentum
        fint_n  = jnp.zeros(nnodes)  # internal force
        fext_n  = jnp.zeros(nnodes)  # external force

        

        # iterate through each element
        for eid in range(nelements):
            # get nodal ids
            nid1, nid2 = elements[eid]

            # compute shape functions and derivatives
            N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
            N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
            dN1 = -1/dx
            dN2 = 1/dx

            # map particle mass and momentum to nodes
            mass_n = mass_n.at[nid1].set(mass_n[nid1] + N1 * mass_p[eid])
            mass_n = mass_n.at[nid2].set(mass_n[nid2] + N2 * mass_p[eid])

            mom_n = mom_n.at[nid1].set(mom_n[nid1] + N1 * mass_p[eid] * vel_p[eid])
            mom_n = mom_n.at[nid2].set(mom_n[nid2] + N2 * mass_p[eid] * vel_p[eid])

            # compute nodal internal force
            fint_n = fint_n.at[nid1].set(fint_n[nid1] - vol_p[eid] * stress_p[eid] * dN1)
            fint_n = fint_n.at[nid2].set(fint_n[nid2] - vol_p[eid] * stress_p[eid] * dN2)

            # compute nodal external force
            fext_n = fext_n.at[nid1].set(fext_n[nid1] + N1 * fext[eid])
            fext_n = fext_n.at[nid2].set(fext_n[nid2] + N2 * fext[eid])        

        # apply boundary conditions # this goes before nodal internal force in diffmpm
        mom_n = mom_n.at[0].set(0)  # Nodal velocity v = 0 in m * v at node 0.
        fint_n = fint_n.at[0].set(0)  # Nodal force f = m * a, where a = 0 at node 0.

        # update nodal momentum
        mom_n = mom_n + fint_n * dt

        # update particle velocity position and stress
        # iterate through each element
        for eid in range(nelements):
            # get nodal ids
            nid1, nid2 = elements[eid]

            # compute shape functions and derivatives
            N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
            N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
            dN1 = -1/dx
            dN2 = 1/dx
            
            # compute particle velocity
            # if (mass_n[nid1]) > tol:
            vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N1 * fint_n[nid1] / mass_n[nid1])
            # if (mass_n[nid2]) > tol:
            vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N2 * fint_n[nid2] / mass_n[nid2])
            # update particle position based on nodal momentum
            x_p = x_p.at[eid].set(x_p[eid] + dt * (N1 * mom_n[nid1]/mass_n[nid1] + N2 * mom_n[nid2]/mass_n[nid2]))

            # nodal velocity
            nv1 = mom_n[nid1]/mass_n[nid1]
            nv2 = mom_n[nid2]/mass_n[nid2]

             # rate of strain increment
            grad_v = dN1 * nv1 + dN2 * nv2
            # particle dstrain
            dstrain = grad_v * dt
            # particle volume
            vol_p = vol_p.at[eid].set((1 + dstrain) * vol_p[eid])
            # update stress using linear elastic model
            stress_p = stress_p.at[eid].set(stress_p[eid] + E * dstrain)

        # results
        vt = vt.at[i].set(vel_p[pmid])
        xt = xt.at[i].set(x_p[pmid])

        return (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt)

    x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = lax.fori_loop(0, nsteps, step, (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt))


    return vt




In [None]:

import jax.numpy as jnp
from jax import grad, jit, vmap, lax
import jax.scipy as jsp
import jax.scipy.optimize as jsp_opt
import optax 
import jaxopt
from jaxopt import ScipyBoundedMinimize
import matplotlib.pyplot as plt
import jax

def jdprint(obj, string = ""):
    jax.debug.print(string+': {x}', x=obj)

@jit
def mpm(E):
    # nsteps
    nsteps = 100
    
    # mom tolerance
    tol = 1e-12

    # Domain
    L = 25

    # Material properties
    # E = 100
    rho = 1

    # Computational grid

    nelements = 13 # number of elements
    dx = L / nelements # element length

    # Create equally spaced nodes
    x_n = jnp.linspace(0, L, nelements+1)
    nnodes = len(x_n)

    # Set-up a 2D array of elements with node ids
    elements = jnp.zeros((nelements, 2), dtype = int)
    for nid in range(nelements):
        elements = elements.at[nid,0].set(nid)
        elements = elements.at[nid,1].set(nid+1)

    # Loading conditions
    v0 = 0.1             # initial velocity
    c  = jnp.sqrt(E/rho)  # speed of sound
    b1 = jnp.pi / (2 * L) # beta1
    w1 = b1 * c          # omega1

    # Create material points at the center of each element
    nparticles = nelements  # number of particles
    # Id of the particle in the central element
    pmid = 6

    # Material point properties
    x_p      = jnp.zeros(nparticles)       # positions
    vol_p    = jnp.ones(nparticles) * dx   # volume
    mass_p   = vol_p * rho                 # mass
    stress_p = jnp.zeros(nparticles)       # stress
    vel_p    = jnp.zeros(nparticles)       # velocity
    
    # Create particle at the center
    x_p      = 0.5 * (x_n[:-1] + x_n[1:])
    # set initial velocities
    vel_p    = v0 * jnp.sin(b1 * x_p)
    
    # Time steps and duration
    dt_crit = dx / c
    dt = 0.02
    
    # results
    tt = jnp.zeros(nsteps)
    vt = jnp.zeros(nsteps)
    xt = jnp.zeros(nsteps)

    def step(i, carry):
        x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = carry
        # reset nodal values
        mass_n  = jnp.zeros(nnodes)  # mass
        mom_n   = jnp.zeros(nnodes)  # momentum
        fint_n  = jnp.zeros(nnodes)  # internal force

        # iterate through each element
        for eid in range(nelements):
            # get nodal ids
            nid1, nid2 = elements[eid]

            # compute shape functions and derivatives
            N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
            N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
            dN1 = -1/dx
            dN2 = 1/dx

            # map particle mass and momentum to nodes
            mass_n = mass_n.at[nid1].set(mass_n[nid1] + N1 * mass_p[eid])
            mass_n = mass_n.at[nid2].set(mass_n[nid2] + N2 * mass_p[eid])

            mom_n = mom_n.at[nid1].set(mom_n[nid1] + N1 * mass_p[eid] * vel_p[eid])
            mom_n = mom_n.at[nid2].set(mom_n[nid2] + N2 * mass_p[eid] * vel_p[eid])

            # compute nodal internal force
            fint_n = fint_n.at[nid1].set(fint_n[nid1] - vol_p[eid] * stress_p[eid] * dN1)
            fint_n = fint_n.at[nid2].set(fint_n[nid2] - vol_p[eid] * stress_p[eid] * dN2)
        
        # apply boundary conditions
        mom_n = mom_n.at[0].set(0)  # Nodal velocity v = 0 in m * v at node 0.
        fint_n = fint_n.at[0].set(0)  # Nodal force f = m * a, where a = 0 at node 0.

        # update nodal momentum
        mom_n = mom_n + fint_n * dt

        # update particle velocity position and stress
        # iterate through each element
        for eid in range(nelements):
            # get nodal ids
            nid1, nid2 = elements[eid]

            # compute shape functions and derivatives
            N1 = 1 - abs(x_p[eid] - x_n[nid1])/dx
            N2 = 1 - abs(x_p[eid] - x_n[nid2])/dx
            dN1 = -1/dx
            dN2 = 1/dx
            
            # compute particle velocity
            # if (mass_n[nid1]) > tol:
            vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N1 * fint_n[nid1] / mass_n[nid1])
            # if (mass_n[nid2]) > tol:
            vel_p = vel_p.at[eid].set(vel_p[eid] + dt * N2 * fint_n[nid2] / mass_n[nid2])
            # update particle position based on nodal momentum
            x_p = x_p.at[eid].set(x_p[eid] + dt * (N1 * mom_n[nid1]/mass_n[nid1] + N2 * mom_n[nid2]/mass_n[nid2]))

            # nodal velocity
            nv1 = mom_n[nid1]/mass_n[nid1]
            nv2 = mom_n[nid2]/mass_n[nid2]

             # rate of strain increment
            grad_v = dN1 * nv1 + dN2 * nv2
            # particle dstrain
            dstrain = grad_v * dt
            # particle volume
            vol_p = vol_p.at[eid].set((1 + dstrain) * vol_p[eid])
            # update stress using linear elastic model
            stress_p = stress_p.at[eid].set(stress_p[eid] + E * dstrain)

        # results
        vt = vt.at[i].set(vel_p[pmid])
        xt = xt.at[i].set(x_p[pmid])

        return (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt)

    x_p, mass_p, vel_p, vol_p, stress_p, vt, xt = lax.fori_loop(0, nsteps, step, (x_p, mass_p, vel_p, vol_p, stress_p, vt, xt))


    return vt

In [2]:
# baseline
E = 100
target = mpm(E)

target

Array([0.07071068, 0.07069957, 0.07067733, 0.07064398, 0.07059952,
       0.07054394, 0.07047725, 0.07039945, 0.07031056, 0.07021059,
       0.07009955, 0.06997745, 0.0698443 , 0.06970013, 0.06954496,
       0.06937879, 0.06920167, 0.06901361, 0.06881464, 0.06860477,
       0.06838407, 0.06815253, 0.0679102 , 0.06765713, 0.06739333,
       0.06711886, 0.06683375, 0.06653804, 0.06623179, 0.06591503,
       0.06558781, 0.06525019, 0.06490221, 0.06454393, 0.06417539,
       0.06379667, 0.06340781, 0.06300888, 0.06259994, 0.06218106,
       0.06175229, 0.06131371, 0.06086539, 0.06040739, 0.05993979,
       0.05946267, 0.0589761 , 0.05848015, 0.05797491, 0.05746046,
       0.05693687, 0.05640425, 0.05586267, 0.05531221, 0.05475298,
       0.05418506, 0.05360855, 0.05302354, 0.05243013, 0.05182841,
       0.05121849, 0.05060047, 0.04997444, 0.04934052, 0.04869881,
       0.04804942, 0.04739245, 0.04672803, 0.04605624, 0.04537722,
       0.04469107, 0.04399791, 0.04329786, 0.04259103, 0.04187