Namespace diffmpm.cli
+Sub-modules
+-
+
diffmpm.cli.mpm
+- + + +
diff --git a/README.md b/README.md index f32775d..a369e58 100644 --- a/README.md +++ b/README.md @@ -1 +1,34 @@ -# DiffMPM +# Differentiable Material Point Method (DiffMPM) + +MPM simulations are applied in various fields such as computer graphics, geotechnical engineering, computational mechanics and more. `diffmpm` is a differentiable MPM simulation library written entirely in JAX which means it also has all the niceties that come with JAX. It is a highly parallel, Just-In-Time compiled code that can run on CPUs, GPUs or TPUs. It aims to be a fast solver that can be used in various problems like optimization and inverse problems. Having a differentiable MPM simulation opens up several advantages - +- **Efficient Gradient-based Optimization:** Since the entire simulation model is differentiable, it can be used in conjunction with various gradient-based optimization techniques such as stochastic gradient descent (SGD), ADAM etc. +- **Inverse Problems:** It also enables us to solve inverse problems to determine material properties by formulating an inverse problem as an optimization task. +- **Integration with Deep Learning:** It can be seamlessly integrated with other Neural Network models to enable training physics-informed neural networks. + +## Installation +`diffmpm` can be installed directly from PyPI using `pip` + +``` shell +pip install diffmpm +``` + +#### ToDo +Add separate installation commands for CPU/GPU. + +## Usage +Once installed, `diffmpm` can be used as a CLI tool or can be imported as a library in Python. Example input files can be found in the `benchmarks/` directory. + +``` +Usage: mpm [OPTIONS] + + CLI utility for DiffMPM. + +Options: + -f, --file TEXT Input TOML file [required] + --version Show the version and exit. + --help Show this message and exit. +``` + +Further documentation about the input file can be found in the documentation _[INSERT LINK HERE]_. `diffmpm` can write the output to various file types like `.npz`, `.vtk` etc. that can then be used to visualize the output of the simulations. + +## Examples diff --git a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py index 248034a..ae72923 100644 --- a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py +++ b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py index b880ffa..356d0a3 100644 --- a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py +++ b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/benchmarks/2d/uniaxial_stress/test_benchmark.py b/benchmarks/2d/uniaxial_stress/test_benchmark.py index 0a6d10c..f04e820 100644 --- a/benchmarks/2d/uniaxial_stress/test_benchmark.py +++ b/benchmarks/2d/uniaxial_stress/test_benchmark.py @@ -1,6 +1,8 @@ import os from pathlib import Path + import jax.numpy as jnp + from diffmpm import MPM diff --git a/diffmpm/__init__.py b/diffmpm/__init__.py index bf2f251..faa8316 100644 --- a/diffmpm/__init__.py +++ b/diffmpm/__init__.py @@ -40,7 +40,7 @@ def __init__(self, filepath): raise ValueError("Wrong type of solver specified.") def solve(self): - """Solve the MPM simulation.""" + """Solve the MPM simulation using JIT solver.""" arrays = self.solver.solve_jit( self._config.parsed_config["external_loading"]["gravity"], ) diff --git a/diffmpm/cli/mpm.py b/diffmpm/cli/mpm.py index 3e621cf..aebc4ba 100644 --- a/diffmpm/cli/mpm.py +++ b/diffmpm/cli/mpm.py @@ -3,11 +3,12 @@ from diffmpm import MPM -@click.command() +@click.command() # type: ignore @click.option( "-f", "--file", "filepath", required=True, type=str, help="Input TOML file" ) @click.version_option(package_name="diffmpm") def mpm(filepath): + """CLI utility for DiffMPM.""" solver = MPM(filepath) solver.solve() diff --git a/diffmpm/constraint.py b/diffmpm/constraint.py index cba836a..93f75bd 100644 --- a/diffmpm/constraint.py +++ b/diffmpm/constraint.py @@ -3,7 +3,18 @@ @register_pytree_node_class class Constraint: - def __init__(self, dir, velocity): + """Generic velocity constraints to apply on nodes or particles.""" + + def __init__(self, dir: int, velocity: float): + """Contains 2 govering parameters. + + Attributes + ---------- + dir : int + Direction in which constraint is applied. + velocity : float + Constrained velocity to be applied. + """ self.dir = dir self.velocity = velocity @@ -16,16 +27,15 @@ def tree_unflatten(cls, aux_data, children): return cls(*aux_data) def apply(self, obj, ids): - """ - Apply constraint values to the passed object. + """Apply constraint values to the passed object. - Arguments - --------- + Parameters + ---------- obj : diffmpm.node.Nodes, diffmpm.particle.Particles Object on which the constraint is applied ids : array_like The indices of the container `obj` on which the constraint - will be applied. + will be applied. """ obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity) obj.momentum = obj.momentum.at[ids, :, self.dir].set( diff --git a/diffmpm/element.py b/diffmpm/element.py index f0b9d4f..3eeff67 100644 --- a/diffmpm/element.py +++ b/diffmpm/element.py @@ -1,52 +1,79 @@ +from __future__ import annotations + import abc import itertools -from typing import Sequence, Tuple, List +from typing import TYPE_CHECKING, Optional, Sequence, Tuple + +if TYPE_CHECKING: + from diffmpm.particle import Particles import jax.numpy as jnp -from jax import jacobian, jit, lax, vmap +from jax import Array, jacobian, jit, lax, vmap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike -from diffmpm.node import Nodes from diffmpm.constraint import Constraint +from diffmpm.node import Nodes + +__all__ = ["_Element", "Linear1D", "Quadrilateral4Node"] class _Element(abc.ABC): + """Base element class that is inherited by all types of Elements.""" + + nodes: Nodes + total_elements: int + concentrated_nodal_forces: Sequence + volume: Array + @abc.abstractmethod - def id_to_node_ids(self): - ... + def id_to_node_ids(self, id: ArrayLike) -> Array: + """Node IDs corresponding to element `id`. + + This method is implemented by each of the subclass. - def id_to_node_loc(self, id: int): + Parameters + ---------- + id : int + Element ID. + + Returns + ------- + ArrayLike + Nodal IDs of the element. """ - Node locations corresponding to element `id`. + ... + + def id_to_node_loc(self, id: ArrayLike) -> Array: + """Node locations corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal locations for the element. Shape of returned - array is (nodes_in_element, 1, ndim) + array is `(nodes_in_element, 1, ndim)` """ node_ids = self.id_to_node_ids(id).squeeze() return self.nodes.loc[node_ids] - def id_to_node_vel(self, id: int): - """ - Node velocities corresponding to element `id`. + def id_to_node_vel(self, id: ArrayLike) -> Array: + """Node velocities corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal velocities for the element. Shape of returned - array is (nodes_in_element, 1, ndim) + array is `(nodes_in_element, 1, ndim)` """ node_ids = self.id_to_node_ids(id).squeeze() return self.nodes.velocity[node_ids] @@ -77,29 +104,33 @@ def tree_unflatten(cls, aux_data, children): ) @abc.abstractmethod - def shapefn(self): + def shapefn(self, xi: ArrayLike): + """Evaluate Shape function for element type.""" ... @abc.abstractmethod - def shapefn_grad(self): + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Evaluate gradient of shape function for element type.""" ... @abc.abstractmethod - def set_particle_element_ids(self): + def set_particle_element_ids(self, particles: Particles): + """Set the element IDs that particles are present in.""" ... # Mapping from particles to nodes (P2G) - def compute_nodal_mass(self, particles): - r""" - Compute the nodal mass based on particle mass. + def compute_nodal_mass(self, particles: Particles): + r"""Compute the nodal mass based on particle mass. The nodal mass is updated as a sum of particle mass for all particles mapped to the node. - :math:`(m)_i = \sum_p N_i(x_p) m_p` + \[ + (m)_i = \sum_p N_i(x_p) m_p + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -120,17 +151,18 @@ def _step(pid, args): ) _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args) - def compute_nodal_momentum(self, particles): - r""" - Compute the nodal mass based on particle mass. + def compute_nodal_momentum(self, particles: Particles): + r"""Compute the nodal mass based on particle mass. The nodal mass is updated as a sum of particle mass for all particles mapped to the node. - :math:`(mv)_i = \sum_p N_i(x_p) (mv)_p` + \[ + (mv)_i = \sum_p N_i(x_p) (mv)_p + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -156,7 +188,8 @@ def _step(pid, args): self.nodes.momentum, ) - def compute_velocity(self, particles): + def compute_velocity(self, particles: Particles): + """Compute velocity using momentum.""" self.nodes.velocity = jnp.where( self.nodes.mass == 0, self.nodes.velocity, @@ -168,17 +201,18 @@ def compute_velocity(self, particles): self.nodes.velocity, ) - def compute_external_force(self, particles): - r""" - Update the nodal external force based on particle f_ext. + def compute_external_force(self, particles: Particles): + r"""Update the nodal external force based on particle f_ext. The nodal force is updated as a sum of particle external force for all particles mapped to the node. - :math:`(f_{ext})_i = \sum_p N_i(x_p) f_{ext}` + \[ + f_{ext})_i = \sum_p N_i(x_p) f_{ext} + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -199,17 +233,18 @@ def _step(pid, args): ) self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def compute_body_force(self, particles, gravity: float | jnp.ndarray): - r""" - Update the nodal external force based on particle mass. + def compute_body_force(self, particles: Particles, gravity: ArrayLike): + r"""Update the nodal external force based on particle mass. The nodal force is updated as a sum of particle body force for all particles mapped to th - :math:`(f_{ext})_i += \sum_p N_i(x_p) m_p g` + \[ + (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g + \] - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -232,14 +267,31 @@ def _step(pid, args): ) self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def apply_concentrated_nodal_forces(self, particles, curr_time): + def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + curr_time: float + Current time in the simulation. + """ for cnf in self.concentrated_nodal_forces: factor = cnf.function.value(curr_time) self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add( factor * cnf.force ) - def apply_particle_traction_forces(self, particles): + def apply_particle_traction_forces(self, particles: Particles): + """Apply concentrated nodal forces. + + Parameters + ---------- + particles: Particles + Particles in the simulation. + """ + def _step(pid, args): f_ext, ptraction, mapped_pos, el_nodes = args f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid]) @@ -250,7 +302,9 @@ def _step(pid, args): args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes) self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args) - def update_nodal_acceleration_velocity(self, particles, dt: float, *args): + def update_nodal_acceleration_velocity( + self, particles: Particles, dt: float, *args + ): """Update the nodal momentum based on total force on nodes.""" total_force = self.nodes.get_total_force() self.nodes.acceleration = self.nodes.acceleration.at[:].set( @@ -288,38 +342,54 @@ def apply_force_boundary_constraints(self, *args): @register_pytree_node_class class Linear1D(_Element): - """ - Container for 1D line elements (and nodes). + """Container for 1D line elements (and nodes). + + Element ID: 0 1 2 3 + Mesh: +-----+-----+-----+-----+ + Node IDs: 0 1 2 3 4 - Element ID: 0 1 2 3 - Mesh: +-----+-----+-----+-----+ - Node IDs: 0 1 2 3 4 + where + + + : Nodes + +-----+ : An element - + : Nodes - +-----+ : An element """ def __init__( self, nelements: int, - total_elements, + total_elements: int, el_len: float, - constraints: List[Tuple[jnp.ndarray, Constraint]], - nodes: Nodes = None, - concentrated_nodal_forces=[], - initialized=None, - volume=None, + constraints: Sequence[Tuple[ArrayLike, Constraint]], + nodes: Optional[Nodes] = None, + concentrated_nodal_forces: Sequence = [], + initialized: Optional[bool] = None, + volume: Optional[ArrayLike] = None, ): """Initialize Linear1D. - Arguments - --------- + Parameters + ---------- nelements : int Number of elements. + total_elements : int + Total number of elements (same as `nelements` for 1D) el_len : float Length of each element. - boundary_nodes : Sequence - IDs of nodes that are supposed to be fixed (boundary). + constraints: list + A list of constraints where each element is a tuple of type + `(node_ids, diffmpm.Constraint)`. Here, `node_ids` correspond to + the node IDs where `diffmpm.Constraint` should be applied. + nodes : Nodes, Optional + Nodes in the element object. + concentrated_nodal_forces: list + A list of `diffmpm.forces.NodalForce`s that are to be + applied. + initialized: bool, None + `True` if the class has been initialized, `None` if not. + This is required like this for using JAX flattening. + volume: ArrayLike + Volume of the elements. """ self.nelements = nelements self.total_elements = nelements @@ -338,72 +408,71 @@ def __init__( if initialized is None: self.volume = jnp.ones((self.total_elements, 1, 1)) else: - self.volume = volume + self.volume = jnp.asarray(volume) self.initialized = True - def id_to_node_ids(self, id: int): - """ - Node IDs corresponding to element `id`. + def id_to_node_ids(self, id: ArrayLike): + """Node IDs corresponding to element `id`. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal IDs of the element. Shape of returned - array is (2, 1) + array is `(2, 1)` """ return jnp.array([id, id + 1]).reshape(2, 1) - def shapefn(self, xi: float | jnp.ndarray): - """ - Evaluate linear shape function. + def shapefn(self, xi: ArrayLike): + """Evaluate linear shape function. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles in natural coordinates to evaluate - the function at. Expected shape is (npoints, 1, ndim) + the function at. Expected shape is `(npoints, 1, ndim)` Returns ------- array_like Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape (1, 2, 1) but if the input is a vector then the output will - be of the shape (len(x), 2, 1). + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 2, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 2, 1)`. """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + if xi.ndim != 3: raise ValueError( f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3) return result - def _shapefn_natural_grad(self, xi: float | jnp.ndarray): - """ - Calculate the gradient of shape function. + def _shapefn_natural_grad(self, xi: ArrayLike): + """Calculate the gradient of shape function. This calculation is done in the natural coordinates. - Arguments - --------- + Parameters + ---------- x : float, array_like Locations of particles in natural coordinates to evaluate - the function at. + the function at. Returns ------- array_like Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape (2, 1). + the returned array will depend on the input shape. For example, + in the linear case, if the input is a scalar, the returned array + will be of the shape `(2, 1)`. """ + xi = jnp.asarray(xi) result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() # TODO: The following code tries to evaluate vmap even if @@ -416,25 +485,26 @@ def _shapefn_natural_grad(self, xi: float | jnp.ndarray): # ) return result.reshape(2, 1) - def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): - """ - Gradient of shape function in physical coordinates. + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Gradient of shape function in physical coordinates. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles to evaluate in natural coordinates. - Expected shape (npoints, 1, ndim). + Expected shape `(npoints, 1, ndim)`. coords : array_like Nodal coordinates to transform by. Expected shape - (npoints, 1, ndim) + `(npoints, 1, ndim)` Returns ------- array_like Gradient of the shape function in physical coordinates at `xi` """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + coords = jnp.asarray(coords) + if xi.ndim != 3: raise ValueError( f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -445,8 +515,7 @@ def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): return result def set_particle_element_ids(self, particles): - """ - Set the element IDs for the particles. + """Set the element IDs for the particles. If the particle doesn't lie between the boundaries of any element, it sets the element index to -1. @@ -472,20 +541,24 @@ def f(x): ) def compute_volume(self, *args): + """Compute volume of all elements.""" vol = jnp.ediff1d(self.nodes.loc) self.volume = jnp.ones((self.total_elements, 1, 1)) * vol def compute_internal_force(self, particles): - r""" - Update the nodal internal force based on particle mass. + r"""Update the nodal internal force based on particle mass. The nodal force is updated as a sum of internal forces for all particles mapped to the node. - :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)` + \[ + (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + \] - Arguments - --------- + where \(\sigma_p\) is the stress at particle \(p\). + + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -529,45 +602,63 @@ def _step(pid, args): @register_pytree_node_class class Quadrilateral4Node(_Element): - """ - Container for 2D quadrilateral elements with 4 nodes. + r"""Container for 2D quadrilateral elements with 4 nodes. Nodes and elements are numbered as - 15 0---0---0---0---0 19 + 15 +---+---+---+---+ 19 | 8 | 9 | 10| 11| - 10 0---0---0---0---0 14 + 10 +---+---+---+---+ 14 | 4 | 5 | 6 | 7 | - 5 0---0---0---0---0 9 + 5 +---+---+---+---+ 9 | 0 | 1 | 2 | 3 | - 0---0---0---0---0 + +---+---+---+---+ 0 1 2 3 4 - + : Nodes - +---+ - | | : An element - +---+ + where + + + : Nodes + +---+ + | | : An element + +---+ """ def __init__( self, - nelements: Tuple[int, int], + nelements: int, total_elements: int, - el_len: Tuple[float, float], - constraints: List[Tuple[jnp.ndarray, Constraint]], - nodes: Nodes = None, - concentrated_nodal_forces=[], - initialized: bool = None, - volume: jnp.ndarray = None, - ): - """Initialize Quadrilateral4Node. - - Arguments - --------- - nelements : (int, int) - Number of elements in X and Y direction. - el_len : (float, float) - Length of each element in X and Y direction. + el_len: float, + constraints: Sequence[Tuple[ArrayLike, Constraint]], + nodes: Optional[Nodes] = None, + concentrated_nodal_forces: Sequence = [], + initialized: Optional[bool] = None, + volume: Optional[ArrayLike] = None, + ) -> None: + """Initialize Linear1D. + + Parameters + ---------- + nelements : int + Number of elements. + total_elements : int + Total number of elements (product of all elements of `nelements`) + el_len : float + Length of each element. + constraints: list + A list of constraints where each element is a tuple of + type `(node_ids, diffmpm.Constraint)`. Here, `node_ids` + correspond to the node IDs where `diffmpm.Constraint` + should be applied. + nodes : Nodes, Optional + Nodes in the element object. + concentrated_nodal_forces: list + A list of `diffmpm.forces.NodalForce`s that are to be + applied. + initialized: bool, None + `True` if the class has been initialized, `None` if not. + This is required like this for using JAX flattening. + volume: ArrayLike + Volume of the elements. """ self.nelements = jnp.asarray(nelements) self.el_len = jnp.asarray(el_len) @@ -578,15 +669,15 @@ def __init__( coords = jnp.asarray( list( itertools.product( - jnp.arange(nelements[1] + 1), - jnp.arange(nelements[0] + 1), + jnp.arange(self.nelements[1] + 1), + jnp.arange(self.nelements[0] + 1), ) ) ) node_locations = ( jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.el_len ).reshape(-1, 1, 2) - self.nodes = Nodes(total_nodes, node_locations) + self.nodes = Nodes(int(total_nodes), node_locations) else: self.nodes = nodes @@ -595,12 +686,11 @@ def __init__( if initialized is None: self.volume = jnp.ones((self.total_elements, 1, 1)) else: - self.volume = volume + self.volume = jnp.asarray(volume) self.initialized = True - def id_to_node_ids(self, id: int): - """ - Node IDs corresponding to element `id`. + def id_to_node_ids(self, id: ArrayLike): + """Node IDs corresponding to element `id`. 3----2 | | @@ -608,16 +698,16 @@ def id_to_node_ids(self, id: int): Node ids are returned in the order as shown in the figure. - Arguments - --------- + Parameters + ---------- id : int Element ID. Returns ------- - jax.numpy.ndarray + ArrayLike Nodal IDs of the element. Shape of returned - array is (4, 1) + array is (4, 1) """ lower_left = (id // self.nelements[0]) * ( self.nelements[0] + 1 @@ -632,26 +722,26 @@ def id_to_node_ids(self, id: int): ) return result.reshape(4, 1) - def shapefn(self, xi: Sequence[float]): - """ - Evaluate linear shape function. + def shapefn(self, xi: ArrayLike): + """Evaluate linear shape function. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles in natural coordinates to evaluate - the function at. Expected shape is (npoints, 1, ndim) + the function at. Expected shape is (npoints, 1, ndim) Returns ------- array_like Evaluated shape function values. The shape of the returned - array will depend on the input shape. For example, in the linear - case, if the input is a scalar, the returned array will be of - the shape (1, 4, 1) but if the input is a vector then the output will - be of the shape (len(x), 4, 1). + array will depend on the input shape. For example, in the linear + case, if the input is a scalar, the returned array will be of + the shape `(1, 4, 1)` but if the input is a vector then the output will + be of the shape `(len(x), 4, 1)`. """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + if xi.ndim != 3: raise ValueError( f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -666,27 +756,27 @@ def shapefn(self, xi: Sequence[float]): result = result.transpose(1, 0, 2)[..., jnp.newaxis] return result - def _shapefn_natural_grad(self, xi: float | jnp.ndarray): - """ - Calculate the gradient of shape function. + def _shapefn_natural_grad(self, xi: ArrayLike): + """Calculate the gradient of shape function. This calculation is done in the natural coordinates. - Arguments - --------- + Parameters + ---------- x : float, array_like Locations of particles in natural coordinates to evaluate - the function at. + the function at. Returns ------- array_like Evaluated gradient values of the shape function. The shape of - the returned array will depend on the input shape. For example, - in the linear case, if the input is a scalar, the returned array - will be of the shape (4, 2). + the returned array will depend on the input shape. For example, + in the linear case, if the input is a scalar, the returned array + will be of the shape `(4, 2)`. """ # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze() + xi = jnp.asarray(xi) xi = xi.squeeze() result = jnp.array( [ @@ -698,25 +788,26 @@ def _shapefn_natural_grad(self, xi: float | jnp.ndarray): ) return result - def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): - """ - Gradient of shape function in physical coordinates. + def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike): + """Gradient of shape function in physical coordinates. - Arguments - --------- + Parameters + ---------- xi : float, array_like Locations of particles to evaluate in natural coordinates. - Expected shape (npoints, 1, ndim). + Expected shape `(npoints, 1, ndim)`. coords : array_like Nodal coordinates to transform by. Expected shape - (npoints, 1, ndim) + `(npoints, 1, ndim)` Returns ------- array_like Gradient of the shape function in physical coordinates at `xi` """ - if len(xi.shape) != 3: + xi = jnp.asarray(xi) + coords = jnp.asarray(coords) + if xi.ndim != 3: raise ValueError( f"`x` should be of size (npoints, 1, ndim); found {xi.shape}" ) @@ -726,9 +817,8 @@ def shapefn_grad(self, xi: float | jnp.ndarray, coords: jnp.ndarray): result = grad_sf @ jnp.linalg.inv(_jacobian).T return result - def set_particle_element_ids(self, particles): - """ - Set the element IDs for the particles. + def set_particle_element_ids(self, particles: Particles): + """Set the element IDs for the particles. If the particle doesn't lie between the boundaries of any element, it sets the element index to -1. @@ -749,17 +839,20 @@ def f(x): ids = vmap(f)(particles.loc) particles.element_ids = ids - def compute_internal_force(self, particles): - r""" - Update the nodal internal force based on particle mass. + def compute_internal_force(self, particles: Particles): + r"""Update the nodal internal force based on particle mass. The nodal force is updated as a sum of internal forces for all particles mapped to the node. - :math:`(f_{int})_i = -\sum_p V_p * stress_p * \nabla N_i(x_p)` + \[ + (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) + \] + + where \(\sigma_p\) is the stress at particle \(p\). - Arguments - --------- + Parameters + ---------- particles: diffmpm.particle.Particles Particles to map to the nodal values. """ @@ -808,14 +901,9 @@ def _step(pid, args): self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args) def compute_volume(self, *args): + """Compute volume of all elements.""" a = c = self.el_len[1] b = d = self.el_len[0] p = q = jnp.sqrt(a**2 + b**2) vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2) self.volume = self.volume.at[:].set(vol) - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(Linear1D(2, 1, jnp.array([0]))) diff --git a/diffmpm/forces.py b/diffmpm/forces.py index eb6d27f..6740462 100644 --- a/diffmpm/forces.py +++ b/diffmpm/forces.py @@ -1,17 +1,61 @@ -from collections import namedtuple +from typing import Annotated, NamedTuple, get_type_hints + +from jax import Array from jax.tree_util import register_pytree_node -NodalForce = namedtuple("NodalForce", ("node_ids", "function", "dir", "force")) -ParticleTraction = namedtuple( - "ParticleTraction", ("pset", "pids", "function", "dir", "traction") -) +from diffmpm.functions import Function + + +class NodalForce(NamedTuple): + """Nodal Force being applied constantly on a set of nodes.""" + + node_ids: Annotated[Array, "Array of Node IDs to which force is applied."] + function: Annotated[ + Function, + "Mathematical function that governs time-varying changes in the force.", + ] + dir: Annotated[int, "Direction in which force is applied."] + force: Annotated[float, "Amount of force to be applied."] + + +nfhints = get_type_hints(NodalForce, include_extras=True) +for attr in nfhints: + getattr(NodalForce, attr).__doc__ = "".join(nfhints[attr].__metadata__) + + +class ParticleTraction(NamedTuple): + """Traction being applied on a set of particles.""" + + pset: Annotated[ + int, "The particle set in which traction is applied to the particles." + ] + pids: Annotated[ + Array, + "Array of Particle IDs to which traction is applied inside the particle set.", + ] + function: Annotated[ + Function, + "Mathematical function that governs time-varying changes in the traction.", + ] + dir: Annotated[int, "Direction in which traction is applied."] + traction: Annotated[float, "Amount of traction to be applied."] + + +pthints = get_type_hints(ParticleTraction, include_extras=True) +for attr in pthints: + getattr(ParticleTraction, attr).__doc__ = "".join(pthints[attr].__metadata__) + register_pytree_node( NodalForce, - lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable - lambda _, xs: NodalForce(*xs), # tell JAX how to pack back into a NodalForce + # tell JAX how to unpack to an iterable + lambda xs: (tuple(xs), None), # type: ignore + # tell JAX how to pack back into a NodalForce + lambda _, xs: NodalForce(*xs), # type: ignore ) register_pytree_node( ParticleTraction, - lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable - lambda _, xs: ParticleTraction(*xs), # tell JAX how to pack back + # tell JAX how to unpack to an iterable + lambda xs: (tuple(xs), None), # type: ignore + # tell JAX how to pack back + lambda _, xs: ParticleTraction(*xs), # type: ignore ) diff --git a/diffmpm/functions.py b/diffmpm/functions.py index 44880dc..90b55c4 100644 --- a/diffmpm/functions.py +++ b/diffmpm/functions.py @@ -1,4 +1,5 @@ import abc + import jax.numpy as jnp from jax.tree_util import register_pytree_node_class diff --git a/diffmpm/io.py b/diffmpm/io.py index 358b142..d6e4573 100644 --- a/diffmpm/io.py +++ b/diffmpm/io.py @@ -9,7 +9,7 @@ from diffmpm import mesh as mpmesh from diffmpm.constraint import Constraint from diffmpm.forces import NodalForce, ParticleTraction -from diffmpm.functions import Unit, Linear +from diffmpm.functions import Linear, Unit from diffmpm.particle import Particles diff --git a/diffmpm/material.py b/diffmpm/material.py index 2dd8487..09230d4 100644 --- a/diffmpm/material.py +++ b/diffmpm/material.py @@ -1,19 +1,20 @@ -from jax.tree_util import register_pytree_node_class import abc +from typing import Tuple + import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class class Material(abc.ABC): """Base material class.""" - _props = () + _props: Tuple[str, ...] def __init__(self, material_properties): - """ - Initialize material properties. + """Initialize material properties. - Arguments - --------- + Parameters + ---------- material_properties: dict A key-value map for various material properties. """ @@ -57,11 +58,10 @@ class LinearElastic(Material): _props = ("density", "youngs_modulus", "poisson_ratio") def __init__(self, material_properties): - """ - Create a Linear Elastic material. + """Create a Linear Elastic material. - Arguments - --------- + Parameters + ---------- material_properties: dict Dictionary with material properties. For linear elastic materials, 'density' and 'youngs_modulus' are required keys. @@ -111,9 +111,7 @@ def _compute_elastic_tensor(self): ) def compute_stress(self, dstrain): - """ - Compute material stress. - """ + """Compute material stress.""" dstress = self.de @ dstrain return dstress @@ -131,9 +129,3 @@ def __repr__(self): def compute_stress(self, dstrain): return dstrain * self.properties["E"] - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(SimpleMaterial({"E": 2, "density": 1})) diff --git a/diffmpm/mesh.py b/diffmpm/mesh.py index 6fa6abd..23bc6de 100644 --- a/diffmpm/mesh.py +++ b/diffmpm/mesh.py @@ -1,5 +1,5 @@ import abc -from typing import Iterable +from typing import Callable, Sequence, Tuple import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -7,40 +7,64 @@ from diffmpm.element import _Element from diffmpm.particle import Particles +__all__ = ["_MeshBase", "Mesh1D", "Mesh2D"] + class _MeshBase(abc.ABC): - """ - Base class for Meshes. + """Base class for Meshes. - Note: If attributes other than elements and particles are added - then the child class should also implement `tree_flatten` and - `tree_unflatten` correctly or that information will get lost. + .. note:: + If attributes other than elements and particles are added + then the child class should also implement `tree_flatten` and + `tree_unflatten` correctly or that information will get lost. """ + ndim: int + def __init__(self, config: dict): """Initialize mesh using configuration.""" - self.particles: Iterable[Particles, ...] = config["particles"] + self.particles: Sequence[Particles] = config["particles"] self.elements: _Element = config["elements"] self.particle_tractions = config["particle_surface_traction"] - @property - @abc.abstractmethod - def ndim(self): - ... - # TODO: Convert to using jax directives for loop - def apply_on_elements(self, function, args=()): + def apply_on_elements(self, function: str, args: Tuple = ()): + """Apply a given function to elements. + + Parameters + ---------- + function: str + A string corresponding to a function name in `_Element`. + args: tuple + Parameters to be passed to the function. + """ f = getattr(self.elements, function) for particle_set in self.particles: f(particle_set, *args) # TODO: Convert to using jax directives for loop - def apply_on_particles(self, function, args=()): + def apply_on_particles(self, function: str, args: Tuple = ()): + """Apply a given function to particles. + + Parameters + ---------- + function: str + A string corresponding to a function name in `Particles`. + args: tuple + Parameters to be passed to the function. + """ for particle_set in self.particles: f = getattr(particle_set, function) f(self.elements, *args) - def apply_traction_on_particles(self, curr_time): + def apply_traction_on_particles(self, curr_time: float): + """Apply tractions on particles. + + Parameters + ---------- + curr_time: float + Current time in the simulation. + """ self.apply_on_particles("zero_traction") for ptraction in self.particle_tractions: factor = ptraction.function.value(curr_time) @@ -50,7 +74,6 @@ def apply_traction_on_particles(self, curr_time): ptraction.pids, ptraction.dir, traction_val ) - # breakpoint() self.apply_on_elements("apply_particle_traction_forces") def tree_flatten(self): @@ -74,52 +97,30 @@ class Mesh1D(_MeshBase): """1D Mesh class with nodes, elements, and particles.""" def __init__(self, config: dict): - """ - Initialize a 1D Mesh. + """Initialize a 1D Mesh. - Arguments - --------- + Parameters + ---------- config: dict Configuration to be used for initialization. It _should_ - contain `elements` and `particles` keys. + contain `elements` and `particles` keys. """ + self.ndim = 1 super().__init__(config) - @property - def ndim(self): - return 1 - @register_pytree_node_class class Mesh2D(_MeshBase): """1D Mesh class with nodes, elements, and particles.""" def __init__(self, config: dict): - """ - Initialize a 2D Mesh. + """Initialize a 2D Mesh. - Arguments - --------- + Parameters + ---------- config: dict Configuration to be used for initialization. It _should_ - contain `elements` and `particles` keys. + contain `elements` and `particles` keys. """ + self.ndim = 2 super().__init__(config) - - @property - def ndim(self): - return 2 - - -if __name__ == "__main__": - from diffmpm.element import Linear1D - from diffmpm.material import SimpleMaterial - from diffmpm.utils import _show_example - - particles = Particles( - jnp.array([[[1]]]), - SimpleMaterial({"E": 2, "density": 1}), - jnp.array([0]), - ) - elements = Linear1D(2, 1, jnp.array([0])) - _show_example(Mesh1D({"particles": [particles], "elements": elements})) diff --git a/diffmpm/node.py b/diffmpm/node.py index 14396a4..46e2a60 100644 --- a/diffmpm/node.py +++ b/diffmpm/node.py @@ -1,13 +1,13 @@ -from typing import Tuple +from typing import Optional, Sized, Tuple import jax.numpy as jnp from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike @register_pytree_node_class -class Nodes: - """ - Nodes container class. +class Nodes(Sized): + """Nodes container class. Keeps track of all values required for nodal points. @@ -15,50 +15,51 @@ class Nodes: ---------- nnodes : int Number of nodes stored. - loc : array_like + loc : ArrayLike Location of all the nodes. velocity : array_like Velocity of all the nodes. - mass : array_like + mass : ArrayLike Mass of all the nodes. momentum : array_like Momentum of all the nodes. - f_int : array_like + f_int : ArrayLike Internal forces on all the nodes. - f_ext : array_like + f_ext : ArrayLike External forces present on all the nodes. - f_damp : array_like + f_damp : ArrayLike Damping forces on the nodes. """ def __init__( self, nnodes: int, - loc: jnp.ndarray, - initialized: bool = None, - data: Tuple[jnp.ndarray, ...] = tuple(), + loc: ArrayLike, + initialized: Optional[bool] = None, + data: Tuple[ArrayLike, ...] = tuple(), ): - """ - Initialize container for Nodes. + """Initialize container for Nodes. Parameters ---------- nnodes : int Number of nodes stored. - loc : array_like + loc : ArrayLike Locations of all the nodes. Expected shape (nnodes, 1, ndim) initialized: bool - False if node property arrays like mass need to be initialized. - If True, they are set to values from `data`. + `False` if node property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. data: tuple Tuple of length 7 that sets arrays for mass, density, volume, + and forces. Mainly used by JAX while unflattening. """ self.nnodes = nnodes - if len(loc.shape) != 3: + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3: raise ValueError( f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}" ) - self.loc = jnp.asarray(loc, dtype=jnp.float32) + self.loc = loc if initialized is None: self.velocity = jnp.zeros_like(self.loc, dtype=jnp.float32) @@ -77,11 +78,11 @@ def __init__( self.f_int, self.f_ext, self.f_damp, - ) = data + ) = data # type: ignore self.initialized = True def tree_flatten(self): - """Helper method for registering class as Pytree type.""" + """Flatten class as Pytree type.""" children = ( self.loc, self.initialized, @@ -98,9 +99,8 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): - return cls( - aux_data[0], children[0], initialized=children[1], data=children[2:] - ) + """Unflatten class from Pytree type.""" + return cls(aux_data[0], children[0], initialized=children[1], data=children[2:]) def reset_values(self): """Reset nodal parameter values except location.""" @@ -123,9 +123,3 @@ def __repr__(self): def get_total_force(self): """Calculate total force on the nodes.""" return self.f_int + self.f_ext + self.f_damp - - -if __name__ == "__main__": - from diffmpm.utils import _show_example - - _show_example(Nodes(2, jnp.array([1, 2]).reshape(2, 1, 1))) diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 0df598c..1bb3d70 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -1,49 +1,50 @@ -from typing import Tuple +from typing import Optional, Sized, Tuple import jax.numpy as jnp -from jax import vmap, lax +from jax import lax, vmap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike from diffmpm.element import _Element from diffmpm.material import Material @register_pytree_node_class -class Particles: +class Particles(Sized): """Container class for a set of particles.""" def __init__( self, - loc: jnp.ndarray, + loc: ArrayLike, material: Material, - element_ids: jnp.ndarray, - initialized: bool = None, - data: Tuple[jnp.ndarray, ...] = None, + element_ids: ArrayLike, + initialized: Optional[bool] = None, + data: Optional[Tuple[ArrayLike, ...]] = None, ): - """ - Initialize a container of particles. + """Initialize a container of particles. - Arguments - --------- - loc: jax.numpy.ndarray + Parameters + ---------- + loc: ArrayLike Location of the particles. Expected shape (nparticles, 1, ndim) material: diffmpm.material.Material Type of material for the set of particles. - element_ids: jax.numpy.ndarray + element_ids: ArrayLike The element ids that the particles belong to. This contains - information that will make sense only with the information of - the mesh that is being considered. + information that will make sense only with the information of + the mesh that is being considered. initialized: bool - False if particle property arrays like mass need to be initialized. - If True, they are set to values from `data`. + `False` if particle property arrays like mass need to be initialized. + If `True`, they are set to values from `data`. data: tuple Tuple of length 13 that sets arrays for mass, density, volume, - velocity, acceleration, momentum, strain, stress, strain_rate, - dstrain, f_ext, reference_loc and volumetric_strain_centroid. + velocity, acceleration, momentum, strain, stress, strain_rate, + dstrain, f_ext, reference_loc and volumetric_strain_centroid. """ self.material = material self.element_ids = element_ids - if len(loc.shape) != 3: + loc = jnp.asarray(loc, dtype=jnp.float32) + if loc.ndim != 3: raise ValueError( f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}" ) @@ -86,10 +87,11 @@ def __init__( self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, - ) = data + ) = data # type: ignore self.initialized = True def tree_flatten(self): + """Flatten class as Pytree type.""" children = ( self.loc, self.element_ids, @@ -116,6 +118,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): + """Unflatten class from Pytree type.""" return cls( children[0], aux_data[0], @@ -124,24 +127,24 @@ def tree_unflatten(cls, aux_data, children): data=children[3:], ) - def __len__(self): + def __len__(self) -> int: """Set length of the class as number of particles.""" return self.loc.shape[0] - def __repr__(self): + def __repr__(self) -> str: """Informative repr showing number of particles.""" return f"Particles(nparticles={len(self)})" - def set_mass_volume(self, m: float | jnp.ndarray): - """ - Set particle mass. + def set_mass_volume(self, m: ArrayLike): + """Set particle mass. - Arguments - --------- + Parameters + ---------- m: float, array_like Mass to be set for particles. If scalar, mass for all - particles is set to this value. + particles is set to this value. """ + m = jnp.asarray(m) if jnp.isscalar(m): self.mass = jnp.ones_like(self.loc) * m elif m.shape == self.mass.shape: @@ -152,12 +155,22 @@ def set_mass_volume(self, m: float | jnp.ndarray): ) self.volume = jnp.divide(self.mass, self.material.properties["density"]) - def compute_volume(self, elements, total_elements): + def compute_volume(self, elements: _Element, total_elements: int): + """Compute volume of all particles. + + Parameters + ---------- + elements: diffmpm._Element + Elements that the particles are present in, and are used to + compute the particles' volumes. + total_elements: int + Total elements present in `elements`. + """ particles_per_element = jnp.bincount( self.element_ids, length=elements.total_elements ) vol = ( - elements.volume.squeeze((1, 2))[self.element_ids] + elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore / particles_per_element[self.element_ids] ) self.volume = self.volume.at[:, 0, 0].set(vol) @@ -165,24 +178,26 @@ def compute_volume(self, elements, total_elements): self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze()) def update_natural_coords(self, elements: _Element): - """ - Update natural coordinates for the particles. + r"""Update natural coordinates for the particles. Whenever the particles' physical coordinates change, their natural coordinates need to be updated. This function updates the natural coordinates of the particles based on the element a particle is a part of. The update formula is - :math:`xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)` + \[ + \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e) + \] - If a particle is not in any element (element_id = -1), its - natural coordinate is set to 0. + where \(x_i^e\) are the nodal coordinates of the element the + particle is in. If a particle is not in any element + (element_id = -1), its natural coordinate is set to 0. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements based on which to update the natural coordinates - of the particles. + of the particles. """ t = vmap(elements.id_to_node_loc)(self.element_ids) xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * ( @@ -193,21 +208,20 @@ def update_natural_coords(self, elements: _Element): def update_position_velocity( self, elements: _Element, dt: float, velocity_update: bool ): - """ - Transfer nodal velocity to particles and update particle position. + """Transfer nodal velocity to particles and update particle position. The velocity is calculated based on the total force at nodes. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements whose nodes are used to transfer the velocity. dt: float Timestep. velocity_update: bool If True, velocity is directly used as nodal velocity, else - velocity is calculated is interpolated nodal acceleration - multiplied by dt. Default is False. + velocity is calculated is interpolated nodal acceleration + multiplied by dt. Default is False. """ mapped_positions = elements.shapefn(self.reference_loc) mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1) @@ -233,14 +247,13 @@ def update_position_velocity( self.momentum = self.momentum.at[:].set(self.mass * self.velocity) def compute_strain(self, elements: _Element, dt: float): - """ - Compute the strain on all particles. + """Compute the strain on all particles. This is done by first calculating the strain rate for the particles - and then calculating strain as strain += strain rate * dt. + and then calculating strain as `strain += strain rate * dt`. - Arguments - --------- + Parameters + ---------- elements: diffmpm.element._Element Elements whose nodes are used to calculate the strain. dt : float @@ -265,18 +278,18 @@ def compute_strain(self, elements: _Element, dt: float): self.dvolumetric_strain ) - def _compute_strain_rate(self, dn_dx: jnp.ndarray, elements: _Element): - """ - Compute the strain rate for particles. + def _compute_strain_rate(self, dn_dx: ArrayLike, elements: _Element): + """Compute the strain rate for particles. - Arguments - --------- - dn_dx: jnp.ndarray - The gradient of the shape function. - Expected shape (nparticles, 1, ndim) + Parameters + ---------- + dn_dx: ArrayLike + The gradient of the shape function. Expected shape + `(nparticles, 1, ndim)` elements: diffmpm.element._Element Elements whose nodes are used to calculate the strain rate. """ + dn_dx = jnp.asarray(dn_dx) strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1) mapped_vel = vmap(elements.id_to_node_vel)( self.element_ids @@ -300,8 +313,7 @@ def _step(pid, args): return strain_rate def compute_stress(self, *args): - """ - Compute the strain on all particles. + """Compute the strain on all particles. This calculation is governed by the material of the particles. The stress calculated by the material is then @@ -314,23 +326,22 @@ def update_volume(self, *args): self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain) self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain) - def assign_traction(self, pids, dir, traction_): + def assign_traction(self, pids: ArrayLike, dir: int, traction_: float): + """Assign traction to particles. + + Parameters + ---------- + pids: ArrayLike + IDs of the particles to which traction should be applied. + dir: int + The direction in which traction should be applied. + traction_: float + Traction value to be applied in the direction. + """ self.traction = self.traction.at[pids, 0, dir].add( traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir] ) def zero_traction(self, *args): + """Set all traction values to 0.""" self.traction = self.traction.at[:].set(0) - - -if __name__ == "__main__": - from diffmpm.material import SimpleMaterial - from diffmpm.utils import _show_example - - _show_example( - Particles( - jnp.array([[[1]]]), - SimpleMaterial({"E": 2, "density": 1}), - jnp.array([0]), - ) - ) diff --git a/diffmpm/scheme.py b/diffmpm/scheme.py index 83e35ca..61a062e 100644 --- a/diffmpm/scheme.py +++ b/diffmpm/scheme.py @@ -1,3 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from jax.typing import ArrayLike + +if TYPE_CHECKING: + import jax.numpy as jnp + from diffmpm.mesh import _MeshBase + import abc _schemes = ("usf", "usl") @@ -10,6 +20,7 @@ def __init__(self, mesh, dt, velocity_update): self.dt = dt def compute_nodal_kinematics(self): + """Compute nodal kinematics - map mass and momentum to mesh nodes.""" self.mesh.apply_on_elements("set_particle_element_ids") self.mesh.apply_on_particles("update_natural_coords") self.mesh.apply_on_elements("compute_nodal_mass") @@ -18,11 +29,23 @@ def compute_nodal_kinematics(self): self.mesh.apply_on_elements("apply_boundary_constraints") def compute_stress_strain(self): + """Compute stress and strain on the particles.""" self.mesh.apply_on_particles("compute_strain", args=(self.dt,)) self.mesh.apply_on_particles("update_volume") self.mesh.apply_on_particles("compute_stress") - def compute_forces(self, gravity, step): + def compute_forces(self, gravity: ArrayLike, step: int): + """Compute the forces acting in the system. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + step: int + Current step being simulated. + """ self.mesh.apply_on_elements("compute_external_force") self.mesh.apply_on_elements("compute_body_force", args=(gravity,)) self.mesh.apply_traction_on_particles(step * self.dt) @@ -33,6 +56,7 @@ def compute_forces(self, gravity, step): # self.mesh.apply_on_elements("apply_force_boundary_constraints") def compute_particle_kinematics(self): + """Compute particle location, acceleration and velocity.""" self.mesh.apply_on_elements( "update_nodal_acceleration_velocity", args=(self.dt,) ) @@ -43,24 +67,40 @@ def compute_particle_kinematics(self): # TODO: Apply particle velocity constraints. @abc.abstractmethod - def precompute_stress_strain(): + def precompute_stress_strain(self): ... @abc.abstractmethod - def postcompute_stress_strain(): + def postcompute_stress_strain(self): ... class USF(_MPMScheme): """USF Scheme solver.""" - def __init__(self, mesh, dt, velocity_update): + def __init__(self, mesh: _MeshBase, dt: float, velocity_update: bool): + """Initialize USF Scheme solver. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + """ super().__init__(mesh, dt, velocity_update) def precompute_stress_strain(self): + """Compute stress and strain on particles.""" self.compute_stress_strain() def postcompute_stress_strain(self): + """Compute stress and strain on particles. (Empty call for USF).""" pass @@ -68,10 +108,26 @@ class USL(_MPMScheme): """USL Scheme solver.""" def __init__(self, mesh, dt, velocity_update): + """Initialize USL Scheme solver. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + """ super().__init__(mesh, dt, velocity_update) def precompute_stress_strain(self): + """Compute stress and strain on particles. (Empty call for USL).""" pass def postcompute_stress_strain(self): + """Compute stress and strain on particles.""" self.compute_stress_strain() diff --git a/diffmpm/solver.py b/diffmpm/solver.py index de4624e..3b1ae01 100644 --- a/diffmpm/solver.py +++ b/diffmpm/solver.py @@ -1,33 +1,73 @@ +from __future__ import annotations + import functools -from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional import jax.numpy as jnp from jax import lax from jax.experimental.host_callback import id_tap from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike + +from diffmpm.scheme import USF, USL, _MPMScheme, _schemes -from diffmpm.scheme import USF, USL, _schemes +if TYPE_CHECKING: + from diffmpm.mesh import _MeshBase @register_pytree_node_class class MPMExplicit: + """A class to implement the fully explicit MPM.""" + __particle_props = ("loc", "velocity", "stress", "strain") def __init__( self, - mesh, - dt, - scheme="usf", - velocity_update=False, - sim_steps=1, - out_steps=1, - out_dir="results/", - writer_func=None, - ): + mesh: _MeshBase, + dt: float, + scheme: str = "usf", + velocity_update: bool = False, + sim_steps: int = 1, + out_steps: int = 1, + out_dir: str = "results/", + writer_func: Optional[Callable] = None, + ) -> None: + """Create an `MPMExplicit` object. + + This can be used to solve a given configuration of an MPM + problem. + + Parameters + ---------- + mesh: _MeshBase + A `diffmpm.Mesh` object that contains the elements that form + the underlying mesh used to solve the simulation. + dt: float + Timestep used in the simulation. + scheme: str + The MPM Scheme type used for the simulation. Can be one of + `"usl"` or `"usf"`. Default set to `"usf"`. + velocity_update: bool + Flag to control if velocity should be updated using nodal + velocity or interpolated nodal acceleration. If `True`, nodal + velocity is used, else nodal acceleration. Default `False`. + sim_steps: int + Number of steps to run the simulation for. Default set to 1. + out_steps: int + Frequency with which to store the results. For example, if + set to 5, the result at every 5th step will be stored. Default + set to 1. + out_dir: str + Path to the output directory where results are stored. + writer_func: Callable, None + Function that is used to write the state in the output + directory. + """ + if scheme == "usf": - self.mpm_scheme = USF(mesh, dt, velocity_update) + self.mpm_scheme: _MPMScheme = USF(mesh, dt, velocity_update) # type: ignore elif scheme == "usl": - self.mpm_scheme = USL(mesh, dt, velocity_update) + self.mpm_scheme: _MPMScheme = USL(mesh, dt, velocity_update) # type: ignore else: raise ValueError(f"Please select scheme from {_schemes}. Found {scheme}") self.mesh = mesh @@ -70,13 +110,35 @@ def tree_unflatten(cls, aux_data, children): writer_func=aux_data["writer_func"], ) - def jax_writer(self, func, args): + def _jax_writer(self, func, args): id_tap(func, args) - def solve(self, gravity: float | jnp.ndarray): + def solve(self, gravity: ArrayLike): + """Non-JIT solve method. + + This method runs the entire simulation for the defined number + of steps. + + .. note:: + This is mainly used for debugging and might be removed in + future versions or moved to the JIT solver. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + + Returns + ------- + dict + A dictionary of `ArrayLike` arrays corresponding to the + all states of the simulation after completing all steps. + """ from collections import defaultdict - from tqdm import tqdm + from tqdm import tqdm # type: ignore result = defaultdict(list) for step in tqdm(range(self.sim_steps)): @@ -91,10 +153,29 @@ def solve(self, gravity: float | jnp.ndarray): result["stress"].append(pset.stress[:, :2, 0]) result["strain"].append(pset.strain[:, :2, 0]) - result = {k: jnp.asarray(v) for k, v in result.items()} - return result + result_arr = {k: jnp.asarray(v) for k, v in result.items()} + return result_arr + + def solve_jit(self, gravity: ArrayLike) -> dict: + """Solver method that runs the simulation. + + This method runs the entire simulation for the defined number + of steps. + + Parameters + ---------- + gravity: ArrayLike + Gravity present in the system. This should be an array equal + with shape `(1, ndim)` where `ndim` is the dimension of the + simulation. + + Returns + ------- + dict + A dictionary of `jax.numpy` arrays corresponding to the + final state of the simulation after completing all steps. + """ - def solve_jit(self, gravity: float | jnp.ndarray): def _step(i, data): self = data self.mpm_scheme.compute_nodal_kinematics() @@ -112,7 +193,7 @@ def _write(self, i): for j in range(len(self.mesh.particles)) ] ) - self.jax_writer( + self._jax_writer( functools.partial( self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps ), diff --git a/diffmpm/utils.py b/diffmpm/utils.py deleted file mode 100644 index 6559036..0000000 --- a/diffmpm/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from jax.tree_util import tree_flatten, tree_unflatten - - -def _show_example(structured): - flat, tree = tree_flatten(structured) - unflattened = tree_unflatten(tree, flat) - print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}") diff --git a/diffmpm/writers.py b/diffmpm/writers.py index 4038b52..fdc5cd2 100644 --- a/diffmpm/writers.py +++ b/diffmpm/writers.py @@ -1,24 +1,45 @@ import abc import logging -import numpy as np from pathlib import Path +from typing import Tuple, Annotated, Any +from jax.typing import ArrayLike +import numpy as np + logger = logging.getLogger(__file__) +__all__ = ["_Writer", "EmptyWriter", "NPZWriter"] + + +class _Writer(abc.ABC): + """Base writer class.""" -class Writer(abc.ABC): @abc.abstractmethod def write(self): ... -class EmptyWriter(Writer): +class EmptyWriter(_Writer): + """Empty writer used when output is not to be written.""" + def write(self, args, transforms, **kwargs): + """Empty function.""" pass -class NPZWriter(Writer): - def write(self, args, transforms, **kwargs): +class NPZWriter(_Writer): + """Writer to write output in `.npz` format.""" + + def write( + self, + args: Tuple[ + Annotated[ArrayLike, "JAX arrays to be written"], + Annotated[int, "step number of the simulation"], + ], + transforms: Any, + **kwargs, + ): + """Writes the output arrays as `.npz` files.""" arrays, step = args max_digits = int(np.log10(kwargs["max_steps"])) + 1 if step == 0: diff --git a/docs/diffmpm/cli/index.html b/docs/diffmpm/cli/index.html new file mode 100644 index 0000000..493c83c --- /dev/null +++ b/docs/diffmpm/cli/index.html @@ -0,0 +1,66 @@ + + +
+ + + +diffmpm.clidiffmpm.cli.mpmdiffmpm.cli.mpmimport click
+
+from diffmpm import MPM
+
+
+@click.command() # type: ignore
+@click.option(
+ "-f", "--file", "filepath", required=True, type=str, help="Input TOML file"
+)
+@click.version_option(package_name="diffmpm")
+def mpm(filepath):
+ """CLI utility for MPM."""
+ solver = MPM(filepath)
+ solver.solve()
+diffmpm.constraintfrom jax.tree_util import register_pytree_node_class
+
+
+@register_pytree_node_class
+class Constraint:
+ """Generic velocity constraints to apply on nodes or particles."""
+
+ def __init__(self, dir: int, velocity: float):
+ """Contains 2 govering parameters.
+
+ Attributes
+ ----------
+ dir : int
+ Direction in which constraint is applied.
+ velocity : float
+ Constrained velocity to be applied.
+ """
+ self.dir = dir
+ self.velocity = velocity
+
+ def tree_flatten(self):
+ return ((), (self.dir, self.velocity))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+ def apply(self, obj, ids):
+ """Apply constraint values to the passed object.
+
+ Parameters
+ ----------
+ obj : diffmpm.node.Nodes, diffmpm.particle.Particles
+ Object on which the constraint is applied
+ ids : array_like
+ The indices of the container `obj` on which the constraint
+ will be applied.
+ """
+ obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
+ obj.momentum = obj.momentum.at[ids, :, self.dir].set(
+ obj.mass[ids, :, 0] * self.velocity
+ )
+ obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0)
+
+class Constraint
+(dir: int, velocity: float)
+Generic velocity constraints to apply on nodes or particles.
+Contains 2 govering parameters.
+dir : intvelocity : float@register_pytree_node_class
+class Constraint:
+ """Generic velocity constraints to apply on nodes or particles."""
+
+ def __init__(self, dir: int, velocity: float):
+ """Contains 2 govering parameters.
+
+ Attributes
+ ----------
+ dir : int
+ Direction in which constraint is applied.
+ velocity : float
+ Constrained velocity to be applied.
+ """
+ self.dir = dir
+ self.velocity = velocity
+
+ def tree_flatten(self):
+ return ((), (self.dir, self.velocity))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+ def apply(self, obj, ids):
+ """Apply constraint values to the passed object.
+
+ Parameters
+ ----------
+ obj : diffmpm.node.Nodes, diffmpm.particle.Particles
+ Object on which the constraint is applied
+ ids : array_like
+ The indices of the container `obj` on which the constraint
+ will be applied.
+ """
+ obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
+ obj.momentum = obj.momentum.at[ids, :, self.dir].set(
+ obj.mass[ids, :, 0] * self.velocity
+ )
+ obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0)
+
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+def apply(self, obj, ids)
+Apply constraint values to the passed object.
+def apply(self, obj, ids):
+ """Apply constraint values to the passed object.
+
+ Parameters
+ ----------
+ obj : diffmpm.node.Nodes, diffmpm.particle.Particles
+ Object on which the constraint is applied
+ ids : array_like
+ The indices of the container `obj` on which the constraint
+ will be applied.
+ """
+ obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
+ obj.momentum = obj.momentum.at[ids, :, self.dir].set(
+ obj.mass[ids, :, 0] * self.velocity
+ )
+ obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0)
+
+def tree_flatten(self)
+def tree_flatten(self):
+ return ((), (self.dir, self.velocity))
+diffmpm.elementfrom __future__ import annotations
+
+import abc
+import itertools
+from typing import TYPE_CHECKING, Optional, Sequence, Tuple
+
+if TYPE_CHECKING:
+ from diffmpm.particle import Particles
+
+import jax.numpy as jnp
+from jax import Array, jacobian, jit, lax, vmap
+from jax.tree_util import register_pytree_node_class
+from jax.typing import ArrayLike
+
+from diffmpm.constraint import Constraint
+from diffmpm.node import Nodes
+
+__all__ = ["_Element", "Linear1D", "Quadrilateral4Node"]
+
+
+class _Element(abc.ABC):
+ """Base element class that is inherited by all types of Elements."""
+
+ nodes: Nodes
+ total_elements: int
+ concentrated_nodal_forces: Sequence
+ volume: Array
+
+ @abc.abstractmethod
+ def id_to_node_ids(self, id: ArrayLike) -> Array:
+ """Node IDs corresponding to element `id`.
+
+ This method is implemented by each of the subclass.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element.
+ """
+ ...
+
+ def id_to_node_loc(self, id: ArrayLike) -> Array:
+ """Node locations corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal locations for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.loc[node_ids]
+
+ def id_to_node_vel(self, id: ArrayLike) -> Array:
+ """Node velocities corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal velocities for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.velocity[node_ids]
+
+ def tree_flatten(self):
+ children = (self.nodes, self.volume)
+ aux_data = (
+ self.nelements,
+ self.total_elements,
+ self.el_len,
+ self.constraints,
+ self.concentrated_nodal_forces,
+ self.initialized,
+ )
+ return children, aux_data
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ aux_data[0],
+ aux_data[1],
+ aux_data[2],
+ aux_data[3],
+ nodes=children[0],
+ concentrated_nodal_forces=aux_data[4],
+ initialized=aux_data[5],
+ volume=children[1],
+ )
+
+ @abc.abstractmethod
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate Shape function for element type."""
+ ...
+
+ @abc.abstractmethod
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Evaluate gradient of shape function for element type."""
+ ...
+
+ @abc.abstractmethod
+ def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs that particles are present in."""
+ ...
+
+ # Mapping from particles to nodes (P2G)
+ def compute_nodal_mass(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (m)_i = \sum_p N_i(x_p) m_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmass, mass, mapped_pos, el_nodes = args
+ mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid])
+ return pmass, mass, mapped_pos, el_nodes
+
+ self.nodes.mass = self.nodes.mass.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass,
+ self.nodes.mass,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_nodal_momentum(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (mv)_i = \sum_p N_i(x_p) (mv)_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmom, mom, mapped_pos, el_nodes = args
+ mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid])
+ return pmom, mom, mapped_pos, el_nodes
+
+ self.nodes.momentum = self.nodes.momentum.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass * particles.velocity,
+ self.nodes.momentum,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
+ self.nodes.momentum = jnp.where(
+ jnp.abs(self.nodes.momentum) < 1e-12,
+ jnp.zeros_like(self.nodes.momentum),
+ self.nodes.momentum,
+ )
+
+ def compute_velocity(self, particles: Particles):
+ """Compute velocity using momentum."""
+ self.nodes.velocity = jnp.where(
+ self.nodes.mass == 0,
+ self.nodes.velocity,
+ self.nodes.momentum / self.nodes.mass,
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+
+ def compute_external_force(self, particles: Particles):
+ r"""Update the nodal external force based on particle f_ext.
+
+ The nodal force is updated as a sum of particle external
+ force for all particles mapped to the node.
+
+ \[
+ f_{ext})_i = \sum_p N_i(x_p) f_{ext}
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pf_ext, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid])
+ return f_ext, pf_ext, mapped_pos, el_nodes
+
+ self.nodes.f_ext = self.nodes.f_ext.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.f_ext,
+ mapped_positions,
+ mapped_nodes,
+ )
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_body_force(self, particles: Particles, gravity: ArrayLike):
+ r"""Update the nodal external force based on particle mass.
+
+ The nodal force is updated as a sum of particle body
+ force for all particles mapped to th
+
+ \[
+ (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pmass, mapped_pos, el_nodes, gravity = args
+ f_ext = f_ext.at[el_nodes[pid]].add(
+ mapped_pos[pid] @ (pmass[pid] * gravity)
+ )
+ return f_ext, pmass, mapped_pos, el_nodes, gravity
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.mass,
+ mapped_positions,
+ mapped_nodes,
+ gravity,
+ )
+ self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ curr_time: float
+ Current time in the simulation.
+ """
+ for cnf in self.concentrated_nodal_forces:
+ factor = cnf.function.value(curr_time)
+ self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add(
+ factor * cnf.force
+ )
+
+ def apply_particle_traction_forces(self, particles: Particles):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ """
+
+ def _step(pid, args):
+ f_ext, ptraction, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid])
+ return f_ext, ptraction, mapped_pos, el_nodes
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes)
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def update_nodal_acceleration_velocity(
+ self, particles: Particles, dt: float, *args
+ ):
+ """Update the nodal momentum based on total force on nodes."""
+ total_force = self.nodes.get_total_force()
+ self.nodes.acceleration = self.nodes.acceleration.at[:].set(
+ jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass))
+ )
+ self.nodes.velocity = self.nodes.velocity.at[:].add(
+ self.nodes.acceleration * dt
+ )
+ self.apply_boundary_constraints()
+ self.nodes.momentum = self.nodes.momentum.at[:].set(
+ self.nodes.mass * self.nodes.velocity
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+ self.nodes.acceleration = jnp.where(
+ jnp.abs(self.nodes.acceleration) < 1e-12,
+ jnp.zeros_like(self.nodes.acceleration),
+ self.nodes.acceleration,
+ )
+
+ def apply_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal velocity."""
+ for ids, constraint in self.constraints:
+ constraint.apply(self.nodes, ids)
+
+ def apply_force_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal forces."""
+ self.nodes.f_int = self.nodes.f_int.at[self.constraints[0][0]].set(0)
+ self.nodes.f_ext = self.nodes.f_ext.at[self.constraints[0][0]].set(0)
+ self.nodes.f_damp = self.nodes.f_damp.at[self.constraints[0][0]].set(0)
+
+
+@register_pytree_node_class
+class Linear1D(_Element):
+ """Container for 1D line elements (and nodes).
+
+ Element ID: 0 1 2 3
+ Mesh: +-----+-----+-----+-----+
+ Node IDs: 0 1 2 3 4
+
+ where
+
+ + : Nodes
+ +-----+ : An element
+
+ """
+
+ def __init__(
+ self,
+ nelements: int,
+ total_elements: int,
+ el_len: float,
+ constraints: Sequence[Tuple[ArrayLike, Constraint]],
+ nodes: Optional[Nodes] = None,
+ concentrated_nodal_forces: Sequence = [],
+ initialized: Optional[bool] = None,
+ volume: Optional[ArrayLike] = None,
+ ):
+ """Initialize Linear1D.
+
+ Parameters
+ ----------
+ nelements : int
+ Number of elements.
+ total_elements : int
+ Total number of elements (same as `nelements` for 1D)
+ el_len : float
+ Length of each element.
+ constraints: list
+ A list of constraints where each element is a tuple of type
+ `(node_ids, diffmpm.Constraint)`. Here, `node_ids` correspond to
+ the node IDs where `diffmpm.Constraint` should be applied.
+ nodes : Nodes, Optional
+ Nodes in the element object.
+ concentrated_nodal_forces: list
+ A list of `diffmpm.forces.NodalForce`s that are to be
+ applied.
+ initialized: bool, None
+ `True` if the class has been initialized, `None` if not.
+ This is required like this for using JAX flattening.
+ volume: ArrayLike
+ Volume of the elements.
+ """
+ self.nelements = nelements
+ self.total_elements = nelements
+ self.el_len = el_len
+ if nodes is None:
+ self.nodes = Nodes(
+ nelements + 1,
+ jnp.arange(nelements + 1).reshape(-1, 1, 1) * el_len,
+ )
+ else:
+ self.nodes = nodes
+
+ # self.boundary_nodes = boundary_nodes
+ self.constraints = constraints
+ self.concentrated_nodal_forces = concentrated_nodal_forces
+ if initialized is None:
+ self.volume = jnp.ones((self.total_elements, 1, 1))
+ else:
+ self.volume = jnp.asarray(volume)
+ self.initialized = True
+
+ def id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is `(2, 1)`
+ """
+ return jnp.array([id, id + 1]).reshape(2, 1)
+
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 2, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 2, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3)
+ return result
+
+ def _shapefn_natural_grad(self, xi: ArrayLike):
+ """Calculate the gradient of shape function.
+
+ This calculation is done in the natural coordinates.
+
+ Parameters
+ ----------
+ x : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at.
+
+ Returns
+ -------
+ array_like
+ Evaluated gradient values of the shape function. The shape of
+ the returned array will depend on the input shape. For example,
+ in the linear case, if the input is a scalar, the returned array
+ will be of the shape `(2, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze()
+
+ # TODO: The following code tries to evaluate vmap even if
+ # the predicate condition is true, not sure why.
+ # result = lax.cond(
+ # jnp.isscalar(x),
+ # jacobian(self.shapefn),
+ # vmap(jacobian(self.shapefn)),
+ # xi
+ # )
+ return result.reshape(2, 1)
+
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+
+ def set_particle_element_ids(self, particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ idl = (
+ len(self.nodes.loc)
+ - 1
+ - jnp.asarray(self.nodes.loc[::-1] <= x).nonzero(size=1, fill_value=-1)[
+ 0
+ ][-1]
+ )
+ idg = (
+ jnp.asarray(self.nodes.loc > x).nonzero(size=1, fill_value=-1)[0][0] - 1
+ )
+ return (idl, idg)
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = jnp.where(
+ ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1
+ )
+
+ def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ vol = jnp.ediff1d(self.nodes.loc)
+ self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
+
+ def compute_internal_force(self, particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ # TODO: correct matrix multiplication for n-d
+ # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
+ update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid]
+ f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis])
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+
+@register_pytree_node_class
+class Quadrilateral4Node(_Element):
+ r"""Container for 2D quadrilateral elements with 4 nodes.
+
+ Nodes and elements are numbered as
+
+ 15 +---+---+---+---+ 19
+ | 8 | 9 | 10| 11|
+ 10 +---+---+---+---+ 14
+ | 4 | 5 | 6 | 7 |
+ 5 +---+---+---+---+ 9
+ | 0 | 1 | 2 | 3 |
+ +---+---+---+---+
+ 0 1 2 3 4
+
+ where
+
+ + : Nodes
+ +---+
+ | | : An element
+ +---+
+ """
+
+ def __init__(
+ self,
+ nelements: int,
+ total_elements: int,
+ el_len: float,
+ constraints: Sequence[Tuple[ArrayLike, Constraint]],
+ nodes: Optional[Nodes] = None,
+ concentrated_nodal_forces: Sequence = [],
+ initialized: Optional[bool] = None,
+ volume: Optional[ArrayLike] = None,
+ ) -> None:
+ """Initialize Linear1D.
+
+ Parameters
+ ----------
+ nelements : int
+ Number of elements.
+ total_elements : int
+ Total number of elements (product of all elements of `nelements`)
+ el_len : float
+ Length of each element.
+ constraints: list
+ A list of constraints where each element is a tuple of
+ type `(node_ids, diffmpm.Constraint)`. Here, `node_ids`
+ correspond to the node IDs where `diffmpm.Constraint`
+ should be applied.
+ nodes : Nodes, Optional
+ Nodes in the element object.
+ concentrated_nodal_forces: list
+ A list of `diffmpm.forces.NodalForce`s that are to be
+ applied.
+ initialized: bool, None
+ `True` if the class has been initialized, `None` if not.
+ This is required like this for using JAX flattening.
+ volume: ArrayLike
+ Volume of the elements.
+ """
+ self.nelements = jnp.asarray(nelements)
+ self.el_len = jnp.asarray(el_len)
+ self.total_elements = total_elements
+
+ if nodes is None:
+ total_nodes = jnp.prod(self.nelements + 1)
+ coords = jnp.asarray(
+ list(
+ itertools.product(
+ jnp.arange(self.nelements[1] + 1),
+ jnp.arange(self.nelements[0] + 1),
+ )
+ )
+ )
+ node_locations = (
+ jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.el_len
+ ).reshape(-1, 1, 2)
+ self.nodes = Nodes(int(total_nodes), node_locations)
+ else:
+ self.nodes = nodes
+
+ self.constraints = constraints
+ self.concentrated_nodal_forces = concentrated_nodal_forces
+ if initialized is None:
+ self.volume = jnp.ones((self.total_elements, 1, 1))
+ else:
+ self.volume = jnp.asarray(volume)
+ self.initialized = True
+
+ def id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ 3----2
+ | |
+ 0----1
+
+ Node ids are returned in the order as shown in the figure.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is (4, 1)
+ """
+ lower_left = (id // self.nelements[0]) * (
+ self.nelements[0] + 1
+ ) + id % self.nelements[0]
+ result = jnp.asarray(
+ [
+ lower_left,
+ lower_left + 1,
+ lower_left + self.nelements[0] + 2,
+ lower_left + self.nelements[0] + 1,
+ ]
+ )
+ return result.reshape(4, 1)
+
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is (npoints, 1, ndim)
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 4, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 4, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array(
+ [
+ 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ ]
+ )
+ result = result.transpose(1, 0, 2)[..., jnp.newaxis]
+ return result
+
+ def _shapefn_natural_grad(self, xi: ArrayLike):
+ """Calculate the gradient of shape function.
+
+ This calculation is done in the natural coordinates.
+
+ Parameters
+ ----------
+ x : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at.
+
+ Returns
+ -------
+ array_like
+ Evaluated gradient values of the shape function. The shape of
+ the returned array will depend on the input shape. For example,
+ in the linear case, if the input is a scalar, the returned array
+ will be of the shape `(4, 2)`.
+ """
+ # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze()
+ xi = jnp.asarray(xi)
+ xi = xi.squeeze()
+ result = jnp.array(
+ [
+ [-0.25 * (1 - xi[1]), -0.25 * (1 - xi[0])],
+ [0.25 * (1 - xi[1]), -0.25 * (1 + xi[0])],
+ [0.25 * (1 + xi[1]), 0.25 * (1 + xi[0])],
+ [-0.25 * (1 + xi[1]), 0.25 * (1 - xi[0])],
+ ],
+ )
+ return result
+
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords.squeeze()
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+
+ def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ xidl = (self.nodes.loc[:, :, 0] <= x[0, 0]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ yidl = (self.nodes.loc[:, :, 1] <= x[0, 1]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max()
+ element_id = lower_left - lower_left // (self.nelements[0] + 1)
+ return element_id
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = ids
+
+ def compute_internal_force(self, particles: Particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ force = jnp.zeros((mapped_grads.shape[1], 1, 2))
+ force = force.at[:, 0, 0].set(
+ mapped_grads[pid][:, 0] * pstress[pid][0]
+ + mapped_grads[pid][:, 1] * pstress[pid][3]
+ )
+ force = force.at[:, 0, 1].set(
+ mapped_grads[pid][:, 1] * pstress[pid][1]
+ + mapped_grads[pid][:, 0] * pstress[pid][3]
+ )
+ update = -pvol[pid] * force
+ f_int = f_int.at[el_nodes[pid]].add(update)
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ a = c = self.el_len[1]
+ b = d = self.el_len[0]
+ p = q = jnp.sqrt(a**2 + b**2)
+ vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2)
+ self.volume = self.volume.at[:].set(vol)
+
+class Linear1D
+(nelements: int, total_elements: int, el_len: float, constraints: Sequence[Tuple[ArrayLike, Constraint]], nodes: Optional[Nodes] = None, concentrated_nodal_forces: Sequence = [], initialized: Optional[bool] = None, volume: Optional[ArrayLike] = None)
+Container for 1D line elements (and nodes).
+Element ID: 0 1 2 3
+Mesh: +-----+-----+-----+-----+
+Node IDs: 0 1 2 3 4
+
+where
++ : Nodes
++-----+ : An element
+
+Initialize Linear1D.
+nelements : inttotal_elements : intnelements for 1D)el_len : floatconstraints : list(node_ids, diffmpm.Constraint). Here, node_ids correspond to
+the node IDs where diffmpm.Constraint should be applied.nodes : Nodes, Optionalconcentrated_nodal_forces : listNodalForces that are to be
+applied.initialized : bool, NoneTrue if the class has been initialized, None if not.
+This is required like this for using JAX flattening.volume : ArrayLike@register_pytree_node_class
+class Linear1D(_Element):
+ """Container for 1D line elements (and nodes).
+
+ Element ID: 0 1 2 3
+ Mesh: +-----+-----+-----+-----+
+ Node IDs: 0 1 2 3 4
+
+ where
+
+ + : Nodes
+ +-----+ : An element
+
+ """
+
+ def __init__(
+ self,
+ nelements: int,
+ total_elements: int,
+ el_len: float,
+ constraints: Sequence[Tuple[ArrayLike, Constraint]],
+ nodes: Optional[Nodes] = None,
+ concentrated_nodal_forces: Sequence = [],
+ initialized: Optional[bool] = None,
+ volume: Optional[ArrayLike] = None,
+ ):
+ """Initialize Linear1D.
+
+ Parameters
+ ----------
+ nelements : int
+ Number of elements.
+ total_elements : int
+ Total number of elements (same as `nelements` for 1D)
+ el_len : float
+ Length of each element.
+ constraints: list
+ A list of constraints where each element is a tuple of type
+ `(node_ids, diffmpm.Constraint)`. Here, `node_ids` correspond to
+ the node IDs where `diffmpm.Constraint` should be applied.
+ nodes : Nodes, Optional
+ Nodes in the element object.
+ concentrated_nodal_forces: list
+ A list of `diffmpm.forces.NodalForce`s that are to be
+ applied.
+ initialized: bool, None
+ `True` if the class has been initialized, `None` if not.
+ This is required like this for using JAX flattening.
+ volume: ArrayLike
+ Volume of the elements.
+ """
+ self.nelements = nelements
+ self.total_elements = nelements
+ self.el_len = el_len
+ if nodes is None:
+ self.nodes = Nodes(
+ nelements + 1,
+ jnp.arange(nelements + 1).reshape(-1, 1, 1) * el_len,
+ )
+ else:
+ self.nodes = nodes
+
+ # self.boundary_nodes = boundary_nodes
+ self.constraints = constraints
+ self.concentrated_nodal_forces = concentrated_nodal_forces
+ if initialized is None:
+ self.volume = jnp.ones((self.total_elements, 1, 1))
+ else:
+ self.volume = jnp.asarray(volume)
+ self.initialized = True
+
+ def id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is `(2, 1)`
+ """
+ return jnp.array([id, id + 1]).reshape(2, 1)
+
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 2, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 2, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3)
+ return result
+
+ def _shapefn_natural_grad(self, xi: ArrayLike):
+ """Calculate the gradient of shape function.
+
+ This calculation is done in the natural coordinates.
+
+ Parameters
+ ----------
+ x : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at.
+
+ Returns
+ -------
+ array_like
+ Evaluated gradient values of the shape function. The shape of
+ the returned array will depend on the input shape. For example,
+ in the linear case, if the input is a scalar, the returned array
+ will be of the shape `(2, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze()
+
+ # TODO: The following code tries to evaluate vmap even if
+ # the predicate condition is true, not sure why.
+ # result = lax.cond(
+ # jnp.isscalar(x),
+ # jacobian(self.shapefn),
+ # vmap(jacobian(self.shapefn)),
+ # xi
+ # )
+ return result.reshape(2, 1)
+
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+
+ def set_particle_element_ids(self, particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ idl = (
+ len(self.nodes.loc)
+ - 1
+ - jnp.asarray(self.nodes.loc[::-1] <= x).nonzero(size=1, fill_value=-1)[
+ 0
+ ][-1]
+ )
+ idg = (
+ jnp.asarray(self.nodes.loc > x).nonzero(size=1, fill_value=-1)[0][0] - 1
+ )
+ return (idl, idg)
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = jnp.where(
+ ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1
+ )
+
+ def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ vol = jnp.ediff1d(self.nodes.loc)
+ self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
+
+ def compute_internal_force(self, particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ # TODO: correct matrix multiplication for n-d
+ # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
+ update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid]
+ f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis])
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_internal_force(self, particles)
+Update the nodal internal force based on particle mass.
+The nodal force is updated as a sum of internal forces for +all particles mapped to the node.
++(f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) +
+where \sigma_p is the stress at particle p.
+particles : Particlesdef compute_internal_force(self, particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ # TODO: correct matrix multiplication for n-d
+ # update = -(pvol[pid]) * pstress[pid] @ mapped_grads[pid]
+ update = -pvol[pid] * pstress[pid][0] * mapped_grads[pid]
+ f_int = f_int.at[el_nodes[pid]].add(update[..., jnp.newaxis])
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_volume(self, *args)
+Compute volume of all elements.
def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ vol = jnp.ediff1d(self.nodes.loc)
+ self.volume = jnp.ones((self.total_elements, 1, 1)) * vol
+
+def id_to_node_ids(self, id: ArrayLike)
+Node IDs corresponding to element id.
id : intArrayLike(2, 1)def id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is `(2, 1)`
+ """
+ return jnp.array([id, id + 1]).reshape(2, 1)
+
+def set_particle_element_ids(self, particles)
+Set the element IDs for the particles.
+If the particle doesn't lie between the boundaries of any +element, it sets the element index to -1.
def set_particle_element_ids(self, particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ idl = (
+ len(self.nodes.loc)
+ - 1
+ - jnp.asarray(self.nodes.loc[::-1] <= x).nonzero(size=1, fill_value=-1)[
+ 0
+ ][-1]
+ )
+ idg = (
+ jnp.asarray(self.nodes.loc > x).nonzero(size=1, fill_value=-1)[0][0] - 1
+ )
+ return (idl, idg)
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = jnp.where(
+ ids[0] == ids[1], ids[0], jnp.ones_like(ids[0]) * -1
+ )
+
+def shapefn(self, xi: ArrayLike)
+Evaluate linear shape function.
+xi : float, array_like(npoints, 1, ndim)array_like(1, 2, 1) but if the input is a vector then the output will
+be of the shape (len(x), 2, 1).def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 2, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 2, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array([0.5 * (1 - xi), 0.5 * (1 + xi)]).transpose(1, 0, 2, 3)
+ return result
+
+def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike)
+Gradient of shape function in physical coordinates.
+xi : float, array_like(npoints, 1, ndim).coords : array_like(npoints, 1, ndim)array_likexidef shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+_Element:
+apply_boundary_constraintsapply_concentrated_nodal_forcesapply_force_boundary_constraintsapply_particle_traction_forcescompute_body_forcecompute_external_forcecompute_nodal_masscompute_nodal_momentumcompute_velocityid_to_node_locid_to_node_velupdate_nodal_acceleration_velocity
+class Quadrilateral4Node
+(nelements: int, total_elements: int, el_len: float, constraints: Sequence[Tuple[ArrayLike, Constraint]], nodes: Optional[Nodes] = None, concentrated_nodal_forces: Sequence = [], initialized: Optional[bool] = None, volume: Optional[ArrayLike] = None)
+Container for 2D quadrilateral elements with 4 nodes.
+Nodes and elements are numbered as
+ 15 +---+---+---+---+ 19
+ | 8 | 9 | 10| 11|
+ 10 +---+---+---+---+ 14
+ | 4 | 5 | 6 | 7 |
+ 5 +---+---+---+---+ 9
+ | 0 | 1 | 2 | 3 |
+ +---+---+---+---+
+ 0 1 2 3 4
+
+where
+ + : Nodes
+ +---+
+ | | : An element
+ +---+
+
+Initialize Linear1D.
+nelements : inttotal_elements : intnelements)el_len : floatconstraints : list(node_ids, diffmpm.Constraint). Here, node_ids
+correspond to the node IDs where diffmpm.Constraint
+should be applied.nodes : Nodes, Optionalconcentrated_nodal_forces : listNodalForces that are to be
+applied.initialized : bool, NoneTrue if the class has been initialized, None if not.
+This is required like this for using JAX flattening.volume : ArrayLike@register_pytree_node_class
+class Quadrilateral4Node(_Element):
+ r"""Container for 2D quadrilateral elements with 4 nodes.
+
+ Nodes and elements are numbered as
+
+ 15 +---+---+---+---+ 19
+ | 8 | 9 | 10| 11|
+ 10 +---+---+---+---+ 14
+ | 4 | 5 | 6 | 7 |
+ 5 +---+---+---+---+ 9
+ | 0 | 1 | 2 | 3 |
+ +---+---+---+---+
+ 0 1 2 3 4
+
+ where
+
+ + : Nodes
+ +---+
+ | | : An element
+ +---+
+ """
+
+ def __init__(
+ self,
+ nelements: int,
+ total_elements: int,
+ el_len: float,
+ constraints: Sequence[Tuple[ArrayLike, Constraint]],
+ nodes: Optional[Nodes] = None,
+ concentrated_nodal_forces: Sequence = [],
+ initialized: Optional[bool] = None,
+ volume: Optional[ArrayLike] = None,
+ ) -> None:
+ """Initialize Linear1D.
+
+ Parameters
+ ----------
+ nelements : int
+ Number of elements.
+ total_elements : int
+ Total number of elements (product of all elements of `nelements`)
+ el_len : float
+ Length of each element.
+ constraints: list
+ A list of constraints where each element is a tuple of
+ type `(node_ids, diffmpm.Constraint)`. Here, `node_ids`
+ correspond to the node IDs where `diffmpm.Constraint`
+ should be applied.
+ nodes : Nodes, Optional
+ Nodes in the element object.
+ concentrated_nodal_forces: list
+ A list of `diffmpm.forces.NodalForce`s that are to be
+ applied.
+ initialized: bool, None
+ `True` if the class has been initialized, `None` if not.
+ This is required like this for using JAX flattening.
+ volume: ArrayLike
+ Volume of the elements.
+ """
+ self.nelements = jnp.asarray(nelements)
+ self.el_len = jnp.asarray(el_len)
+ self.total_elements = total_elements
+
+ if nodes is None:
+ total_nodes = jnp.prod(self.nelements + 1)
+ coords = jnp.asarray(
+ list(
+ itertools.product(
+ jnp.arange(self.nelements[1] + 1),
+ jnp.arange(self.nelements[0] + 1),
+ )
+ )
+ )
+ node_locations = (
+ jnp.asarray([coords[:, 1], coords[:, 0]]).T * self.el_len
+ ).reshape(-1, 1, 2)
+ self.nodes = Nodes(int(total_nodes), node_locations)
+ else:
+ self.nodes = nodes
+
+ self.constraints = constraints
+ self.concentrated_nodal_forces = concentrated_nodal_forces
+ if initialized is None:
+ self.volume = jnp.ones((self.total_elements, 1, 1))
+ else:
+ self.volume = jnp.asarray(volume)
+ self.initialized = True
+
+ def id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ 3----2
+ | |
+ 0----1
+
+ Node ids are returned in the order as shown in the figure.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is (4, 1)
+ """
+ lower_left = (id // self.nelements[0]) * (
+ self.nelements[0] + 1
+ ) + id % self.nelements[0]
+ result = jnp.asarray(
+ [
+ lower_left,
+ lower_left + 1,
+ lower_left + self.nelements[0] + 2,
+ lower_left + self.nelements[0] + 1,
+ ]
+ )
+ return result.reshape(4, 1)
+
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is (npoints, 1, ndim)
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 4, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 4, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array(
+ [
+ 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ ]
+ )
+ result = result.transpose(1, 0, 2)[..., jnp.newaxis]
+ return result
+
+ def _shapefn_natural_grad(self, xi: ArrayLike):
+ """Calculate the gradient of shape function.
+
+ This calculation is done in the natural coordinates.
+
+ Parameters
+ ----------
+ x : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at.
+
+ Returns
+ -------
+ array_like
+ Evaluated gradient values of the shape function. The shape of
+ the returned array will depend on the input shape. For example,
+ in the linear case, if the input is a scalar, the returned array
+ will be of the shape `(4, 2)`.
+ """
+ # result = vmap(jacobian(self.shapefn))(xi[..., jnp.newaxis]).squeeze()
+ xi = jnp.asarray(xi)
+ xi = xi.squeeze()
+ result = jnp.array(
+ [
+ [-0.25 * (1 - xi[1]), -0.25 * (1 - xi[0])],
+ [0.25 * (1 - xi[1]), -0.25 * (1 + xi[0])],
+ [0.25 * (1 + xi[1]), 0.25 * (1 + xi[0])],
+ [-0.25 * (1 + xi[1]), 0.25 * (1 - xi[0])],
+ ],
+ )
+ return result
+
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords.squeeze()
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+
+ def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ xidl = (self.nodes.loc[:, :, 0] <= x[0, 0]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ yidl = (self.nodes.loc[:, :, 1] <= x[0, 1]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max()
+ element_id = lower_left - lower_left // (self.nelements[0] + 1)
+ return element_id
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = ids
+
+ def compute_internal_force(self, particles: Particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ force = jnp.zeros((mapped_grads.shape[1], 1, 2))
+ force = force.at[:, 0, 0].set(
+ mapped_grads[pid][:, 0] * pstress[pid][0]
+ + mapped_grads[pid][:, 1] * pstress[pid][3]
+ )
+ force = force.at[:, 0, 1].set(
+ mapped_grads[pid][:, 1] * pstress[pid][1]
+ + mapped_grads[pid][:, 0] * pstress[pid][3]
+ )
+ update = -pvol[pid] * force
+ f_int = f_int.at[el_nodes[pid]].add(update)
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ a = c = self.el_len[1]
+ b = d = self.el_len[0]
+ p = q = jnp.sqrt(a**2 + b**2)
+ vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2)
+ self.volume = self.volume.at[:].set(vol)
+
+def compute_internal_force(self, particles: Particles)
+Update the nodal internal force based on particle mass.
+The nodal force is updated as a sum of internal forces for +all particles mapped to the node.
++(f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p) +
+where \sigma_p is the stress at particle p.
+particles : Particlesdef compute_internal_force(self, particles: Particles):
+ r"""Update the nodal internal force based on particle mass.
+
+ The nodal force is updated as a sum of internal forces for
+ all particles mapped to the node.
+
+ \[
+ (f_{int})_i = -\sum_p V_p \sigma_p \nabla N_i(x_p)
+ \]
+
+ where \(\sigma_p\) is the stress at particle \(p\).
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ ) = args
+ force = jnp.zeros((mapped_grads.shape[1], 1, 2))
+ force = force.at[:, 0, 0].set(
+ mapped_grads[pid][:, 0] * pstress[pid][0]
+ + mapped_grads[pid][:, 1] * pstress[pid][3]
+ )
+ force = force.at[:, 0, 1].set(
+ mapped_grads[pid][:, 1] * pstress[pid][1]
+ + mapped_grads[pid][:, 0] * pstress[pid][3]
+ )
+ update = -pvol[pid] * force
+ f_int = f_int.at[el_nodes[pid]].add(update)
+ return (
+ f_int,
+ pvol,
+ mapped_grads,
+ el_nodes,
+ pstress,
+ )
+
+ self.nodes.f_int = self.nodes.f_int.at[:].set(0)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ mapped_coords = vmap(self.id_to_node_loc)(particles.element_ids).squeeze(2)
+ mapped_grads = vmap(self.shapefn_grad)(
+ particles.reference_loc[:, jnp.newaxis, ...],
+ mapped_coords,
+ )
+ args = (
+ self.nodes.f_int,
+ particles.volume,
+ mapped_grads,
+ mapped_nodes,
+ particles.stress,
+ )
+ self.nodes.f_int, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_volume(self, *args)
+Compute volume of all elements.
def compute_volume(self, *args):
+ """Compute volume of all elements."""
+ a = c = self.el_len[1]
+ b = d = self.el_len[0]
+ p = q = jnp.sqrt(a**2 + b**2)
+ vol = 0.25 * jnp.sqrt(4 * p * p * q * q - (a * a + c * c - b * b - d * d) ** 2)
+ self.volume = self.volume.at[:].set(vol)
+
+def id_to_node_ids(self, id: ArrayLike)
+Node IDs corresponding to element id.
3----2
+| |
+0----1
+
+Node ids are returned in the order as shown in the figure.
+id : intArrayLikedef id_to_node_ids(self, id: ArrayLike):
+ """Node IDs corresponding to element `id`.
+
+ 3----2
+ | |
+ 0----1
+
+ Node ids are returned in the order as shown in the figure.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element. Shape of returned
+ array is (4, 1)
+ """
+ lower_left = (id // self.nelements[0]) * (
+ self.nelements[0] + 1
+ ) + id % self.nelements[0]
+ result = jnp.asarray(
+ [
+ lower_left,
+ lower_left + 1,
+ lower_left + self.nelements[0] + 2,
+ lower_left + self.nelements[0] + 1,
+ ]
+ )
+ return result.reshape(4, 1)
+
+def set_particle_element_ids(self, particles: Particles)
+Set the element IDs for the particles.
+If the particle doesn't lie between the boundaries of any +element, it sets the element index to -1.
def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs for the particles.
+
+ If the particle doesn't lie between the boundaries of any
+ element, it sets the element index to -1.
+ """
+
+ @jit
+ def f(x):
+ xidl = (self.nodes.loc[:, :, 0] <= x[0, 0]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ yidl = (self.nodes.loc[:, :, 1] <= x[0, 1]).nonzero(
+ size=len(self.nodes.loc), fill_value=-1
+ )[0]
+ lower_left = jnp.where(jnp.isin(xidl, yidl), xidl, -1).max()
+ element_id = lower_left - lower_left // (self.nelements[0] + 1)
+ return element_id
+
+ ids = vmap(f)(particles.loc)
+ particles.element_ids = ids
+
+def shapefn(self, xi: ArrayLike)
+Evaluate linear shape function.
+xi : float, array_likearray_like(1, 4, 1) but if the input is a vector then the output will
+be of the shape (len(x), 4, 1).def shapefn(self, xi: ArrayLike):
+ """Evaluate linear shape function.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles in natural coordinates to evaluate
+ the function at. Expected shape is (npoints, 1, ndim)
+
+ Returns
+ -------
+ array_like
+ Evaluated shape function values. The shape of the returned
+ array will depend on the input shape. For example, in the linear
+ case, if the input is a scalar, the returned array will be of
+ the shape `(1, 4, 1)` but if the input is a vector then the output will
+ be of the shape `(len(x), 4, 1)`.
+ """
+ xi = jnp.asarray(xi)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`xi` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ result = jnp.array(
+ [
+ 0.25 * (1 - xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 - xi[:, :, 1]),
+ 0.25 * (1 + xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ 0.25 * (1 - xi[:, :, 0]) * (1 + xi[:, :, 1]),
+ ]
+ )
+ result = result.transpose(1, 0, 2)[..., jnp.newaxis]
+ return result
+
+def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike)
+Gradient of shape function in physical coordinates.
+xi : float, array_like(npoints, 1, ndim).coords : array_like(npoints, 1, ndim)array_likexidef shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Gradient of shape function in physical coordinates.
+
+ Parameters
+ ----------
+ xi : float, array_like
+ Locations of particles to evaluate in natural coordinates.
+ Expected shape `(npoints, 1, ndim)`.
+ coords : array_like
+ Nodal coordinates to transform by. Expected shape
+ `(npoints, 1, ndim)`
+
+ Returns
+ -------
+ array_like
+ Gradient of the shape function in physical coordinates at `xi`
+ """
+ xi = jnp.asarray(xi)
+ coords = jnp.asarray(coords)
+ if xi.ndim != 3:
+ raise ValueError(
+ f"`x` should be of size (npoints, 1, ndim); found {xi.shape}"
+ )
+ grad_sf = self._shapefn_natural_grad(xi)
+ _jacobian = grad_sf.T @ coords.squeeze()
+
+ result = grad_sf @ jnp.linalg.inv(_jacobian).T
+ return result
+_Element:
+apply_boundary_constraintsapply_concentrated_nodal_forcesapply_force_boundary_constraintsapply_particle_traction_forcescompute_body_forcecompute_external_forcecompute_nodal_masscompute_nodal_momentumcompute_velocityid_to_node_locid_to_node_velupdate_nodal_acceleration_velocity
+class _Element
+Base element class that is inherited by all types of Elements.
class _Element(abc.ABC):
+ """Base element class that is inherited by all types of Elements."""
+
+ nodes: Nodes
+ total_elements: int
+ concentrated_nodal_forces: Sequence
+ volume: Array
+
+ @abc.abstractmethod
+ def id_to_node_ids(self, id: ArrayLike) -> Array:
+ """Node IDs corresponding to element `id`.
+
+ This method is implemented by each of the subclass.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element.
+ """
+ ...
+
+ def id_to_node_loc(self, id: ArrayLike) -> Array:
+ """Node locations corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal locations for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.loc[node_ids]
+
+ def id_to_node_vel(self, id: ArrayLike) -> Array:
+ """Node velocities corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal velocities for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.velocity[node_ids]
+
+ def tree_flatten(self):
+ children = (self.nodes, self.volume)
+ aux_data = (
+ self.nelements,
+ self.total_elements,
+ self.el_len,
+ self.constraints,
+ self.concentrated_nodal_forces,
+ self.initialized,
+ )
+ return children, aux_data
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ aux_data[0],
+ aux_data[1],
+ aux_data[2],
+ aux_data[3],
+ nodes=children[0],
+ concentrated_nodal_forces=aux_data[4],
+ initialized=aux_data[5],
+ volume=children[1],
+ )
+
+ @abc.abstractmethod
+ def shapefn(self, xi: ArrayLike):
+ """Evaluate Shape function for element type."""
+ ...
+
+ @abc.abstractmethod
+ def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Evaluate gradient of shape function for element type."""
+ ...
+
+ @abc.abstractmethod
+ def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs that particles are present in."""
+ ...
+
+ # Mapping from particles to nodes (P2G)
+ def compute_nodal_mass(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (m)_i = \sum_p N_i(x_p) m_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmass, mass, mapped_pos, el_nodes = args
+ mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid])
+ return pmass, mass, mapped_pos, el_nodes
+
+ self.nodes.mass = self.nodes.mass.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass,
+ self.nodes.mass,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_nodal_momentum(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (mv)_i = \sum_p N_i(x_p) (mv)_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmom, mom, mapped_pos, el_nodes = args
+ mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid])
+ return pmom, mom, mapped_pos, el_nodes
+
+ self.nodes.momentum = self.nodes.momentum.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass * particles.velocity,
+ self.nodes.momentum,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
+ self.nodes.momentum = jnp.where(
+ jnp.abs(self.nodes.momentum) < 1e-12,
+ jnp.zeros_like(self.nodes.momentum),
+ self.nodes.momentum,
+ )
+
+ def compute_velocity(self, particles: Particles):
+ """Compute velocity using momentum."""
+ self.nodes.velocity = jnp.where(
+ self.nodes.mass == 0,
+ self.nodes.velocity,
+ self.nodes.momentum / self.nodes.mass,
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+
+ def compute_external_force(self, particles: Particles):
+ r"""Update the nodal external force based on particle f_ext.
+
+ The nodal force is updated as a sum of particle external
+ force for all particles mapped to the node.
+
+ \[
+ f_{ext})_i = \sum_p N_i(x_p) f_{ext}
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pf_ext, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid])
+ return f_ext, pf_ext, mapped_pos, el_nodes
+
+ self.nodes.f_ext = self.nodes.f_ext.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.f_ext,
+ mapped_positions,
+ mapped_nodes,
+ )
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def compute_body_force(self, particles: Particles, gravity: ArrayLike):
+ r"""Update the nodal external force based on particle mass.
+
+ The nodal force is updated as a sum of particle body
+ force for all particles mapped to th
+
+ \[
+ (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pmass, mapped_pos, el_nodes, gravity = args
+ f_ext = f_ext.at[el_nodes[pid]].add(
+ mapped_pos[pid] @ (pmass[pid] * gravity)
+ )
+ return f_ext, pmass, mapped_pos, el_nodes, gravity
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.mass,
+ mapped_positions,
+ mapped_nodes,
+ gravity,
+ )
+ self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ curr_time: float
+ Current time in the simulation.
+ """
+ for cnf in self.concentrated_nodal_forces:
+ factor = cnf.function.value(curr_time)
+ self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add(
+ factor * cnf.force
+ )
+
+ def apply_particle_traction_forces(self, particles: Particles):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ """
+
+ def _step(pid, args):
+ f_ext, ptraction, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid])
+ return f_ext, ptraction, mapped_pos, el_nodes
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes)
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+ def update_nodal_acceleration_velocity(
+ self, particles: Particles, dt: float, *args
+ ):
+ """Update the nodal momentum based on total force on nodes."""
+ total_force = self.nodes.get_total_force()
+ self.nodes.acceleration = self.nodes.acceleration.at[:].set(
+ jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass))
+ )
+ self.nodes.velocity = self.nodes.velocity.at[:].add(
+ self.nodes.acceleration * dt
+ )
+ self.apply_boundary_constraints()
+ self.nodes.momentum = self.nodes.momentum.at[:].set(
+ self.nodes.mass * self.nodes.velocity
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+ self.nodes.acceleration = jnp.where(
+ jnp.abs(self.nodes.acceleration) < 1e-12,
+ jnp.zeros_like(self.nodes.acceleration),
+ self.nodes.acceleration,
+ )
+
+ def apply_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal velocity."""
+ for ids, constraint in self.constraints:
+ constraint.apply(self.nodes, ids)
+
+ def apply_force_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal forces."""
+ self.nodes.f_int = self.nodes.f_int.at[self.constraints[0][0]].set(0)
+ self.nodes.f_ext = self.nodes.f_ext.at[self.constraints[0][0]].set(0)
+ self.nodes.f_damp = self.nodes.f_damp.at[self.constraints[0][0]].set(0)
+var concentrated_nodal_forces : Sequencevar nodes : Nodesvar total_elements : intvar volume : jax.Array
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ return cls(
+ aux_data[0],
+ aux_data[1],
+ aux_data[2],
+ aux_data[3],
+ nodes=children[0],
+ concentrated_nodal_forces=aux_data[4],
+ initialized=aux_data[5],
+ volume=children[1],
+ )
+
+def apply_boundary_constraints(self, *args)
+Apply boundary conditions for nodal velocity.
def apply_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal velocity."""
+ for ids, constraint in self.constraints:
+ constraint.apply(self.nodes, ids)
+
+def apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float)
+Apply concentrated nodal forces.
+particles : Particlescurr_time : floatdef apply_concentrated_nodal_forces(self, particles: Particles, curr_time: float):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ curr_time: float
+ Current time in the simulation.
+ """
+ for cnf in self.concentrated_nodal_forces:
+ factor = cnf.function.value(curr_time)
+ self.nodes.f_ext = self.nodes.f_ext.at[cnf.node_ids, 0, cnf.dir].add(
+ factor * cnf.force
+ )
+
+def apply_force_boundary_constraints(self, *args)
+Apply boundary conditions for nodal forces.
def apply_force_boundary_constraints(self, *args):
+ """Apply boundary conditions for nodal forces."""
+ self.nodes.f_int = self.nodes.f_int.at[self.constraints[0][0]].set(0)
+ self.nodes.f_ext = self.nodes.f_ext.at[self.constraints[0][0]].set(0)
+ self.nodes.f_damp = self.nodes.f_damp.at[self.constraints[0][0]].set(0)
+
+def apply_particle_traction_forces(self, particles: Particles)
+Apply concentrated nodal forces.
+particles : Particlesdef apply_particle_traction_forces(self, particles: Particles):
+ """Apply concentrated nodal forces.
+
+ Parameters
+ ----------
+ particles: Particles
+ Particles in the simulation.
+ """
+
+ def _step(pid, args):
+ f_ext, ptraction, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ ptraction[pid])
+ return f_ext, ptraction, mapped_pos, el_nodes
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (self.nodes.f_ext, particles.traction, mapped_positions, mapped_nodes)
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_body_force(self, particles: Particles, gravity: ArrayLike)
+Update the nodal external force based on particle mass.
+The nodal force is updated as a sum of particle body +force for all particles mapped to th
++(f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g +
+particles : Particlesdef compute_body_force(self, particles: Particles, gravity: ArrayLike):
+ r"""Update the nodal external force based on particle mass.
+
+ The nodal force is updated as a sum of particle body
+ force for all particles mapped to th
+
+ \[
+ (f_{ext})_i = (f_{ext})_i + \sum_p N_i(x_p) m_p g
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pmass, mapped_pos, el_nodes, gravity = args
+ f_ext = f_ext.at[el_nodes[pid]].add(
+ mapped_pos[pid] @ (pmass[pid] * gravity)
+ )
+ return f_ext, pmass, mapped_pos, el_nodes, gravity
+
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.mass,
+ mapped_positions,
+ mapped_nodes,
+ gravity,
+ )
+ self.nodes.f_ext, _, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_external_force(self, particles: Particles)
+Update the nodal external force based on particle f_ext.
+The nodal force is updated as a sum of particle external +force for all particles mapped to the node.
++f_{ext})_i = \sum_p N_i(x_p) f_{ext} +
+particles : Particlesdef compute_external_force(self, particles: Particles):
+ r"""Update the nodal external force based on particle f_ext.
+
+ The nodal force is updated as a sum of particle external
+ force for all particles mapped to the node.
+
+ \[
+ f_{ext})_i = \sum_p N_i(x_p) f_{ext}
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ f_ext, pf_ext, mapped_pos, el_nodes = args
+ f_ext = f_ext.at[el_nodes[pid]].add(mapped_pos[pid] @ pf_ext[pid])
+ return f_ext, pf_ext, mapped_pos, el_nodes
+
+ self.nodes.f_ext = self.nodes.f_ext.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ self.nodes.f_ext,
+ particles.f_ext,
+ mapped_positions,
+ mapped_nodes,
+ )
+ self.nodes.f_ext, _, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_nodal_mass(self, particles: Particles)
+Compute the nodal mass based on particle mass.
+The nodal mass is updated as a sum of particle mass for +all particles mapped to the node.
++(m)_i = \sum_p N_i(x_p) m_p +
+particles : Particlesdef compute_nodal_mass(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (m)_i = \sum_p N_i(x_p) m_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmass, mass, mapped_pos, el_nodes = args
+ mass = mass.at[el_nodes[pid]].add(pmass[pid] * mapped_pos[pid])
+ return pmass, mass, mapped_pos, el_nodes
+
+ self.nodes.mass = self.nodes.mass.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass,
+ self.nodes.mass,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.mass, _, _ = lax.fori_loop(0, len(particles), _step, args)
+
+def compute_nodal_momentum(self, particles: Particles)
+Compute the nodal mass based on particle mass.
+The nodal mass is updated as a sum of particle mass for +all particles mapped to the node.
++(mv)_i = \sum_p N_i(x_p) (mv)_p +
+particles : Particlesdef compute_nodal_momentum(self, particles: Particles):
+ r"""Compute the nodal mass based on particle mass.
+
+ The nodal mass is updated as a sum of particle mass for
+ all particles mapped to the node.
+
+ \[
+ (mv)_i = \sum_p N_i(x_p) (mv)_p
+ \]
+
+ Parameters
+ ----------
+ particles: diffmpm.particle.Particles
+ Particles to map to the nodal values.
+ """
+
+ def _step(pid, args):
+ pmom, mom, mapped_pos, el_nodes = args
+ mom = mom.at[el_nodes[pid]].add(mapped_pos[pid] @ pmom[pid])
+ return pmom, mom, mapped_pos, el_nodes
+
+ self.nodes.momentum = self.nodes.momentum.at[:].set(0)
+ mapped_positions = self.shapefn(particles.reference_loc)
+ mapped_nodes = vmap(self.id_to_node_ids)(particles.element_ids).squeeze(-1)
+ args = (
+ particles.mass * particles.velocity,
+ self.nodes.momentum,
+ mapped_positions,
+ mapped_nodes,
+ )
+ _, self.nodes.momentum, _, _ = lax.fori_loop(0, len(particles), _step, args)
+ self.nodes.momentum = jnp.where(
+ jnp.abs(self.nodes.momentum) < 1e-12,
+ jnp.zeros_like(self.nodes.momentum),
+ self.nodes.momentum,
+ )
+
+def compute_velocity(self, particles: Particles)
+Compute velocity using momentum.
def compute_velocity(self, particles: Particles):
+ """Compute velocity using momentum."""
+ self.nodes.velocity = jnp.where(
+ self.nodes.mass == 0,
+ self.nodes.velocity,
+ self.nodes.momentum / self.nodes.mass,
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+
+def id_to_node_ids(self, id: ArrayLike) ‑> jax.Array
+Node IDs corresponding to element id.
This method is implemented by each of the subclass.
+id : intArrayLike@abc.abstractmethod
+def id_to_node_ids(self, id: ArrayLike) -> Array:
+ """Node IDs corresponding to element `id`.
+
+ This method is implemented by each of the subclass.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal IDs of the element.
+ """
+ ...
+
+def id_to_node_loc(self, id: ArrayLike) ‑> jax.Array
+Node locations corresponding to element id.
id : intArrayLike(nodes_in_element, 1, ndim)def id_to_node_loc(self, id: ArrayLike) -> Array:
+ """Node locations corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal locations for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.loc[node_ids]
+
+def id_to_node_vel(self, id: ArrayLike) ‑> jax.Array
+Node velocities corresponding to element id.
id : intArrayLike(nodes_in_element, 1, ndim)def id_to_node_vel(self, id: ArrayLike) -> Array:
+ """Node velocities corresponding to element `id`.
+
+ Parameters
+ ----------
+ id : int
+ Element ID.
+
+ Returns
+ -------
+ ArrayLike
+ Nodal velocities for the element. Shape of returned
+ array is `(nodes_in_element, 1, ndim)`
+ """
+ node_ids = self.id_to_node_ids(id).squeeze()
+ return self.nodes.velocity[node_ids]
+
+def set_particle_element_ids(self, particles: Particles)
+Set the element IDs that particles are present in.
@abc.abstractmethod
+def set_particle_element_ids(self, particles: Particles):
+ """Set the element IDs that particles are present in."""
+ ...
+
+def shapefn(self, xi: ArrayLike)
+Evaluate Shape function for element type.
@abc.abstractmethod
+def shapefn(self, xi: ArrayLike):
+ """Evaluate Shape function for element type."""
+ ...
+
+def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike)
+Evaluate gradient of shape function for element type.
@abc.abstractmethod
+def shapefn_grad(self, xi: ArrayLike, coords: ArrayLike):
+ """Evaluate gradient of shape function for element type."""
+ ...
+
+def tree_flatten(self)
+def tree_flatten(self):
+ children = (self.nodes, self.volume)
+ aux_data = (
+ self.nelements,
+ self.total_elements,
+ self.el_len,
+ self.constraints,
+ self.concentrated_nodal_forces,
+ self.initialized,
+ )
+ return children, aux_data
+
+def update_nodal_acceleration_velocity(self, particles: Particles, dt: float, *args)
+Update the nodal momentum based on total force on nodes.
def update_nodal_acceleration_velocity(
+ self, particles: Particles, dt: float, *args
+):
+ """Update the nodal momentum based on total force on nodes."""
+ total_force = self.nodes.get_total_force()
+ self.nodes.acceleration = self.nodes.acceleration.at[:].set(
+ jnp.nan_to_num(jnp.divide(total_force, self.nodes.mass))
+ )
+ self.nodes.velocity = self.nodes.velocity.at[:].add(
+ self.nodes.acceleration * dt
+ )
+ self.apply_boundary_constraints()
+ self.nodes.momentum = self.nodes.momentum.at[:].set(
+ self.nodes.mass * self.nodes.velocity
+ )
+ self.nodes.velocity = jnp.where(
+ jnp.abs(self.nodes.velocity) < 1e-12,
+ jnp.zeros_like(self.nodes.velocity),
+ self.nodes.velocity,
+ )
+ self.nodes.acceleration = jnp.where(
+ jnp.abs(self.nodes.acceleration) < 1e-12,
+ jnp.zeros_like(self.nodes.acceleration),
+ self.nodes.acceleration,
+ )
+diffmpm.forcesfrom typing import Annotated, NamedTuple, get_type_hints
+
+from jax import Array
+from jax.tree_util import register_pytree_node
+
+from diffmpm.functions import Function
+
+
+class NodalForce(NamedTuple):
+ """Nodal Force being applied constantly on a set of nodes."""
+
+ node_ids: Annotated[Array, "Array of Node IDs to which force is applied."]
+ function: Annotated[
+ Function,
+ "Mathematical function that governs time-varying changes in the force.",
+ ]
+ dir: Annotated[int, "Direction in which force is applied."]
+ force: Annotated[float, "Amount of force to be applied."]
+
+
+nfhints = get_type_hints(NodalForce, include_extras=True)
+for attr in nfhints:
+ getattr(NodalForce, attr).__doc__ = "".join(nfhints[attr].__metadata__)
+
+
+class ParticleTraction(NamedTuple):
+ """Traction being applied on a set of particles."""
+
+ pset: Annotated[
+ int, "The particle set in which traction is applied to the particles."
+ ]
+ pids: Annotated[
+ Array,
+ "Array of Particle IDs to which traction is applied inside the particle set.",
+ ]
+ function: Annotated[
+ Function,
+ "Mathematical function that governs time-varying changes in the traction.",
+ ]
+ dir: Annotated[int, "Direction in which traction is applied."]
+ traction: Annotated[float, "Amount of traction to be applied."]
+
+
+pthints = get_type_hints(ParticleTraction, include_extras=True)
+for attr in pthints:
+ getattr(ParticleTraction, attr).__doc__ = "".join(pthints[attr].__metadata__)
+
+register_pytree_node(
+ NodalForce,
+ # tell JAX how to unpack to an iterable
+ lambda xs: (tuple(xs), None), # type: ignore
+ # tell JAX how to pack back into a NodalForce
+ lambda _, xs: NodalForce(*xs), # type: ignore
+)
+register_pytree_node(
+ ParticleTraction,
+ # tell JAX how to unpack to an iterable
+ lambda xs: (tuple(xs), None), # type: ignore
+ # tell JAX how to pack back
+ lambda _, xs: ParticleTraction(*xs), # type: ignore
+)
+
+class NodalForce
+(node_ids: typing.Annotated[jax.Array, 'Array of Node IDs to which force is applied.'], function: typing.Annotated[Function, 'Mathematical function that governs time-varying changes in the force.'], dir: typing.Annotated[int, 'Direction in which force is applied.'], force: typing.Annotated[float, 'Amount of force to be applied.'])
+Nodal Force being applied constantly on a set of nodes.
class NodalForce(NamedTuple):
+ """Nodal Force being applied constantly on a set of nodes."""
+
+ node_ids: Annotated[Array, "Array of Node IDs to which force is applied."]
+ function: Annotated[
+ Function,
+ "Mathematical function that governs time-varying changes in the force.",
+ ]
+ dir: Annotated[int, "Direction in which force is applied."]
+ force: Annotated[float, "Amount of force to be applied."]
+var dir : intDirection in which force is applied.
var force : floatAmount of force to be applied.
var function : FunctionMathematical function that governs time-varying changes in the force.
var node_ids : jax.ArrayArray of Node IDs to which force is applied.
+class ParticleTraction
+(pset: typing.Annotated[int, 'The particle set in which traction is applied to the particles.'], pids: typing.Annotated[jax.Array, 'Array of Particle IDs to which traction is applied inside the particle set.'], function: typing.Annotated[Function, 'Mathematical function that governs time-varying changes in the traction.'], dir: typing.Annotated[int, 'Direction in which traction is applied.'], traction: typing.Annotated[float, 'Amount of traction to be applied.'])
+Traction being applied on a set of particles.
class ParticleTraction(NamedTuple):
+ """Traction being applied on a set of particles."""
+
+ pset: Annotated[
+ int, "The particle set in which traction is applied to the particles."
+ ]
+ pids: Annotated[
+ Array,
+ "Array of Particle IDs to which traction is applied inside the particle set.",
+ ]
+ function: Annotated[
+ Function,
+ "Mathematical function that governs time-varying changes in the traction.",
+ ]
+ dir: Annotated[int, "Direction in which traction is applied."]
+ traction: Annotated[float, "Amount of traction to be applied."]
+var dir : intDirection in which traction is applied.
var function : FunctionMathematical function that governs time-varying changes in the traction.
var pids : jax.ArrayArray of Particle IDs to which traction is applied inside the particle set.
var pset : intThe particle set in which traction is applied to the particles.
var traction : floatAmount of traction to be applied.
diffmpm.functionsimport abc
+
+import jax.numpy as jnp
+from jax.tree_util import register_pytree_node_class
+
+
+class Function(abc.ABC):
+ def __init__(self, id):
+ self.id = id
+
+ @abc.abstractmethod
+ def value(self):
+ ...
+
+
+@register_pytree_node_class
+class Unit(Function):
+ def __init__(self, id):
+ super().__init__(id)
+
+ def value(self, x):
+ return 1.0
+
+ def tree_flatten(self):
+ return ((), (self.id))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+
+@register_pytree_node_class
+class Linear(Function):
+ def __init__(self, id, xvalues, fxvalues):
+ self.xvalues = xvalues
+ self.fxvalues = fxvalues
+ super().__init__(id)
+
+ def value(self, x):
+ return jnp.interp(x, self.xvalues, self.fxvalues)
+
+ def tree_flatten(self):
+ return ((), (self.id, self.xvalues, self.fxvalues))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+class Function
+(id)
+Helper class that provides a standard way to create an ABC using +inheritance.
class Function(abc.ABC):
+ def __init__(self, id):
+ self.id = id
+
+ @abc.abstractmethod
+ def value(self):
+ ...
+
+def value(self)
+@abc.abstractmethod
+def value(self):
+ ...
+
+class Linear
+(id, xvalues, fxvalues)
+Helper class that provides a standard way to create an ABC using +inheritance.
@register_pytree_node_class
+class Linear(Function):
+ def __init__(self, id, xvalues, fxvalues):
+ self.xvalues = xvalues
+ self.fxvalues = fxvalues
+ super().__init__(id)
+
+ def value(self, x):
+ return jnp.interp(x, self.xvalues, self.fxvalues)
+
+ def tree_flatten(self):
+ return ((), (self.id, self.xvalues, self.fxvalues))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+def tree_flatten(self)
+def tree_flatten(self):
+ return ((), (self.id, self.xvalues, self.fxvalues))
+
+def value(self, x)
+def value(self, x):
+ return jnp.interp(x, self.xvalues, self.fxvalues)
+
+class Unit
+(id)
+Helper class that provides a standard way to create an ABC using +inheritance.
@register_pytree_node_class
+class Unit(Function):
+ def __init__(self, id):
+ super().__init__(id)
+
+ def value(self, x):
+ return 1.0
+
+ def tree_flatten(self):
+ return ((), (self.id))
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ del children
+ return cls(*aux_data)
+
+def tree_flatten(self)
+def tree_flatten(self):
+ return ((), (self.id))
+
+def value(self, x)
+def value(self, x):
+ return 1.0
+diffmpmfrom importlib.metadata import version
+from pathlib import Path
+
+import diffmpm.writers as writers
+from diffmpm.io import Config
+from diffmpm.solver import MPMExplicit
+
+__all__ = ["MPM", "__version__"]
+
+__version__ = version("diffmpm")
+
+
+class MPM:
+ def __init__(self, filepath):
+ self._config = Config(filepath)
+ mesh = self._config.parse()
+ out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
+ self._config.parsed_config["meta"]["title"],
+ )
+
+ write_format = self._config.parsed_config["output"].get("format", None)
+ if write_format is None or write_format.lower() == "none":
+ writer_func = None
+ elif write_format == "npz":
+ writer_func = writers.NPZWriter().write
+ else:
+ raise ValueError(f"Specified output format not supported: {write_format}")
+
+ if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
+ self.solver = MPMExplicit(
+ mesh,
+ self._config.parsed_config["meta"]["dt"],
+ velocity_update=self._config.parsed_config["meta"]["velocity_update"],
+ sim_steps=self._config.parsed_config["meta"]["nsteps"],
+ out_steps=self._config.parsed_config["output"]["step_frequency"],
+ out_dir=out_dir,
+ writer_func=writer_func,
+ )
+ else:
+ raise ValueError("Wrong type of solver specified.")
+
+ def solve(self):
+ """Solve the MPM simulation using JIT solver."""
+ arrays = self.solver.solve_jit(
+ self._config.parsed_config["external_loading"]["gravity"],
+ )
+ return arrays
+diffmpm.clidiffmpm.constraintdiffmpm.elementdiffmpm.forcesdiffmpm.functionsdiffmpm.iodiffmpm.materialdiffmpm.meshdiffmpm.nodediffmpm.particlediffmpm.schemediffmpm.solverdiffmpm.writers
+class MPM
+(filepath)
+class MPM:
+ def __init__(self, filepath):
+ self._config = Config(filepath)
+ mesh = self._config.parse()
+ out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
+ self._config.parsed_config["meta"]["title"],
+ )
+
+ write_format = self._config.parsed_config["output"].get("format", None)
+ if write_format is None or write_format.lower() == "none":
+ writer_func = None
+ elif write_format == "npz":
+ writer_func = writers.NPZWriter().write
+ else:
+ raise ValueError(f"Specified output format not supported: {write_format}")
+
+ if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
+ self.solver = MPMExplicit(
+ mesh,
+ self._config.parsed_config["meta"]["dt"],
+ velocity_update=self._config.parsed_config["meta"]["velocity_update"],
+ sim_steps=self._config.parsed_config["meta"]["nsteps"],
+ out_steps=self._config.parsed_config["output"]["step_frequency"],
+ out_dir=out_dir,
+ writer_func=writer_func,
+ )
+ else:
+ raise ValueError("Wrong type of solver specified.")
+
+ def solve(self):
+ """Solve the MPM simulation using JIT solver."""
+ arrays = self.solver.solve_jit(
+ self._config.parsed_config["external_loading"]["gravity"],
+ )
+ return arrays
+
+def solve(self)
+Solve the MPM simulation using JIT solver.
def solve(self):
+ """Solve the MPM simulation using JIT solver."""
+ arrays = self.solver.solve_jit(
+ self._config.parsed_config["external_loading"]["gravity"],
+ )
+ return arrays
+diffmpm.ioimport json
+import tomllib as tl
+from collections import namedtuple
+
+import jax.numpy as jnp
+
+from diffmpm import element as mpel
+from diffmpm import material as mpmat
+from diffmpm import mesh as mpmesh
+from diffmpm.constraint import Constraint
+from diffmpm.forces import NodalForce, ParticleTraction
+from diffmpm.functions import Linear, Unit
+from diffmpm.particle import Particles
+
+
+class Config:
+ def __init__(self, filepath):
+ self._filepath = filepath
+ self.parsed_config = {}
+ self.parse()
+
+ def parse(self):
+ with open(self._filepath, "rb") as f:
+ self._fileconfig = tl.load(f)
+
+ self._parse_meta(self._fileconfig)
+ self._parse_output(self._fileconfig)
+ self._parse_materials(self._fileconfig)
+ self._parse_particles(self._fileconfig)
+ if "math_functions" in self._fileconfig:
+ self._parse_math_functions(self._fileconfig)
+ self._parse_external_loading(self._fileconfig)
+ mesh = self._parse_mesh(self._fileconfig)
+ return mesh
+
+ def _parse_meta(self, config):
+ self.parsed_config["meta"] = config["meta"]
+
+ def _parse_output(self, config):
+ self.parsed_config["output"] = config["output"]
+
+ def _parse_materials(self, config):
+ materials = []
+ for mat_config in config["materials"]:
+ mat_type = mat_config.pop("type")
+ mat_cls = getattr(mpmat, mat_type)
+ mat = mat_cls(mat_config)
+ materials.append(mat)
+ self.parsed_config["materials"] = materials
+
+ def _parse_particles(self, config):
+ particle_sets = []
+ for pset_config in config["particles"]:
+ pmat = self.parsed_config["materials"][pset_config["material_id"]]
+ with open(pset_config["file"], "r") as f:
+ ploc = jnp.asarray(json.load(f))
+ peids = jnp.zeros(ploc.shape[0], dtype=jnp.int32)
+ pset = Particles(ploc, pmat, peids)
+ pset.velocity = pset.velocity.at[:].set(pset_config["init_velocity"])
+ particle_sets.append(pset)
+ self.parsed_config["particles"] = particle_sets
+
+ def _parse_math_functions(self, config):
+ flist = []
+ for i, fnconfig in enumerate(config["math_functions"]):
+ if fnconfig["type"] == "Linear":
+ fn = Linear(
+ i,
+ jnp.array(fnconfig["xvalues"]),
+ jnp.array(fnconfig["fxvalues"]),
+ )
+ flist.append(fn)
+ else:
+ raise NotImplementedError(
+ "Function type other than `Linear` not yet supported"
+ )
+ self.parsed_config["math_functions"] = flist
+
+ def _parse_external_loading(self, config):
+ external_loading = {}
+ external_loading["gravity"] = jnp.array(config["external_loading"]["gravity"])
+ external_loading["concentrated_nodal_forces"] = []
+ particle_surface_traction = []
+ if "concentrated_nodal_forces" in config["external_loading"]:
+ cnf_list = []
+ for cnfconfig in config["external_loading"]["concentrated_nodal_forces"]:
+ if "math_function_id" in cnfconfig:
+ fn = self.parsed_config["math_functions"][
+ cnfconfig["math_function_id"]
+ ]
+ else:
+ fn = Unit(-1)
+ cnf = NodalForce(
+ node_ids=jnp.array(cnfconfig["node_ids"]),
+ function=fn,
+ dir=cnfconfig["dir"],
+ force=cnfconfig["force"],
+ )
+ cnf_list.append(cnf)
+ external_loading["concentrated_nodal_forces"] = cnf_list
+
+ if "particle_surface_traction" in config["external_loading"]:
+ pst_list = []
+ for pstconfig in config["external_loading"]["particle_surface_traction"]:
+ if "math_function_id" in pstconfig:
+ fn = self.parsed_config["math_functions"][
+ pstconfig["math_function_id"]
+ ]
+ else:
+ fn = Unit(-1)
+ pst = ParticleTraction(
+ pset=pstconfig["pset"],
+ pids=jnp.array(pstconfig["pids"]),
+ function=fn,
+ dir=pstconfig["dir"],
+ traction=pstconfig["traction"],
+ )
+ pst_list.append(pst)
+ particle_surface_traction.extend(pst_list)
+ self.parsed_config["external_loading"] = external_loading
+ self.parsed_config["particle_surface_traction"] = particle_surface_traction
+
+ def _parse_mesh(self, config):
+ element_cls = getattr(mpel, config["mesh"]["element"])
+ mesh_cls = getattr(mpmesh, f"Mesh{config['meta']['dimension']}D")
+ constraints = [
+ (jnp.asarray(c["node_ids"]), Constraint(c["dir"], c["velocity"]))
+ for c in config["mesh"]["constraints"]
+ ]
+ if config["mesh"]["type"] == "generator":
+ elements = element_cls(
+ config["mesh"]["nelements"],
+ jnp.prod(jnp.array(config["mesh"]["nelements"])),
+ config["mesh"]["element_length"],
+ constraints,
+ concentrated_nodal_forces=self.parsed_config["external_loading"][
+ "concentrated_nodal_forces"
+ ],
+ )
+ else:
+ raise NotImplementedError(
+ "Mesh type other than `generator` not yet supported."
+ )
+ self.parsed_config["elements"] = elements
+ mesh = mesh_cls(self.parsed_config)
+ return mesh
+
+class Config
+(filepath)
+class Config:
+ def __init__(self, filepath):
+ self._filepath = filepath
+ self.parsed_config = {}
+ self.parse()
+
+ def parse(self):
+ with open(self._filepath, "rb") as f:
+ self._fileconfig = tl.load(f)
+
+ self._parse_meta(self._fileconfig)
+ self._parse_output(self._fileconfig)
+ self._parse_materials(self._fileconfig)
+ self._parse_particles(self._fileconfig)
+ if "math_functions" in self._fileconfig:
+ self._parse_math_functions(self._fileconfig)
+ self._parse_external_loading(self._fileconfig)
+ mesh = self._parse_mesh(self._fileconfig)
+ return mesh
+
+ def _parse_meta(self, config):
+ self.parsed_config["meta"] = config["meta"]
+
+ def _parse_output(self, config):
+ self.parsed_config["output"] = config["output"]
+
+ def _parse_materials(self, config):
+ materials = []
+ for mat_config in config["materials"]:
+ mat_type = mat_config.pop("type")
+ mat_cls = getattr(mpmat, mat_type)
+ mat = mat_cls(mat_config)
+ materials.append(mat)
+ self.parsed_config["materials"] = materials
+
+ def _parse_particles(self, config):
+ particle_sets = []
+ for pset_config in config["particles"]:
+ pmat = self.parsed_config["materials"][pset_config["material_id"]]
+ with open(pset_config["file"], "r") as f:
+ ploc = jnp.asarray(json.load(f))
+ peids = jnp.zeros(ploc.shape[0], dtype=jnp.int32)
+ pset = Particles(ploc, pmat, peids)
+ pset.velocity = pset.velocity.at[:].set(pset_config["init_velocity"])
+ particle_sets.append(pset)
+ self.parsed_config["particles"] = particle_sets
+
+ def _parse_math_functions(self, config):
+ flist = []
+ for i, fnconfig in enumerate(config["math_functions"]):
+ if fnconfig["type"] == "Linear":
+ fn = Linear(
+ i,
+ jnp.array(fnconfig["xvalues"]),
+ jnp.array(fnconfig["fxvalues"]),
+ )
+ flist.append(fn)
+ else:
+ raise NotImplementedError(
+ "Function type other than `Linear` not yet supported"
+ )
+ self.parsed_config["math_functions"] = flist
+
+ def _parse_external_loading(self, config):
+ external_loading = {}
+ external_loading["gravity"] = jnp.array(config["external_loading"]["gravity"])
+ external_loading["concentrated_nodal_forces"] = []
+ particle_surface_traction = []
+ if "concentrated_nodal_forces" in config["external_loading"]:
+ cnf_list = []
+ for cnfconfig in config["external_loading"]["concentrated_nodal_forces"]:
+ if "math_function_id" in cnfconfig:
+ fn = self.parsed_config["math_functions"][
+ cnfconfig["math_function_id"]
+ ]
+ else:
+ fn = Unit(-1)
+ cnf = NodalForce(
+ node_ids=jnp.array(cnfconfig["node_ids"]),
+ function=fn,
+ dir=cnfconfig["dir"],
+ force=cnfconfig["force"],
+ )
+ cnf_list.append(cnf)
+ external_loading["concentrated_nodal_forces"] = cnf_list
+
+ if "particle_surface_traction" in config["external_loading"]:
+ pst_list = []
+ for pstconfig in config["external_loading"]["particle_surface_traction"]:
+ if "math_function_id" in pstconfig:
+ fn = self.parsed_config["math_functions"][
+ pstconfig["math_function_id"]
+ ]
+ else:
+ fn = Unit(-1)
+ pst = ParticleTraction(
+ pset=pstconfig["pset"],
+ pids=jnp.array(pstconfig["pids"]),
+ function=fn,
+ dir=pstconfig["dir"],
+ traction=pstconfig["traction"],
+ )
+ pst_list.append(pst)
+ particle_surface_traction.extend(pst_list)
+ self.parsed_config["external_loading"] = external_loading
+ self.parsed_config["particle_surface_traction"] = particle_surface_traction
+
+ def _parse_mesh(self, config):
+ element_cls = getattr(mpel, config["mesh"]["element"])
+ mesh_cls = getattr(mpmesh, f"Mesh{config['meta']['dimension']}D")
+ constraints = [
+ (jnp.asarray(c["node_ids"]), Constraint(c["dir"], c["velocity"]))
+ for c in config["mesh"]["constraints"]
+ ]
+ if config["mesh"]["type"] == "generator":
+ elements = element_cls(
+ config["mesh"]["nelements"],
+ jnp.prod(jnp.array(config["mesh"]["nelements"])),
+ config["mesh"]["element_length"],
+ constraints,
+ concentrated_nodal_forces=self.parsed_config["external_loading"][
+ "concentrated_nodal_forces"
+ ],
+ )
+ else:
+ raise NotImplementedError(
+ "Mesh type other than `generator` not yet supported."
+ )
+ self.parsed_config["elements"] = elements
+ mesh = mesh_cls(self.parsed_config)
+ return mesh
+
+def parse(self)
+def parse(self):
+ with open(self._filepath, "rb") as f:
+ self._fileconfig = tl.load(f)
+
+ self._parse_meta(self._fileconfig)
+ self._parse_output(self._fileconfig)
+ self._parse_materials(self._fileconfig)
+ self._parse_particles(self._fileconfig)
+ if "math_functions" in self._fileconfig:
+ self._parse_math_functions(self._fileconfig)
+ self._parse_external_loading(self._fileconfig)
+ mesh = self._parse_mesh(self._fileconfig)
+ return mesh
+diffmpm.materialimport abc
+from typing import Tuple
+
+import jax.numpy as jnp
+from jax.tree_util import register_pytree_node_class
+
+
+class Material(abc.ABC):
+ """Base material class."""
+
+ _props: Tuple[str, ...]
+
+ def __init__(self, material_properties):
+ """Initialize material properties.
+
+ Parameters
+ ----------
+ material_properties: dict
+ A key-value map for various material properties.
+ """
+ self.properties = material_properties
+
+ # @abc.abstractmethod
+ def tree_flatten(self):
+ """Flatten this class as PyTree Node."""
+ return (tuple(), self.properties)
+
+ # @abc.abstractmethod
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten this class as PyTree Node."""
+ del children
+ return cls(aux_data)
+
+ @abc.abstractmethod
+ def __repr__(self):
+ """Repr for Material class."""
+ ...
+
+ @abc.abstractmethod
+ def compute_stress(self):
+ """Compute stress for the material."""
+ ...
+
+ def validate_props(self, material_properties):
+ for key in self._props:
+ if key not in material_properties:
+ raise KeyError(
+ f"'{key}' should be present in `material_properties` "
+ f"for {self.__class__.__name__} materials."
+ )
+
+
+@register_pytree_node_class
+class LinearElastic(Material):
+ """Linear Elastic Material."""
+
+ _props = ("density", "youngs_modulus", "poisson_ratio")
+
+ def __init__(self, material_properties):
+ """Create a Linear Elastic material.
+
+ Parameters
+ ----------
+ material_properties: dict
+ Dictionary with material properties. For linear elastic
+ materials, 'density' and 'youngs_modulus' are required keys.
+ """
+ self.validate_props(material_properties)
+ youngs_modulus = material_properties["youngs_modulus"]
+ poisson_ratio = material_properties["poisson_ratio"]
+ density = material_properties["density"]
+ bulk_modulus = youngs_modulus / (3 * (1 - 2 * poisson_ratio))
+ constrained_modulus = (
+ youngs_modulus
+ * (1 - poisson_ratio)
+ / ((1 + poisson_ratio) * (1 - 2 * poisson_ratio))
+ )
+ shear_modulus = youngs_modulus / (2 * (1 + poisson_ratio))
+ # Wave velocities
+ vp = jnp.sqrt(constrained_modulus / density)
+ vs = jnp.sqrt(shear_modulus / density)
+ self.properties = {
+ **material_properties,
+ "bulk_modulus": bulk_modulus,
+ "pwave_velocity": vp,
+ "swave_velocity": vs,
+ }
+ self._compute_elastic_tensor()
+
+ def __repr__(self):
+ return f"LinearElastic(props={self.properties})"
+
+ def _compute_elastic_tensor(self):
+ G = self.properties["youngs_modulus"] / (
+ 2 * (1 + self.properties["poisson_ratio"])
+ )
+
+ a1 = self.properties["bulk_modulus"] + (4 * G / 3)
+ a2 = self.properties["bulk_modulus"] - (2 * G / 3)
+
+ self.de = jnp.array(
+ [
+ [a1, a2, a2, 0, 0, 0],
+ [a2, a1, a2, 0, 0, 0],
+ [a2, a2, a1, 0, 0, 0],
+ [0, 0, 0, G, 0, 0],
+ [0, 0, 0, 0, G, 0],
+ [0, 0, 0, 0, 0, G],
+ ]
+ )
+
+ def compute_stress(self, dstrain):
+ """Compute material stress."""
+ dstress = self.de @ dstrain
+ return dstress
+
+
+@register_pytree_node_class
+class SimpleMaterial(Material):
+ _props = ("E", "density")
+
+ def __init__(self, material_properties):
+ self.validate_props(material_properties)
+ self.properties = material_properties
+
+ def __repr__(self):
+ return f"SimpleMaterial(props={self.properties})"
+
+ def compute_stress(self, dstrain):
+ return dstrain * self.properties["E"]
+
+class LinearElastic
+(material_properties)
+Linear Elastic Material.
+Create a Linear Elastic material.
+material_properties : dictmaterials, 'density' and 'youngs_modulus' are required keys.
@register_pytree_node_class
+class LinearElastic(Material):
+ """Linear Elastic Material."""
+
+ _props = ("density", "youngs_modulus", "poisson_ratio")
+
+ def __init__(self, material_properties):
+ """Create a Linear Elastic material.
+
+ Parameters
+ ----------
+ material_properties: dict
+ Dictionary with material properties. For linear elastic
+ materials, 'density' and 'youngs_modulus' are required keys.
+ """
+ self.validate_props(material_properties)
+ youngs_modulus = material_properties["youngs_modulus"]
+ poisson_ratio = material_properties["poisson_ratio"]
+ density = material_properties["density"]
+ bulk_modulus = youngs_modulus / (3 * (1 - 2 * poisson_ratio))
+ constrained_modulus = (
+ youngs_modulus
+ * (1 - poisson_ratio)
+ / ((1 + poisson_ratio) * (1 - 2 * poisson_ratio))
+ )
+ shear_modulus = youngs_modulus / (2 * (1 + poisson_ratio))
+ # Wave velocities
+ vp = jnp.sqrt(constrained_modulus / density)
+ vs = jnp.sqrt(shear_modulus / density)
+ self.properties = {
+ **material_properties,
+ "bulk_modulus": bulk_modulus,
+ "pwave_velocity": vp,
+ "swave_velocity": vs,
+ }
+ self._compute_elastic_tensor()
+
+ def __repr__(self):
+ return f"LinearElastic(props={self.properties})"
+
+ def _compute_elastic_tensor(self):
+ G = self.properties["youngs_modulus"] / (
+ 2 * (1 + self.properties["poisson_ratio"])
+ )
+
+ a1 = self.properties["bulk_modulus"] + (4 * G / 3)
+ a2 = self.properties["bulk_modulus"] - (2 * G / 3)
+
+ self.de = jnp.array(
+ [
+ [a1, a2, a2, 0, 0, 0],
+ [a2, a1, a2, 0, 0, 0],
+ [a2, a2, a1, 0, 0, 0],
+ [0, 0, 0, G, 0, 0],
+ [0, 0, 0, 0, G, 0],
+ [0, 0, 0, 0, 0, G],
+ ]
+ )
+
+ def compute_stress(self, dstrain):
+ """Compute material stress."""
+ dstress = self.de @ dstrain
+ return dstress
+
+def compute_stress(self, dstrain)
+Compute material stress.
def compute_stress(self, dstrain):
+ """Compute material stress."""
+ dstress = self.de @ dstrain
+ return dstress
+Material:
+
+
+class Material
+(material_properties)
+Base material class.
+Initialize material properties.
+material_properties : dictclass Material(abc.ABC):
+ """Base material class."""
+
+ _props: Tuple[str, ...]
+
+ def __init__(self, material_properties):
+ """Initialize material properties.
+
+ Parameters
+ ----------
+ material_properties: dict
+ A key-value map for various material properties.
+ """
+ self.properties = material_properties
+
+ # @abc.abstractmethod
+ def tree_flatten(self):
+ """Flatten this class as PyTree Node."""
+ return (tuple(), self.properties)
+
+ # @abc.abstractmethod
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten this class as PyTree Node."""
+ del children
+ return cls(aux_data)
+
+ @abc.abstractmethod
+ def __repr__(self):
+ """Repr for Material class."""
+ ...
+
+ @abc.abstractmethod
+ def compute_stress(self):
+ """Compute stress for the material."""
+ ...
+
+ def validate_props(self, material_properties):
+ for key in self._props:
+ if key not in material_properties:
+ raise KeyError(
+ f"'{key}' should be present in `material_properties` "
+ f"for {self.__class__.__name__} materials."
+ )
+
+def tree_unflatten(aux_data, children)
+Unflatten this class as PyTree Node.
@classmethod
+def tree_unflatten(cls, aux_data, children):
+ """Unflatten this class as PyTree Node."""
+ del children
+ return cls(aux_data)
+
+def compute_stress(self)
+Compute stress for the material.
@abc.abstractmethod
+def compute_stress(self):
+ """Compute stress for the material."""
+ ...
+
+def tree_flatten(self)
+Flatten this class as PyTree Node.
def tree_flatten(self):
+ """Flatten this class as PyTree Node."""
+ return (tuple(), self.properties)
+
+def validate_props(self, material_properties)
+def validate_props(self, material_properties):
+ for key in self._props:
+ if key not in material_properties:
+ raise KeyError(
+ f"'{key}' should be present in `material_properties` "
+ f"for {self.__class__.__name__} materials."
+ )
+
+class SimpleMaterial
+(material_properties)
+Base material class.
+Initialize material properties.
+material_properties : dict@register_pytree_node_class
+class SimpleMaterial(Material):
+ _props = ("E", "density")
+
+ def __init__(self, material_properties):
+ self.validate_props(material_properties)
+ self.properties = material_properties
+
+ def __repr__(self):
+ return f"SimpleMaterial(props={self.properties})"
+
+ def compute_stress(self, dstrain):
+ return dstrain * self.properties["E"]
+Material:
+
+diffmpm.meshimport abc
+from typing import Callable, Sequence, Tuple
+
+import jax.numpy as jnp
+from jax.tree_util import register_pytree_node_class
+
+from diffmpm.element import _Element
+from diffmpm.particle import Particles
+
+__all__ = ["_MeshBase", "Mesh1D", "Mesh2D"]
+
+
+class _MeshBase(abc.ABC):
+ """Base class for Meshes.
+
+ .. note::
+ If attributes other than elements and particles are added
+ then the child class should also implement `tree_flatten` and
+ `tree_unflatten` correctly or that information will get lost.
+ """
+
+ ndim: int
+
+ def __init__(self, config: dict):
+ """Initialize mesh using configuration."""
+ self.particles: Sequence[Particles] = config["particles"]
+ self.elements: _Element = config["elements"]
+ self.particle_tractions = config["particle_surface_traction"]
+
+ # TODO: Convert to using jax directives for loop
+ def apply_on_elements(self, function: str, args: Tuple = ()):
+ """Apply a given function to elements.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `_Element`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ f = getattr(self.elements, function)
+ for particle_set in self.particles:
+ f(particle_set, *args)
+
+ # TODO: Convert to using jax directives for loop
+ def apply_on_particles(self, function: str, args: Tuple = ()):
+ """Apply a given function to particles.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `Particles`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ for particle_set in self.particles:
+ f = getattr(particle_set, function)
+ f(self.elements, *args)
+
+ def apply_traction_on_particles(self, curr_time: float):
+ """Apply tractions on particles.
+
+ Parameters
+ ----------
+ curr_time: float
+ Current time in the simulation.
+ """
+ self.apply_on_particles("zero_traction")
+ for ptraction in self.particle_tractions:
+ factor = ptraction.function.value(curr_time)
+ traction_val = factor * ptraction.traction
+ for i, pset_id in enumerate(ptraction.pset):
+ self.particles[pset_id].assign_traction(
+ ptraction.pids[i], ptraction.dir, traction_val
+ )
+
+ self.apply_on_elements("apply_particle_traction_forces")
+
+ def tree_flatten(self):
+ children = (self.particles, self.elements)
+ aux_data = self.particle_tractions
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ {
+ "particles": children[0],
+ "elements": children[1],
+ "particle_surface_traction": aux_data,
+ }
+ )
+
+
+@register_pytree_node_class
+class Mesh1D(_MeshBase):
+ """1D Mesh class with nodes, elements, and particles."""
+
+ def __init__(self, config: dict):
+ """Initialize a 1D Mesh.
+
+ Parameters
+ ----------
+ config: dict
+ Configuration to be used for initialization. It _should_
+ contain `elements` and `particles` keys.
+ """
+ self.ndim = 1
+ super().__init__(config)
+
+
+@register_pytree_node_class
+class Mesh2D(_MeshBase):
+ """1D Mesh class with nodes, elements, and particles."""
+
+ def __init__(self, config: dict):
+ """Initialize a 2D Mesh.
+
+ Parameters
+ ----------
+ config: dict
+ Configuration to be used for initialization. It _should_
+ contain `elements` and `particles` keys.
+ """
+ self.ndim = 2
+ super().__init__(config)
+
+class Mesh1D
+(config: dict)
+1D Mesh class with nodes, elements, and particles.
+Initialize a 1D Mesh.
+config : dictelements and particles keys.@register_pytree_node_class
+class Mesh1D(_MeshBase):
+ """1D Mesh class with nodes, elements, and particles."""
+
+ def __init__(self, config: dict):
+ """Initialize a 1D Mesh.
+
+ Parameters
+ ----------
+ config: dict
+ Configuration to be used for initialization. It _should_
+ contain `elements` and `particles` keys.
+ """
+ self.ndim = 1
+ super().__init__(config)
+
+class Mesh2D
+(config: dict)
+1D Mesh class with nodes, elements, and particles.
+Initialize a 2D Mesh.
+config : dictelements and particles keys.@register_pytree_node_class
+class Mesh2D(_MeshBase):
+ """1D Mesh class with nodes, elements, and particles."""
+
+ def __init__(self, config: dict):
+ """Initialize a 2D Mesh.
+
+ Parameters
+ ----------
+ config: dict
+ Configuration to be used for initialization. It _should_
+ contain `elements` and `particles` keys.
+ """
+ self.ndim = 2
+ super().__init__(config)
+
+class _MeshBase
+(config: dict)
+Base class for Meshes.
+Note
+If attributes other than elements and particles are added
+then the child class should also implement tree_flatten and
+tree_unflatten correctly or that information will get lost.
Initialize mesh using configuration.
class _MeshBase(abc.ABC):
+ """Base class for Meshes.
+
+ .. note::
+ If attributes other than elements and particles are added
+ then the child class should also implement `tree_flatten` and
+ `tree_unflatten` correctly or that information will get lost.
+ """
+
+ ndim: int
+
+ def __init__(self, config: dict):
+ """Initialize mesh using configuration."""
+ self.particles: Sequence[Particles] = config["particles"]
+ self.elements: _Element = config["elements"]
+ self.particle_tractions = config["particle_surface_traction"]
+
+ # TODO: Convert to using jax directives for loop
+ def apply_on_elements(self, function: str, args: Tuple = ()):
+ """Apply a given function to elements.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `_Element`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ f = getattr(self.elements, function)
+ for particle_set in self.particles:
+ f(particle_set, *args)
+
+ # TODO: Convert to using jax directives for loop
+ def apply_on_particles(self, function: str, args: Tuple = ()):
+ """Apply a given function to particles.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `Particles`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ for particle_set in self.particles:
+ f = getattr(particle_set, function)
+ f(self.elements, *args)
+
+ def apply_traction_on_particles(self, curr_time: float):
+ """Apply tractions on particles.
+
+ Parameters
+ ----------
+ curr_time: float
+ Current time in the simulation.
+ """
+ self.apply_on_particles("zero_traction")
+ for ptraction in self.particle_tractions:
+ factor = ptraction.function.value(curr_time)
+ traction_val = factor * ptraction.traction
+ for i, pset_id in enumerate(ptraction.pset):
+ self.particles[pset_id].assign_traction(
+ ptraction.pids[i], ptraction.dir, traction_val
+ )
+
+ self.apply_on_elements("apply_particle_traction_forces")
+
+ def tree_flatten(self):
+ children = (self.particles, self.elements)
+ aux_data = self.particle_tractions
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ {
+ "particles": children[0],
+ "elements": children[1],
+ "particle_surface_traction": aux_data,
+ }
+ )
+var ndim : int
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ return cls(
+ {
+ "particles": children[0],
+ "elements": children[1],
+ "particle_surface_traction": aux_data,
+ }
+ )
+
+def apply_on_elements(self, function: str, args: Tuple = ())
+Apply a given function to elements.
+function : str_Element.args : tupledef apply_on_elements(self, function: str, args: Tuple = ()):
+ """Apply a given function to elements.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `_Element`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ f = getattr(self.elements, function)
+ for particle_set in self.particles:
+ f(particle_set, *args)
+
+def apply_on_particles(self, function: str, args: Tuple = ())
+Apply a given function to particles.
+function : strParticles.args : tupledef apply_on_particles(self, function: str, args: Tuple = ()):
+ """Apply a given function to particles.
+
+ Parameters
+ ----------
+ function: str
+ A string corresponding to a function name in `Particles`.
+ args: tuple
+ Parameters to be passed to the function.
+ """
+ for particle_set in self.particles:
+ f = getattr(particle_set, function)
+ f(self.elements, *args)
+
+def apply_traction_on_particles(self, curr_time: float)
+Apply tractions on particles.
+curr_time : floatdef apply_traction_on_particles(self, curr_time: float):
+ """Apply tractions on particles.
+
+ Parameters
+ ----------
+ curr_time: float
+ Current time in the simulation.
+ """
+ self.apply_on_particles("zero_traction")
+ for ptraction in self.particle_tractions:
+ factor = ptraction.function.value(curr_time)
+ traction_val = factor * ptraction.traction
+ for i, pset_id in enumerate(ptraction.pset):
+ self.particles[pset_id].assign_traction(
+ ptraction.pids[i], ptraction.dir, traction_val
+ )
+
+ self.apply_on_elements("apply_particle_traction_forces")
+
+def tree_flatten(self)
+def tree_flatten(self):
+ children = (self.particles, self.elements)
+ aux_data = self.particle_tractions
+ return (children, aux_data)
+diffmpm.nodefrom typing import Optional, Sized, Tuple
+
+import jax.numpy as jnp
+from jax.tree_util import register_pytree_node_class
+from jax.typing import ArrayLike
+
+
+@register_pytree_node_class
+class Nodes(Sized):
+ """Nodes container class.
+
+ Keeps track of all values required for nodal points.
+
+ Attributes
+ ----------
+ nnodes : int
+ Number of nodes stored.
+ loc : ArrayLike
+ Location of all the nodes.
+ velocity : array_like
+ Velocity of all the nodes.
+ mass : ArrayLike
+ Mass of all the nodes.
+ momentum : array_like
+ Momentum of all the nodes.
+ f_int : ArrayLike
+ Internal forces on all the nodes.
+ f_ext : ArrayLike
+ External forces present on all the nodes.
+ f_damp : ArrayLike
+ Damping forces on the nodes.
+ """
+
+ def __init__(
+ self,
+ nnodes: int,
+ loc: ArrayLike,
+ initialized: Optional[bool] = None,
+ data: Tuple[ArrayLike, ...] = tuple(),
+ ):
+ """Initialize container for Nodes.
+
+ Parameters
+ ----------
+ nnodes : int
+ Number of nodes stored.
+ loc : ArrayLike
+ Locations of all the nodes. Expected shape (nnodes, 1, ndim)
+ initialized: bool
+ `False` if node property arrays like mass need to be initialized.
+ If `True`, they are set to values from `data`.
+ data: tuple
+ Tuple of length 7 that sets arrays for mass, density, volume,
+ and forces. Mainly used by JAX while unflattening.
+ """
+ self.nnodes = nnodes
+ loc = jnp.asarray(loc, dtype=jnp.float32)
+ if loc.ndim != 3:
+ raise ValueError(
+ f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}"
+ )
+ self.loc = loc
+
+ if initialized is None:
+ self.velocity = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.acceleration = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.mass = jnp.ones((self.loc.shape[0], 1, 1), dtype=jnp.float32)
+ self.momentum = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_int = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_ext = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_damp = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ else:
+ (
+ self.velocity,
+ self.acceleration,
+ self.mass,
+ self.momentum,
+ self.f_int,
+ self.f_ext,
+ self.f_damp,
+ ) = data # type: ignore
+ self.initialized = True
+
+ def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.initialized,
+ self.velocity,
+ self.acceleration,
+ self.mass,
+ self.momentum,
+ self.f_int,
+ self.f_ext,
+ self.f_damp,
+ )
+ aux_data = (self.nnodes,)
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(aux_data[0], children[0], initialized=children[1], data=children[2:])
+
+ def reset_values(self):
+ """Reset nodal parameter values except location."""
+ self.velocity = self.velocity.at[:].set(0)
+ self.acceleration = self.velocity.at[:].set(0)
+ self.mass = self.mass.at[:].set(0)
+ self.momentum = self.momentum.at[:].set(0)
+ self.f_int = self.f_int.at[:].set(0)
+ self.f_ext = self.f_ext.at[:].set(0)
+ self.f_damp = self.f_damp.at[:].set(0)
+
+ def __len__(self):
+ """Set length of class as number of nodes."""
+ return self.nnodes
+
+ def __repr__(self):
+ """Repr containing number of nodes."""
+ return f"Nodes(n={self.nnodes})"
+
+ def get_total_force(self):
+ """Calculate total force on the nodes."""
+ return self.f_int + self.f_ext + self.f_damp
+
+class Nodes
+(nnodes: int, loc: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], initialized: Optional[bool] = None, data: Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], ...] = ())
+Nodes container class.
+Keeps track of all values required for nodal points.
+nnodes : intloc : ArrayLikevelocity : array_likemass : ArrayLikemomentum : array_likef_int : ArrayLikef_ext : ArrayLikef_damp : ArrayLikeInitialize container for Nodes.
+nnodes : intloc : ArrayLikeinitialized : boolFalse if node property arrays like mass need to be initialized.
+If True, they are set to values from data.data : tuple@register_pytree_node_class
+class Nodes(Sized):
+ """Nodes container class.
+
+ Keeps track of all values required for nodal points.
+
+ Attributes
+ ----------
+ nnodes : int
+ Number of nodes stored.
+ loc : ArrayLike
+ Location of all the nodes.
+ velocity : array_like
+ Velocity of all the nodes.
+ mass : ArrayLike
+ Mass of all the nodes.
+ momentum : array_like
+ Momentum of all the nodes.
+ f_int : ArrayLike
+ Internal forces on all the nodes.
+ f_ext : ArrayLike
+ External forces present on all the nodes.
+ f_damp : ArrayLike
+ Damping forces on the nodes.
+ """
+
+ def __init__(
+ self,
+ nnodes: int,
+ loc: ArrayLike,
+ initialized: Optional[bool] = None,
+ data: Tuple[ArrayLike, ...] = tuple(),
+ ):
+ """Initialize container for Nodes.
+
+ Parameters
+ ----------
+ nnodes : int
+ Number of nodes stored.
+ loc : ArrayLike
+ Locations of all the nodes. Expected shape (nnodes, 1, ndim)
+ initialized: bool
+ `False` if node property arrays like mass need to be initialized.
+ If `True`, they are set to values from `data`.
+ data: tuple
+ Tuple of length 7 that sets arrays for mass, density, volume,
+ and forces. Mainly used by JAX while unflattening.
+ """
+ self.nnodes = nnodes
+ loc = jnp.asarray(loc, dtype=jnp.float32)
+ if loc.ndim != 3:
+ raise ValueError(
+ f"`loc` should be of size (nnodes, 1, ndim); found {loc.shape}"
+ )
+ self.loc = loc
+
+ if initialized is None:
+ self.velocity = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.acceleration = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.mass = jnp.ones((self.loc.shape[0], 1, 1), dtype=jnp.float32)
+ self.momentum = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_int = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_ext = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ self.f_damp = jnp.zeros_like(self.loc, dtype=jnp.float32)
+ else:
+ (
+ self.velocity,
+ self.acceleration,
+ self.mass,
+ self.momentum,
+ self.f_int,
+ self.f_ext,
+ self.f_damp,
+ ) = data # type: ignore
+ self.initialized = True
+
+ def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.initialized,
+ self.velocity,
+ self.acceleration,
+ self.mass,
+ self.momentum,
+ self.f_int,
+ self.f_ext,
+ self.f_damp,
+ )
+ aux_data = (self.nnodes,)
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(aux_data[0], children[0], initialized=children[1], data=children[2:])
+
+ def reset_values(self):
+ """Reset nodal parameter values except location."""
+ self.velocity = self.velocity.at[:].set(0)
+ self.acceleration = self.velocity.at[:].set(0)
+ self.mass = self.mass.at[:].set(0)
+ self.momentum = self.momentum.at[:].set(0)
+ self.f_int = self.f_int.at[:].set(0)
+ self.f_ext = self.f_ext.at[:].set(0)
+ self.f_damp = self.f_damp.at[:].set(0)
+
+ def __len__(self):
+ """Set length of class as number of nodes."""
+ return self.nnodes
+
+ def __repr__(self):
+ """Repr containing number of nodes."""
+ return f"Nodes(n={self.nnodes})"
+
+ def get_total_force(self):
+ """Calculate total force on the nodes."""
+ return self.f_int + self.f_ext + self.f_damp
+
+def tree_unflatten(aux_data, children)
+Unflatten class from Pytree type.
@classmethod
+def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(aux_data[0], children[0], initialized=children[1], data=children[2:])
+
+def get_total_force(self)
+Calculate total force on the nodes.
def get_total_force(self):
+ """Calculate total force on the nodes."""
+ return self.f_int + self.f_ext + self.f_damp
+
+def reset_values(self)
+Reset nodal parameter values except location.
def reset_values(self):
+ """Reset nodal parameter values except location."""
+ self.velocity = self.velocity.at[:].set(0)
+ self.acceleration = self.velocity.at[:].set(0)
+ self.mass = self.mass.at[:].set(0)
+ self.momentum = self.momentum.at[:].set(0)
+ self.f_int = self.f_int.at[:].set(0)
+ self.f_ext = self.f_ext.at[:].set(0)
+ self.f_damp = self.f_damp.at[:].set(0)
+
+def tree_flatten(self)
+Flatten class as Pytree type.
def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.initialized,
+ self.velocity,
+ self.acceleration,
+ self.mass,
+ self.momentum,
+ self.f_int,
+ self.f_ext,
+ self.f_damp,
+ )
+ aux_data = (self.nnodes,)
+ return (children, aux_data)
+diffmpm.particlefrom typing import Optional, Sized, Tuple
+
+import jax.numpy as jnp
+from jax import lax, vmap
+from jax.tree_util import register_pytree_node_class
+from jax.typing import ArrayLike
+
+from diffmpm.element import _Element
+from diffmpm.material import Material
+
+
+@register_pytree_node_class
+class Particles(Sized):
+ """Container class for a set of particles."""
+
+ def __init__(
+ self,
+ loc: ArrayLike,
+ material: Material,
+ element_ids: ArrayLike,
+ initialized: Optional[bool] = None,
+ data: Optional[Tuple[ArrayLike, ...]] = None,
+ ):
+ """Initialize a container of particles.
+
+ Parameters
+ ----------
+ loc: ArrayLike
+ Location of the particles. Expected shape (nparticles, 1, ndim)
+ material: diffmpm.material.Material
+ Type of material for the set of particles.
+ element_ids: ArrayLike
+ The element ids that the particles belong to. This contains
+ information that will make sense only with the information of
+ the mesh that is being considered.
+ initialized: bool
+ `False` if particle property arrays like mass need to be initialized.
+ If `True`, they are set to values from `data`.
+ data: tuple
+ Tuple of length 13 that sets arrays for mass, density, volume,
+ velocity, acceleration, momentum, strain, stress, strain_rate,
+ dstrain, f_ext, reference_loc and volumetric_strain_centroid.
+ """
+ self.material = material
+ self.element_ids = element_ids
+ loc = jnp.asarray(loc, dtype=jnp.float32)
+ if loc.ndim != 3:
+ raise ValueError(
+ f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}"
+ )
+ self.loc = loc
+
+ if initialized is None:
+ self.mass = jnp.ones((self.loc.shape[0], 1, 1))
+ self.density = (
+ jnp.ones_like(self.mass) * self.material.properties["density"]
+ )
+ self.volume = jnp.ones_like(self.mass)
+ self.size = jnp.zeros_like(self.loc)
+ self.velocity = jnp.zeros_like(self.loc)
+ self.acceleration = jnp.zeros_like(self.loc)
+ self.momentum = jnp.zeros_like(self.loc)
+ self.strain = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.stress = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.strain_rate = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.dstrain = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.f_ext = jnp.zeros_like(self.loc)
+ self.traction = jnp.zeros_like(self.loc)
+ self.reference_loc = jnp.zeros_like(self.loc)
+ self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1))
+ self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1))
+ else:
+ (
+ self.mass,
+ self.density,
+ self.volume,
+ self.size,
+ self.velocity,
+ self.acceleration,
+ self.momentum,
+ self.strain,
+ self.stress,
+ self.strain_rate,
+ self.dstrain,
+ self.f_ext,
+ self.traction,
+ self.reference_loc,
+ self.dvolumetric_strain,
+ self.volumetric_strain_centroid,
+ ) = data # type: ignore
+ self.initialized = True
+
+ def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.element_ids,
+ self.initialized,
+ self.mass,
+ self.density,
+ self.volume,
+ self.size,
+ self.velocity,
+ self.acceleration,
+ self.momentum,
+ self.strain,
+ self.stress,
+ self.strain_rate,
+ self.dstrain,
+ self.f_ext,
+ self.traction,
+ self.reference_loc,
+ self.dvolumetric_strain,
+ self.volumetric_strain_centroid,
+ )
+ aux_data = (self.material,)
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(
+ children[0],
+ aux_data[0],
+ children[1],
+ initialized=children[2],
+ data=children[3:],
+ )
+
+ def __len__(self) -> int:
+ """Set length of the class as number of particles."""
+ return self.loc.shape[0]
+
+ def __repr__(self) -> str:
+ """Informative repr showing number of particles."""
+ return f"Particles(nparticles={len(self)})"
+
+ def set_mass_volume(self, m: ArrayLike):
+ """Set particle mass.
+
+ Parameters
+ ----------
+ m: float, array_like
+ Mass to be set for particles. If scalar, mass for all
+ particles is set to this value.
+ """
+ m = jnp.asarray(m)
+ if jnp.isscalar(m):
+ self.mass = jnp.ones_like(self.loc) * m
+ elif m.shape == self.mass.shape:
+ self.mass = m
+ else:
+ raise ValueError(
+ f"Incompatible shapes. Expected {self.mass.shape}, " f"found {m.shape}."
+ )
+ self.volume = jnp.divide(self.mass, self.material.properties["density"])
+
+ def compute_volume(self, elements: _Element, total_elements: int):
+ """Compute volume of all particles.
+
+ Parameters
+ ----------
+ elements: diffmpm._Element
+ Elements that the particles are present in, and are used to
+ compute the particles' volumes.
+ total_elements: int
+ Total elements present in `elements`.
+ """
+ particles_per_element = jnp.bincount(
+ self.element_ids, length=elements.total_elements
+ )
+ vol = (
+ elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore
+ / particles_per_element[self.element_ids]
+ )
+ self.volume = self.volume.at[:, 0, 0].set(vol)
+ self.size = self.size.at[:].set(self.volume ** (1 / self.size.shape[-1]))
+ self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze())
+
+ def update_natural_coords(self, elements: _Element):
+ r"""Update natural coordinates for the particles.
+
+ Whenever the particles' physical coordinates change, their
+ natural coordinates need to be updated. This function updates
+ the natural coordinates of the particles based on the element
+ a particle is a part of. The update formula is
+
+ \[
+ \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)
+ \]
+
+ where \(x_i^e\) are the nodal coordinates of the element the
+ particle is in. If a particle is not in any element
+ (element_id = -1), its natural coordinate is set to 0.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements based on which to update the natural coordinates
+ of the particles.
+ """
+ t = vmap(elements.id_to_node_loc)(self.element_ids)
+ xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * (
+ 2 / (t[:, 2, ...] - t[:, 0, ...])
+ )
+ self.reference_loc = xi_coords
+
+ def update_position_velocity(
+ self, elements: _Element, dt: float, velocity_update: bool
+ ):
+ """Transfer nodal velocity to particles and update particle position.
+
+ The velocity is calculated based on the total force at nodes.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to transfer the velocity.
+ dt: float
+ Timestep.
+ velocity_update: bool
+ If True, velocity is directly used as nodal velocity, else
+ velocity is calculated is interpolated nodal acceleration
+ multiplied by dt. Default is False.
+ """
+ mapped_positions = elements.shapefn(self.reference_loc)
+ mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1)
+ nodal_velocity = jnp.sum(
+ mapped_positions * elements.nodes.velocity[mapped_ids], axis=1
+ )
+ nodal_acceleration = jnp.sum(
+ mapped_positions * elements.nodes.acceleration[mapped_ids],
+ axis=1,
+ )
+ self.velocity = self.velocity.at[:].set(
+ lax.cond(
+ velocity_update,
+ lambda sv, nv, na, t: nv,
+ lambda sv, nv, na, t: sv + na * t,
+ self.velocity,
+ nodal_velocity,
+ nodal_acceleration,
+ dt,
+ )
+ )
+ self.loc = self.loc.at[:].add(nodal_velocity * dt)
+ self.momentum = self.momentum.at[:].set(self.mass * self.velocity)
+
+ def compute_strain(self, elements: _Element, dt: float):
+ """Compute the strain on all particles.
+
+ This is done by first calculating the strain rate for the particles
+ and then calculating strain as `strain += strain rate * dt`.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to calculate the strain.
+ dt : float
+ Timestep.
+ """
+ mapped_coords = vmap(elements.id_to_node_loc)(self.element_ids).squeeze(2)
+ dn_dx_ = vmap(elements.shapefn_grad)(
+ self.reference_loc[:, jnp.newaxis, ...], mapped_coords
+ )
+ self.strain_rate = self._compute_strain_rate(dn_dx_, elements)
+ self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt)
+
+ self.strain = self.strain.at[:].add(self.dstrain)
+ centroids = jnp.zeros_like(self.loc)
+ dn_dx_centroid_ = vmap(elements.shapefn_grad)(
+ centroids[:, jnp.newaxis, ...], mapped_coords
+ )
+ strain_rate_centroid = self._compute_strain_rate(dn_dx_centroid_, elements)
+ ndim = self.loc.shape[-1]
+ self.dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1)
+ self.volumetric_strain_centroid = self.volumetric_strain_centroid.at[:].add(
+ self.dvolumetric_strain
+ )
+
+ def _compute_strain_rate(self, dn_dx: ArrayLike, elements: _Element):
+ """Compute the strain rate for particles.
+
+ Parameters
+ ----------
+ dn_dx: ArrayLike
+ The gradient of the shape function. Expected shape
+ `(nparticles, 1, ndim)`
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to calculate the strain rate.
+ """
+ dn_dx = jnp.asarray(dn_dx)
+ strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1)
+ mapped_vel = vmap(elements.id_to_node_vel)(
+ self.element_ids
+ ) # (nparticles, 2, 1)
+
+ temp = mapped_vel.squeeze(2)
+
+ def _step(pid, args):
+ dndx, nvel, strain_rate = args
+ matmul = dndx[pid].T @ nvel[pid]
+ strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0])
+ strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1])
+ strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0])
+ return dndx, nvel, strain_rate
+
+ args = (dn_dx, temp, strain_rate)
+ _, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
+ strain_rate = jnp.where(
+ jnp.abs(strain_rate) < 1e-12, jnp.zeros_like(strain_rate), strain_rate
+ )
+ return strain_rate
+
+ def compute_stress(self, *args):
+ """Compute the strain on all particles.
+
+ This calculation is governed by the material of the
+ particles. The stress calculated by the material is then
+ added to the particles current stress values.
+ """
+ self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain))
+
+ def update_volume(self, *args):
+ """Update volume based on central strain rate."""
+ self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain)
+ self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain)
+
+ def assign_traction(self, pids: ArrayLike, dir: int, traction_: float):
+ """Assign traction to particles.
+
+ Parameters
+ ----------
+ pids: ArrayLike
+ IDs of the particles to which traction should be applied.
+ dir: int
+ The direction in which traction should be applied.
+ traction_: float
+ Traction value to be applied in the direction.
+ """
+ self.traction = self.traction.at[pids, 0, dir].add(
+ traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir]
+ )
+
+ def zero_traction(self, *args):
+ """Set all traction values to 0."""
+ self.traction = self.traction.at[:].set(0)
+
+class Particles
+(loc: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], material: Material, element_ids: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], initialized: Optional[bool] = None, data: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], ...]] = None)
+Container class for a set of particles.
+Initialize a container of particles.
+loc : ArrayLikematerial : Materialelement_ids : ArrayLikeinitialized : boolFalse if particle property arrays like mass need to be initialized.
+If True, they are set to values from data.data : tuple@register_pytree_node_class
+class Particles(Sized):
+ """Container class for a set of particles."""
+
+ def __init__(
+ self,
+ loc: ArrayLike,
+ material: Material,
+ element_ids: ArrayLike,
+ initialized: Optional[bool] = None,
+ data: Optional[Tuple[ArrayLike, ...]] = None,
+ ):
+ """Initialize a container of particles.
+
+ Parameters
+ ----------
+ loc: ArrayLike
+ Location of the particles. Expected shape (nparticles, 1, ndim)
+ material: diffmpm.material.Material
+ Type of material for the set of particles.
+ element_ids: ArrayLike
+ The element ids that the particles belong to. This contains
+ information that will make sense only with the information of
+ the mesh that is being considered.
+ initialized: bool
+ `False` if particle property arrays like mass need to be initialized.
+ If `True`, they are set to values from `data`.
+ data: tuple
+ Tuple of length 13 that sets arrays for mass, density, volume,
+ velocity, acceleration, momentum, strain, stress, strain_rate,
+ dstrain, f_ext, reference_loc and volumetric_strain_centroid.
+ """
+ self.material = material
+ self.element_ids = element_ids
+ loc = jnp.asarray(loc, dtype=jnp.float32)
+ if loc.ndim != 3:
+ raise ValueError(
+ f"`loc` should be of size (nparticles, 1, ndim); " f"found {loc.shape}"
+ )
+ self.loc = loc
+
+ if initialized is None:
+ self.mass = jnp.ones((self.loc.shape[0], 1, 1))
+ self.density = (
+ jnp.ones_like(self.mass) * self.material.properties["density"]
+ )
+ self.volume = jnp.ones_like(self.mass)
+ self.size = jnp.zeros_like(self.loc)
+ self.velocity = jnp.zeros_like(self.loc)
+ self.acceleration = jnp.zeros_like(self.loc)
+ self.momentum = jnp.zeros_like(self.loc)
+ self.strain = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.stress = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.strain_rate = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.dstrain = jnp.zeros((self.loc.shape[0], 6, 1))
+ self.f_ext = jnp.zeros_like(self.loc)
+ self.traction = jnp.zeros_like(self.loc)
+ self.reference_loc = jnp.zeros_like(self.loc)
+ self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1))
+ self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1))
+ else:
+ (
+ self.mass,
+ self.density,
+ self.volume,
+ self.size,
+ self.velocity,
+ self.acceleration,
+ self.momentum,
+ self.strain,
+ self.stress,
+ self.strain_rate,
+ self.dstrain,
+ self.f_ext,
+ self.traction,
+ self.reference_loc,
+ self.dvolumetric_strain,
+ self.volumetric_strain_centroid,
+ ) = data # type: ignore
+ self.initialized = True
+
+ def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.element_ids,
+ self.initialized,
+ self.mass,
+ self.density,
+ self.volume,
+ self.size,
+ self.velocity,
+ self.acceleration,
+ self.momentum,
+ self.strain,
+ self.stress,
+ self.strain_rate,
+ self.dstrain,
+ self.f_ext,
+ self.traction,
+ self.reference_loc,
+ self.dvolumetric_strain,
+ self.volumetric_strain_centroid,
+ )
+ aux_data = (self.material,)
+ return (children, aux_data)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(
+ children[0],
+ aux_data[0],
+ children[1],
+ initialized=children[2],
+ data=children[3:],
+ )
+
+ def __len__(self) -> int:
+ """Set length of the class as number of particles."""
+ return self.loc.shape[0]
+
+ def __repr__(self) -> str:
+ """Informative repr showing number of particles."""
+ return f"Particles(nparticles={len(self)})"
+
+ def set_mass_volume(self, m: ArrayLike):
+ """Set particle mass.
+
+ Parameters
+ ----------
+ m: float, array_like
+ Mass to be set for particles. If scalar, mass for all
+ particles is set to this value.
+ """
+ m = jnp.asarray(m)
+ if jnp.isscalar(m):
+ self.mass = jnp.ones_like(self.loc) * m
+ elif m.shape == self.mass.shape:
+ self.mass = m
+ else:
+ raise ValueError(
+ f"Incompatible shapes. Expected {self.mass.shape}, " f"found {m.shape}."
+ )
+ self.volume = jnp.divide(self.mass, self.material.properties["density"])
+
+ def compute_volume(self, elements: _Element, total_elements: int):
+ """Compute volume of all particles.
+
+ Parameters
+ ----------
+ elements: diffmpm._Element
+ Elements that the particles are present in, and are used to
+ compute the particles' volumes.
+ total_elements: int
+ Total elements present in `elements`.
+ """
+ particles_per_element = jnp.bincount(
+ self.element_ids, length=elements.total_elements
+ )
+ vol = (
+ elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore
+ / particles_per_element[self.element_ids]
+ )
+ self.volume = self.volume.at[:, 0, 0].set(vol)
+ self.size = self.size.at[:].set(self.volume ** (1 / self.size.shape[-1]))
+ self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze())
+
+ def update_natural_coords(self, elements: _Element):
+ r"""Update natural coordinates for the particles.
+
+ Whenever the particles' physical coordinates change, their
+ natural coordinates need to be updated. This function updates
+ the natural coordinates of the particles based on the element
+ a particle is a part of. The update formula is
+
+ \[
+ \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)
+ \]
+
+ where \(x_i^e\) are the nodal coordinates of the element the
+ particle is in. If a particle is not in any element
+ (element_id = -1), its natural coordinate is set to 0.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements based on which to update the natural coordinates
+ of the particles.
+ """
+ t = vmap(elements.id_to_node_loc)(self.element_ids)
+ xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * (
+ 2 / (t[:, 2, ...] - t[:, 0, ...])
+ )
+ self.reference_loc = xi_coords
+
+ def update_position_velocity(
+ self, elements: _Element, dt: float, velocity_update: bool
+ ):
+ """Transfer nodal velocity to particles and update particle position.
+
+ The velocity is calculated based on the total force at nodes.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to transfer the velocity.
+ dt: float
+ Timestep.
+ velocity_update: bool
+ If True, velocity is directly used as nodal velocity, else
+ velocity is calculated is interpolated nodal acceleration
+ multiplied by dt. Default is False.
+ """
+ mapped_positions = elements.shapefn(self.reference_loc)
+ mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1)
+ nodal_velocity = jnp.sum(
+ mapped_positions * elements.nodes.velocity[mapped_ids], axis=1
+ )
+ nodal_acceleration = jnp.sum(
+ mapped_positions * elements.nodes.acceleration[mapped_ids],
+ axis=1,
+ )
+ self.velocity = self.velocity.at[:].set(
+ lax.cond(
+ velocity_update,
+ lambda sv, nv, na, t: nv,
+ lambda sv, nv, na, t: sv + na * t,
+ self.velocity,
+ nodal_velocity,
+ nodal_acceleration,
+ dt,
+ )
+ )
+ self.loc = self.loc.at[:].add(nodal_velocity * dt)
+ self.momentum = self.momentum.at[:].set(self.mass * self.velocity)
+
+ def compute_strain(self, elements: _Element, dt: float):
+ """Compute the strain on all particles.
+
+ This is done by first calculating the strain rate for the particles
+ and then calculating strain as `strain += strain rate * dt`.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to calculate the strain.
+ dt : float
+ Timestep.
+ """
+ mapped_coords = vmap(elements.id_to_node_loc)(self.element_ids).squeeze(2)
+ dn_dx_ = vmap(elements.shapefn_grad)(
+ self.reference_loc[:, jnp.newaxis, ...], mapped_coords
+ )
+ self.strain_rate = self._compute_strain_rate(dn_dx_, elements)
+ self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt)
+
+ self.strain = self.strain.at[:].add(self.dstrain)
+ centroids = jnp.zeros_like(self.loc)
+ dn_dx_centroid_ = vmap(elements.shapefn_grad)(
+ centroids[:, jnp.newaxis, ...], mapped_coords
+ )
+ strain_rate_centroid = self._compute_strain_rate(dn_dx_centroid_, elements)
+ ndim = self.loc.shape[-1]
+ self.dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1)
+ self.volumetric_strain_centroid = self.volumetric_strain_centroid.at[:].add(
+ self.dvolumetric_strain
+ )
+
+ def _compute_strain_rate(self, dn_dx: ArrayLike, elements: _Element):
+ """Compute the strain rate for particles.
+
+ Parameters
+ ----------
+ dn_dx: ArrayLike
+ The gradient of the shape function. Expected shape
+ `(nparticles, 1, ndim)`
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to calculate the strain rate.
+ """
+ dn_dx = jnp.asarray(dn_dx)
+ strain_rate = jnp.zeros((dn_dx.shape[0], 6, 1)) # (nparticles, 6, 1)
+ mapped_vel = vmap(elements.id_to_node_vel)(
+ self.element_ids
+ ) # (nparticles, 2, 1)
+
+ temp = mapped_vel.squeeze(2)
+
+ def _step(pid, args):
+ dndx, nvel, strain_rate = args
+ matmul = dndx[pid].T @ nvel[pid]
+ strain_rate = strain_rate.at[pid, 0].add(matmul[0, 0])
+ strain_rate = strain_rate.at[pid, 1].add(matmul[1, 1])
+ strain_rate = strain_rate.at[pid, 3].add(matmul[0, 1] + matmul[1, 0])
+ return dndx, nvel, strain_rate
+
+ args = (dn_dx, temp, strain_rate)
+ _, _, strain_rate = lax.fori_loop(0, self.loc.shape[0], _step, args)
+ strain_rate = jnp.where(
+ jnp.abs(strain_rate) < 1e-12, jnp.zeros_like(strain_rate), strain_rate
+ )
+ return strain_rate
+
+ def compute_stress(self, *args):
+ """Compute the strain on all particles.
+
+ This calculation is governed by the material of the
+ particles. The stress calculated by the material is then
+ added to the particles current stress values.
+ """
+ self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain))
+
+ def update_volume(self, *args):
+ """Update volume based on central strain rate."""
+ self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain)
+ self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain)
+
+ def assign_traction(self, pids: ArrayLike, dir: int, traction_: float):
+ """Assign traction to particles.
+
+ Parameters
+ ----------
+ pids: ArrayLike
+ IDs of the particles to which traction should be applied.
+ dir: int
+ The direction in which traction should be applied.
+ traction_: float
+ Traction value to be applied in the direction.
+ """
+ self.traction = self.traction.at[pids, 0, dir].add(
+ traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir]
+ )
+
+ def zero_traction(self, *args):
+ """Set all traction values to 0."""
+ self.traction = self.traction.at[:].set(0)
+
+def tree_unflatten(aux_data, children)
+Unflatten class from Pytree type.
@classmethod
+def tree_unflatten(cls, aux_data, children):
+ """Unflatten class from Pytree type."""
+ return cls(
+ children[0],
+ aux_data[0],
+ children[1],
+ initialized=children[2],
+ data=children[3:],
+ )
+
+def assign_traction(self, pids: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], dir: int, traction_: float)
+Assign traction to particles.
+pids : ArrayLikedir : inttraction_ : floatdef assign_traction(self, pids: ArrayLike, dir: int, traction_: float):
+ """Assign traction to particles.
+
+ Parameters
+ ----------
+ pids: ArrayLike
+ IDs of the particles to which traction should be applied.
+ dir: int
+ The direction in which traction should be applied.
+ traction_: float
+ Traction value to be applied in the direction.
+ """
+ self.traction = self.traction.at[pids, 0, dir].add(
+ traction_ * self.volume[pids, 0, 0] / self.size[pids, 0, dir]
+ )
+
+def compute_strain(self, elements: _Element, dt: float)
+Compute the strain on all particles.
+This is done by first calculating the strain rate for the particles
+and then calculating strain as strain += strain rate * dt.
elements : _Elementdt : floatdef compute_strain(self, elements: _Element, dt: float):
+ """Compute the strain on all particles.
+
+ This is done by first calculating the strain rate for the particles
+ and then calculating strain as `strain += strain rate * dt`.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to calculate the strain.
+ dt : float
+ Timestep.
+ """
+ mapped_coords = vmap(elements.id_to_node_loc)(self.element_ids).squeeze(2)
+ dn_dx_ = vmap(elements.shapefn_grad)(
+ self.reference_loc[:, jnp.newaxis, ...], mapped_coords
+ )
+ self.strain_rate = self._compute_strain_rate(dn_dx_, elements)
+ self.dstrain = self.dstrain.at[:].set(self.strain_rate * dt)
+
+ self.strain = self.strain.at[:].add(self.dstrain)
+ centroids = jnp.zeros_like(self.loc)
+ dn_dx_centroid_ = vmap(elements.shapefn_grad)(
+ centroids[:, jnp.newaxis, ...], mapped_coords
+ )
+ strain_rate_centroid = self._compute_strain_rate(dn_dx_centroid_, elements)
+ ndim = self.loc.shape[-1]
+ self.dvolumetric_strain = dt * strain_rate_centroid[:, :ndim].sum(axis=1)
+ self.volumetric_strain_centroid = self.volumetric_strain_centroid.at[:].add(
+ self.dvolumetric_strain
+ )
+
+def compute_stress(self, *args)
+Compute the strain on all particles.
+This calculation is governed by the material of the +particles. The stress calculated by the material is then +added to the particles current stress values.
def compute_stress(self, *args):
+ """Compute the strain on all particles.
+
+ This calculation is governed by the material of the
+ particles. The stress calculated by the material is then
+ added to the particles current stress values.
+ """
+ self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain))
+
+def compute_volume(self, elements: _Element, total_elements: int)
+Compute volume of all particles.
+elements : diffmpm._Elementtotal_elements : intelements.def compute_volume(self, elements: _Element, total_elements: int):
+ """Compute volume of all particles.
+
+ Parameters
+ ----------
+ elements: diffmpm._Element
+ Elements that the particles are present in, and are used to
+ compute the particles' volumes.
+ total_elements: int
+ Total elements present in `elements`.
+ """
+ particles_per_element = jnp.bincount(
+ self.element_ids, length=elements.total_elements
+ )
+ vol = (
+ elements.volume.squeeze((1, 2))[self.element_ids] # type: ignore
+ / particles_per_element[self.element_ids]
+ )
+ self.volume = self.volume.at[:, 0, 0].set(vol)
+ self.size = self.size.at[:].set(self.volume ** (1 / self.size.shape[-1]))
+ self.mass = self.mass.at[:, 0, 0].set(vol * self.density.squeeze())
+
+def set_mass_volume(self, m: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex])
+Set particle mass.
+m : float, array_likedef set_mass_volume(self, m: ArrayLike):
+ """Set particle mass.
+
+ Parameters
+ ----------
+ m: float, array_like
+ Mass to be set for particles. If scalar, mass for all
+ particles is set to this value.
+ """
+ m = jnp.asarray(m)
+ if jnp.isscalar(m):
+ self.mass = jnp.ones_like(self.loc) * m
+ elif m.shape == self.mass.shape:
+ self.mass = m
+ else:
+ raise ValueError(
+ f"Incompatible shapes. Expected {self.mass.shape}, " f"found {m.shape}."
+ )
+ self.volume = jnp.divide(self.mass, self.material.properties["density"])
+
+def tree_flatten(self)
+Flatten class as Pytree type.
def tree_flatten(self):
+ """Flatten class as Pytree type."""
+ children = (
+ self.loc,
+ self.element_ids,
+ self.initialized,
+ self.mass,
+ self.density,
+ self.volume,
+ self.size,
+ self.velocity,
+ self.acceleration,
+ self.momentum,
+ self.strain,
+ self.stress,
+ self.strain_rate,
+ self.dstrain,
+ self.f_ext,
+ self.traction,
+ self.reference_loc,
+ self.dvolumetric_strain,
+ self.volumetric_strain_centroid,
+ )
+ aux_data = (self.material,)
+ return (children, aux_data)
+
+def update_natural_coords(self, elements: _Element)
+Update natural coordinates for the particles.
+Whenever the particles' physical coordinates change, their +natural coordinates need to be updated. This function updates +the natural coordinates of the particles based on the element +a particle is a part of. The update formula is
++\xi = (2x - (x_1^e + x_2^e)) +/ (x_2^e - x_1^e) +
+where x_i^e are the nodal coordinates of the element the +particle is in. If a particle is not in any element +(element_id = -1), its natural coordinate is set to 0.
+elements : _Elementdef update_natural_coords(self, elements: _Element):
+ r"""Update natural coordinates for the particles.
+
+ Whenever the particles' physical coordinates change, their
+ natural coordinates need to be updated. This function updates
+ the natural coordinates of the particles based on the element
+ a particle is a part of. The update formula is
+
+ \[
+ \xi = (2x - (x_1^e + x_2^e)) / (x_2^e - x_1^e)
+ \]
+
+ where \(x_i^e\) are the nodal coordinates of the element the
+ particle is in. If a particle is not in any element
+ (element_id = -1), its natural coordinate is set to 0.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements based on which to update the natural coordinates
+ of the particles.
+ """
+ t = vmap(elements.id_to_node_loc)(self.element_ids)
+ xi_coords = (self.loc - (t[:, 0, ...] + t[:, 2, ...]) / 2) * (
+ 2 / (t[:, 2, ...] - t[:, 0, ...])
+ )
+ self.reference_loc = xi_coords
+
+def update_position_velocity(self, elements: _Element, dt: float, velocity_update: bool)
+Transfer nodal velocity to particles and update particle position.
+The velocity is calculated based on the total force at nodes.
+elements : _Elementdt : floatvelocity_update : booldef update_position_velocity(
+ self, elements: _Element, dt: float, velocity_update: bool
+):
+ """Transfer nodal velocity to particles and update particle position.
+
+ The velocity is calculated based on the total force at nodes.
+
+ Parameters
+ ----------
+ elements: diffmpm.element._Element
+ Elements whose nodes are used to transfer the velocity.
+ dt: float
+ Timestep.
+ velocity_update: bool
+ If True, velocity is directly used as nodal velocity, else
+ velocity is calculated is interpolated nodal acceleration
+ multiplied by dt. Default is False.
+ """
+ mapped_positions = elements.shapefn(self.reference_loc)
+ mapped_ids = vmap(elements.id_to_node_ids)(self.element_ids).squeeze(-1)
+ nodal_velocity = jnp.sum(
+ mapped_positions * elements.nodes.velocity[mapped_ids], axis=1
+ )
+ nodal_acceleration = jnp.sum(
+ mapped_positions * elements.nodes.acceleration[mapped_ids],
+ axis=1,
+ )
+ self.velocity = self.velocity.at[:].set(
+ lax.cond(
+ velocity_update,
+ lambda sv, nv, na, t: nv,
+ lambda sv, nv, na, t: sv + na * t,
+ self.velocity,
+ nodal_velocity,
+ nodal_acceleration,
+ dt,
+ )
+ )
+ self.loc = self.loc.at[:].add(nodal_velocity * dt)
+ self.momentum = self.momentum.at[:].set(self.mass * self.velocity)
+
+def update_volume(self, *args)
+Update volume based on central strain rate.
def update_volume(self, *args):
+ """Update volume based on central strain rate."""
+ self.volume = self.volume.at[:, 0, :].multiply(1 + self.dvolumetric_strain)
+ self.density = self.density.at[:, 0, :].divide(1 + self.dvolumetric_strain)
+
+def zero_traction(self, *args)
+Set all traction values to 0.
def zero_traction(self, *args):
+ """Set all traction values to 0."""
+ self.traction = self.traction.at[:].set(0)
+diffmpm.pbar#!/usr/bin/env python3
+import typing
+
+import jax
+from jax.experimental import host_callback
+from tqdm.auto import tqdm
+
+
+def scan_tqdm(
+ n: int,
+ print_rate: typing.Optional[int] = None,
+ message: typing.Optional[str] = None,
+) -> typing.Callable:
+ """
+ tqdm progress bar for a JAX scan
+
+ Parameters
+ ----------
+ n : int
+ Number of scan steps/iterations.
+ print_rate: int
+ Optional integer rate at which the progress bar will be updated,
+ by default the print rate will 1/20th of the total number of steps.
+ message : str
+ Optional string to prepend to tqdm progress bar.
+
+ Returns
+ -------
+ typing.Callable:
+ Progress bar wrapping function.
+ """
+
+ _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
+
+ def _scan_tqdm(func):
+ """Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
+ Note that `body_fun` must either be looping over `jnp.arange(n)`,
+ or be looping over a tuple who's first element is `jnp.arange(n)`
+ This means that `iter_num` is the current iteration number
+ """
+
+ def wrapper_progress_bar(carry, x):
+ if type(x) is tuple:
+ iter_num, *_ = x
+ else:
+ iter_num = x
+ _update_progress_bar(iter_num)
+ result = func(carry, x)
+ return close_tqdm(result, iter_num)
+
+ return wrapper_progress_bar
+
+ return _scan_tqdm
+
+
+def loop_tqdm(
+ n: int,
+ print_rate: typing.Optional[int] = None,
+ message: typing.Optional[str] = None,
+) -> typing.Callable:
+ """
+ tqdm progress bar for a JAX fori_loop
+
+ Parameters
+ ----------
+ n : int
+ Number of iterations.
+ print_rate: int
+ Optional integer rate at which the progress bar will be updated,
+ by default the print rate will 1/20th of the total number of steps.
+ message : str
+ Optional string to prepend to tqdm progress bar.
+
+ Returns
+ -------
+ typing.Callable:
+ Progress bar wrapping function.
+ """
+
+ _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
+
+ def _loop_tqdm(func):
+ """
+ Decorator that adds a tqdm progress bar to `body_fun`
+ used in `jax.lax.fori_loop`.
+ """
+
+ def wrapper_progress_bar(i, val):
+ _update_progress_bar(i)
+ result = func(i, val)
+ return close_tqdm(result, i)
+
+ return wrapper_progress_bar
+
+ return _loop_tqdm
+
+
+def build_tqdm(
+ n: int,
+ print_rate: typing.Optional[int],
+ message: typing.Optional[str] = None,
+) -> typing.Tuple[typing.Callable, typing.Callable]:
+ """
+ Build the tqdm progress bar on the host
+ """
+
+ if message is None:
+ message = f"Running for {n:,} iterations"
+ tqdm_bars = {}
+
+ if print_rate is None:
+ if n > 20:
+ print_rate = int(n / 20)
+ else:
+ print_rate = 1
+ else:
+ if print_rate < 1:
+ raise ValueError(f"Print rate should be > 0 got {print_rate}")
+ elif print_rate > n:
+ raise ValueError(
+ "Print rate should be less than the "
+ f"number of steps {n}, got {print_rate}"
+ )
+
+ remainder = n % print_rate
+
+ def _define_tqdm(arg, transform):
+ tqdm_bars[0] = tqdm(range(n), leave=False)
+ tqdm_bars[0].set_description(message, refresh=False)
+
+ def _update_tqdm(arg, transform):
+ tqdm_bars[0].update(arg)
+
+ def _update_progress_bar(iter_num):
+ "Updates tqdm from a JAX scan or loop"
+ _ = jax.jax.lax.cond(
+ iter_num == 0,
+ lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ _ = jax.lax.cond(
+ # update tqdm every multiple of `print_rate` except at the end
+ (iter_num % print_rate == 0) & (iter_num != n - remainder),
+ lambda _: host_callback.id_tap(
+ _update_tqdm, print_rate, result=iter_num
+ ),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ _ = jax.lax.cond(
+ # update tqdm by `remainder`
+ iter_num == n - remainder,
+ lambda _: host_callback.id_tap(
+ _update_tqdm, remainder, result=iter_num
+ ),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ def _close_tqdm(arg, transform):
+ tqdm_bars[0].close()
+
+ def close_tqdm(result, iter_num):
+ return jax.lax.cond(
+ iter_num == n - 1,
+ lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
+ lambda _: result,
+ operand=None,
+ )
+
+ return _update_progress_bar, close_tqdm
+
+def build_tqdm(n: int, print_rate: Optional[int], message: Optional[str] = None) ‑> Tuple[Callable, Callable]
+Build the tqdm progress bar on the host
def build_tqdm(
+ n: int,
+ print_rate: typing.Optional[int],
+ message: typing.Optional[str] = None,
+) -> typing.Tuple[typing.Callable, typing.Callable]:
+ """
+ Build the tqdm progress bar on the host
+ """
+
+ if message is None:
+ message = f"Running for {n:,} iterations"
+ tqdm_bars = {}
+
+ if print_rate is None:
+ if n > 20:
+ print_rate = int(n / 20)
+ else:
+ print_rate = 1
+ else:
+ if print_rate < 1:
+ raise ValueError(f"Print rate should be > 0 got {print_rate}")
+ elif print_rate > n:
+ raise ValueError(
+ "Print rate should be less than the "
+ f"number of steps {n}, got {print_rate}"
+ )
+
+ remainder = n % print_rate
+
+ def _define_tqdm(arg, transform):
+ tqdm_bars[0] = tqdm(range(n), leave=False)
+ tqdm_bars[0].set_description(message, refresh=False)
+
+ def _update_tqdm(arg, transform):
+ tqdm_bars[0].update(arg)
+
+ def _update_progress_bar(iter_num):
+ "Updates tqdm from a JAX scan or loop"
+ _ = jax.jax.lax.cond(
+ iter_num == 0,
+ lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ _ = jax.lax.cond(
+ # update tqdm every multiple of `print_rate` except at the end
+ (iter_num % print_rate == 0) & (iter_num != n - remainder),
+ lambda _: host_callback.id_tap(
+ _update_tqdm, print_rate, result=iter_num
+ ),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ _ = jax.lax.cond(
+ # update tqdm by `remainder`
+ iter_num == n - remainder,
+ lambda _: host_callback.id_tap(
+ _update_tqdm, remainder, result=iter_num
+ ),
+ lambda _: iter_num,
+ operand=None,
+ )
+
+ def _close_tqdm(arg, transform):
+ tqdm_bars[0].close()
+
+ def close_tqdm(result, iter_num):
+ return jax.lax.cond(
+ iter_num == n - 1,
+ lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
+ lambda _: result,
+ operand=None,
+ )
+
+ return _update_progress_bar, close_tqdm
+
+def loop_tqdm(n: int, print_rate: Optional[int] = None, message: Optional[str] = None) ‑> Callable
+tqdm progress bar for a JAX fori_loop
+n : intprint_rate : intmessage : strtyping.Callable:def loop_tqdm(
+ n: int,
+ print_rate: typing.Optional[int] = None,
+ message: typing.Optional[str] = None,
+) -> typing.Callable:
+ """
+ tqdm progress bar for a JAX fori_loop
+
+ Parameters
+ ----------
+ n : int
+ Number of iterations.
+ print_rate: int
+ Optional integer rate at which the progress bar will be updated,
+ by default the print rate will 1/20th of the total number of steps.
+ message : str
+ Optional string to prepend to tqdm progress bar.
+
+ Returns
+ -------
+ typing.Callable:
+ Progress bar wrapping function.
+ """
+
+ _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
+
+ def _loop_tqdm(func):
+ """
+ Decorator that adds a tqdm progress bar to `body_fun`
+ used in `jax.lax.fori_loop`.
+ """
+
+ def wrapper_progress_bar(i, val):
+ _update_progress_bar(i)
+ result = func(i, val)
+ return close_tqdm(result, i)
+
+ return wrapper_progress_bar
+
+ return _loop_tqdm
+
+def scan_tqdm(n: int, print_rate: Optional[int] = None, message: Optional[str] = None) ‑> Callable
+tqdm progress bar for a JAX scan
+n : intprint_rate : intmessage : strtyping.Callable:def scan_tqdm(
+ n: int,
+ print_rate: typing.Optional[int] = None,
+ message: typing.Optional[str] = None,
+) -> typing.Callable:
+ """
+ tqdm progress bar for a JAX scan
+
+ Parameters
+ ----------
+ n : int
+ Number of scan steps/iterations.
+ print_rate: int
+ Optional integer rate at which the progress bar will be updated,
+ by default the print rate will 1/20th of the total number of steps.
+ message : str
+ Optional string to prepend to tqdm progress bar.
+
+ Returns
+ -------
+ typing.Callable:
+ Progress bar wrapping function.
+ """
+
+ _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, message)
+
+ def _scan_tqdm(func):
+ """Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
+ Note that `body_fun` must either be looping over `jnp.arange(n)`,
+ or be looping over a tuple who's first element is `jnp.arange(n)`
+ This means that `iter_num` is the current iteration number
+ """
+
+ def wrapper_progress_bar(carry, x):
+ if type(x) is tuple:
+ iter_num, *_ = x
+ else:
+ iter_num = x
+ _update_progress_bar(iter_num)
+ result = func(carry, x)
+ return close_tqdm(result, iter_num)
+
+ return wrapper_progress_bar
+
+ return _scan_tqdm
+diffmpm.plotimport os
+from collections import defaultdict
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.animation import FuncAnimation
+
+
+class NPZPlotter:
+ def __init__(self, out_dir):
+ self.out_dir = out_dir
+
+ def _parse_data(self, nsteps: int, keys: list = ["loc"]):
+ out_files = []
+ for file in os.listdir(self.out_dir):
+ if os.path.splitext(file)[-1] == ".npz":
+ out_files.append(file)
+ out_files = sorted(out_files)[:nsteps]
+
+ pset_vals = defaultdict(list)
+ for file in out_files:
+ data = np.load(os.path.join(self.out_dir, file))
+ for key in keys:
+ if key not in data:
+ raise KeyError(f"{key} not found in {file}.")
+ if key not in pset_vals:
+ pset_vals[key] = defaultdict(list)
+ for pset, pset_val in enumerate(data[key]):
+ pset_vals[key][pset].append(pset_val)
+
+ return pset_vals
+
+ def _create_animation_2d(self, pset_vals, nsteps):
+ fig, ax = plt.subplots()
+ pset_scat_list = []
+ for pval in pset_vals.values():
+ breakpoint()
+ scat = ax.scatter(pval[0][:, 0], pval[0][:, 1])
+ pset_scat_list.append(scat)
+
+ def _update(i, pvals, scat_list):
+ for pset, pval in pvals.items():
+ x = pval[i][:, 0]
+ y = pval[i][:, 1]
+ data = np.stack([x, y]).T
+ scat_list[pset].set_offsets(data)
+ return (*scat_list,)
+
+ anim = FuncAnimation(
+ fig=fig,
+ func=_update,
+ frames=nsteps,
+ interval=300,
+ fargs=(pset_vals, pset_scat_list),
+ )
+ return anim
+
+ def _create_animation_1d(self, pset_vals, nsteps):
+ fig, ax = plt.subplots()
+ pset_scat_list = []
+ for pval in pset_vals.values():
+ scat = ax.scatter(pval[0], pval[0])
+ pset_scat_list.append(scat)
+
+ def _update(i, pvals, scat_list):
+ for pset, pval in pvals.items():
+ x = pval[i]
+ y = pval[i]
+ data = np.stack([x, y]).T
+ scat_list[pset].set_offsets(data)
+ return (*scat_list,)
+
+ anim = FuncAnimation(
+ fig=fig,
+ func=_update,
+ frames=nsteps,
+ interval=30,
+ fargs=(pset_vals, pset_scat_list),
+ )
+ return anim
+
+ def plot_animations(self, nsteps: int, keys: list = ["loc"]):
+ pset_dict = self._parse_data(nsteps, keys)
+ breakpoint()
+ animations = {}
+ for key in keys:
+ animations[key] = self._create_animation_1d(pset_dict[key], nsteps)
+
+ return animations
+
+class NPZPlotter
+(out_dir)
+class NPZPlotter:
+ def __init__(self, out_dir):
+ self.out_dir = out_dir
+
+ def _parse_data(self, nsteps: int, keys: list = ["loc"]):
+ out_files = []
+ for file in os.listdir(self.out_dir):
+ if os.path.splitext(file)[-1] == ".npz":
+ out_files.append(file)
+ out_files = sorted(out_files)[:nsteps]
+
+ pset_vals = defaultdict(list)
+ for file in out_files:
+ data = np.load(os.path.join(self.out_dir, file))
+ for key in keys:
+ if key not in data:
+ raise KeyError(f"{key} not found in {file}.")
+ if key not in pset_vals:
+ pset_vals[key] = defaultdict(list)
+ for pset, pset_val in enumerate(data[key]):
+ pset_vals[key][pset].append(pset_val)
+
+ return pset_vals
+
+ def _create_animation_2d(self, pset_vals, nsteps):
+ fig, ax = plt.subplots()
+ pset_scat_list = []
+ for pval in pset_vals.values():
+ breakpoint()
+ scat = ax.scatter(pval[0][:, 0], pval[0][:, 1])
+ pset_scat_list.append(scat)
+
+ def _update(i, pvals, scat_list):
+ for pset, pval in pvals.items():
+ x = pval[i][:, 0]
+ y = pval[i][:, 1]
+ data = np.stack([x, y]).T
+ scat_list[pset].set_offsets(data)
+ return (*scat_list,)
+
+ anim = FuncAnimation(
+ fig=fig,
+ func=_update,
+ frames=nsteps,
+ interval=300,
+ fargs=(pset_vals, pset_scat_list),
+ )
+ return anim
+
+ def _create_animation_1d(self, pset_vals, nsteps):
+ fig, ax = plt.subplots()
+ pset_scat_list = []
+ for pval in pset_vals.values():
+ scat = ax.scatter(pval[0], pval[0])
+ pset_scat_list.append(scat)
+
+ def _update(i, pvals, scat_list):
+ for pset, pval in pvals.items():
+ x = pval[i]
+ y = pval[i]
+ data = np.stack([x, y]).T
+ scat_list[pset].set_offsets(data)
+ return (*scat_list,)
+
+ anim = FuncAnimation(
+ fig=fig,
+ func=_update,
+ frames=nsteps,
+ interval=30,
+ fargs=(pset_vals, pset_scat_list),
+ )
+ return anim
+
+ def plot_animations(self, nsteps: int, keys: list = ["loc"]):
+ pset_dict = self._parse_data(nsteps, keys)
+ breakpoint()
+ animations = {}
+ for key in keys:
+ animations[key] = self._create_animation_1d(pset_dict[key], nsteps)
+
+ return animations
+
+def plot_animations(self, nsteps: int, keys: list = ['loc'])
+def plot_animations(self, nsteps: int, keys: list = ["loc"]):
+ pset_dict = self._parse_data(nsteps, keys)
+ breakpoint()
+ animations = {}
+ for key in keys:
+ animations[key] = self._create_animation_1d(pset_dict[key], nsteps)
+
+ return animations
+diffmpm.schemefrom __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from jax.typing import ArrayLike
+
+if TYPE_CHECKING:
+ import jax.numpy as jnp
+ from diffmpm.mesh import _MeshBase
+
+import abc
+
+_schemes = ("usf", "usl")
+
+
+class _MPMScheme(abc.ABC):
+ def __init__(self, mesh, dt, velocity_update):
+ self.mesh = mesh
+ self.velocity_update = velocity_update
+ self.dt = dt
+
+ def compute_nodal_kinematics(self):
+ """Compute nodal kinematics - map mass and momentum to mesh nodes."""
+ self.mesh.apply_on_elements("set_particle_element_ids")
+ self.mesh.apply_on_particles("update_natural_coords")
+ self.mesh.apply_on_elements("compute_nodal_mass")
+ self.mesh.apply_on_elements("compute_nodal_momentum")
+ self.mesh.apply_on_elements("compute_velocity")
+ self.mesh.apply_on_elements("apply_boundary_constraints")
+
+ def compute_stress_strain(self):
+ """Compute stress and strain on the particles."""
+ self.mesh.apply_on_particles("compute_strain", args=(self.dt,))
+ self.mesh.apply_on_particles("update_volume")
+ self.mesh.apply_on_particles("compute_stress")
+
+ def compute_forces(self, gravity: ArrayLike, step: int):
+ """Compute the forces acting in the system.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+ step: int
+ Current step being simulated.
+ """
+ self.mesh.apply_on_elements("compute_external_force")
+ self.mesh.apply_on_elements("compute_body_force", args=(gravity,))
+ self.mesh.apply_traction_on_particles(step * self.dt)
+ self.mesh.apply_on_elements(
+ "apply_concentrated_nodal_forces", args=(step * self.dt,)
+ )
+ self.mesh.apply_on_elements("compute_internal_force")
+ # self.mesh.apply_on_elements("apply_force_boundary_constraints")
+
+ def compute_particle_kinematics(self):
+ """Compute particle location, acceleration and velocity."""
+ self.mesh.apply_on_elements(
+ "update_nodal_acceleration_velocity", args=(self.dt,)
+ )
+ self.mesh.apply_on_particles(
+ "update_position_velocity",
+ args=(self.dt, self.velocity_update),
+ )
+ # TODO: Apply particle velocity constraints.
+
+ @abc.abstractmethod
+ def precompute_stress_strain(self):
+ ...
+
+ @abc.abstractmethod
+ def postcompute_stress_strain(self):
+ ...
+
+
+class USF(_MPMScheme):
+ """USF Scheme solver."""
+
+ def __init__(self, mesh: _MeshBase, dt: float, velocity_update: bool):
+ """Initialize USF Scheme solver.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ """
+ super().__init__(mesh, dt, velocity_update)
+
+ def precompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+ def postcompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USF)."""
+ pass
+
+
+class USL(_MPMScheme):
+ """USL Scheme solver."""
+
+ def __init__(self, mesh, dt, velocity_update):
+ """Initialize USL Scheme solver.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ """
+ super().__init__(mesh, dt, velocity_update)
+
+ def precompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USL)."""
+ pass
+
+ def postcompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+class USF
+(mesh: _MeshBase, dt: float, velocity_update: bool)
+USF Scheme solver.
+Initialize USF Scheme solver.
+mesh : _MeshBasediffmpm.Mesh object that contains the elements that form
+the underlying mesh used to solve the simulation.dt : floatvelocity_update : boolTrue, nodal
+velocity is used, else nodal acceleration. Default False.class USF(_MPMScheme):
+ """USF Scheme solver."""
+
+ def __init__(self, mesh: _MeshBase, dt: float, velocity_update: bool):
+ """Initialize USF Scheme solver.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ """
+ super().__init__(mesh, dt, velocity_update)
+
+ def precompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+ def postcompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USF)."""
+ pass
+
+def postcompute_stress_strain(self)
+Compute stress and strain on particles. (Empty call for USF).
def postcompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USF)."""
+ pass
+
+def precompute_stress_strain(self)
+Compute stress and strain on particles.
def precompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+class USL
+(mesh, dt, velocity_update)
+USL Scheme solver.
+Initialize USL Scheme solver.
+mesh : _MeshBasediffmpm.Mesh object that contains the elements that form
+the underlying mesh used to solve the simulation.dt : floatvelocity_update : boolTrue, nodal
+velocity is used, else nodal acceleration. Default False.class USL(_MPMScheme):
+ """USL Scheme solver."""
+
+ def __init__(self, mesh, dt, velocity_update):
+ """Initialize USL Scheme solver.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ """
+ super().__init__(mesh, dt, velocity_update)
+
+ def precompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USL)."""
+ pass
+
+ def postcompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+def postcompute_stress_strain(self)
+Compute stress and strain on particles.
def postcompute_stress_strain(self):
+ """Compute stress and strain on particles."""
+ self.compute_stress_strain()
+
+def precompute_stress_strain(self)
+Compute stress and strain on particles. (Empty call for USL).
def precompute_stress_strain(self):
+ """Compute stress and strain on particles. (Empty call for USL)."""
+ pass
+diffmpm.solverfrom __future__ import annotations
+
+import functools
+from typing import TYPE_CHECKING, Callable, Optional
+
+import jax.numpy as jnp
+from jax import lax
+from jax.experimental.host_callback import id_tap
+from jax.tree_util import register_pytree_node_class
+from jax.typing import ArrayLike
+
+from diffmpm.scheme import USF, USL, _MPMScheme, _schemes
+
+if TYPE_CHECKING:
+ from diffmpm.mesh import _MeshBase
+
+
+@register_pytree_node_class
+class MPMExplicit:
+ """A class to implement the fully explicit MPM."""
+
+ __particle_props = ("loc", "velocity", "stress", "strain")
+
+ def __init__(
+ self,
+ mesh: _MeshBase,
+ dt: float,
+ scheme: str = "usf",
+ velocity_update: bool = False,
+ sim_steps: int = 1,
+ out_steps: int = 1,
+ out_dir: str = "results/",
+ writer_func: Optional[Callable] = None,
+ ) -> None:
+ """Create an `MPMExplicit` object.
+
+ This can be used to solve a given configuration of an MPM
+ problem.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ scheme: str
+ The MPM Scheme type used for the simulation. Can be one of
+ `"usl"` or `"usf"`. Default set to `"usf"`.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ sim_steps: int
+ Number of steps to run the simulation for. Default set to 1.
+ out_steps: int
+ Frequency with which to store the results. For example, if
+ set to 5, the result at every 5th step will be stored. Default
+ set to 1.
+ out_dir: str
+ Path to the output directory where results are stored.
+ writer_func: Callable, None
+ Function that is used to write the state in the output
+ directory.
+ """
+
+ if scheme == "usf":
+ self.mpm_scheme: _MPMScheme = USF(mesh, dt, velocity_update) # type: ignore
+ elif scheme == "usl":
+ self.mpm_scheme: _MPMScheme = USL(mesh, dt, velocity_update) # type: ignore
+ else:
+ raise ValueError(f"Please select scheme from {_schemes}. Found {scheme}")
+ self.mesh = mesh
+ self.dt = dt
+ self.scheme = scheme
+ self.velocity_update = velocity_update
+ self.sim_steps = sim_steps
+ self.out_steps = out_steps
+ self.out_dir = out_dir
+ self.writer_func = writer_func
+ self.mpm_scheme.mesh.apply_on_elements("set_particle_element_ids")
+ self.mpm_scheme.mesh.apply_on_elements("compute_volume")
+ self.mpm_scheme.mesh.apply_on_particles(
+ "compute_volume", args=(self.mesh.elements.total_elements,)
+ )
+
+ def tree_flatten(self):
+ children = (self.mesh,)
+ aux_data = {
+ "dt": self.dt,
+ "scheme": self.scheme,
+ "velocity_update": self.velocity_update,
+ "sim_steps": self.sim_steps,
+ "out_steps": self.out_steps,
+ "out_dir": self.out_dir,
+ "writer_func": self.writer_func,
+ }
+ return children, aux_data
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ *children,
+ aux_data["dt"],
+ scheme=aux_data["scheme"],
+ velocity_update=aux_data["velocity_update"],
+ sim_steps=aux_data["sim_steps"],
+ out_steps=aux_data["out_steps"],
+ out_dir=aux_data["out_dir"],
+ writer_func=aux_data["writer_func"],
+ )
+
+ def _jax_writer(self, func, args):
+ id_tap(func, args)
+
+ def solve(self, gravity: ArrayLike):
+ """Non-JIT solve method.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ .. note::
+ This is mainly used for debugging and might be removed in
+ future versions or moved to the JIT solver.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `ArrayLike` arrays corresponding to the
+ all states of the simulation after completing all steps.
+ """
+ from collections import defaultdict
+
+ from tqdm import tqdm # type: ignore
+
+ result = defaultdict(list)
+ for step in tqdm(range(self.sim_steps)):
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, step)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+ for pset in self.mesh.particles:
+ result["position"].append(pset.loc)
+ result["velocity"].append(pset.velocity)
+ result["stress"].append(pset.stress[:, :2, 0])
+ result["strain"].append(pset.strain[:, :2, 0])
+
+ result_arr = {k: jnp.asarray(v) for k, v in result.items()}
+ return result_arr
+
+ def solve_jit(self, gravity: ArrayLike) -> dict:
+ """Solver method that runs the simulation.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `jax.numpy` arrays corresponding to the
+ final state of the simulation after completing all steps.
+ """
+
+ def _step(i, data):
+ self = data
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, i)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+
+ def _write(self, i):
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name).squeeze()
+ for j in range(len(self.mesh.particles))
+ ]
+ )
+ self._jax_writer(
+ functools.partial(
+ self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps
+ ),
+ (arrays, i),
+ )
+
+ if self.writer_func is not None:
+ lax.cond(
+ i % self.out_steps == 0,
+ _write,
+ lambda s, i: None,
+ self,
+ i,
+ )
+ return self
+
+ self = lax.fori_loop(0, self.sim_steps, _step, self)
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name)
+ for j in range(len(self.mesh.particles))
+ ]
+ ).squeeze()
+ return arrays
+
+class MPMExplicit
+(mesh: _MeshBase, dt: float, scheme: str = 'usf', velocity_update: bool = False, sim_steps: int = 1, out_steps: int = 1, out_dir: str = 'results/', writer_func: Optional[Callable] = None)
+A class to implement the fully explicit MPM.
+Create an MPMExplicit object.
This can be used to solve a given configuration of an MPM +problem.
+mesh : _MeshBasediffmpm.Mesh object that contains the elements that form
+the underlying mesh used to solve the simulation.dt : floatscheme : str"usl" or "usf". Default set to "usf".velocity_update : boolTrue, nodal
+velocity is used, else nodal acceleration. Default False.sim_steps : intout_steps : intout_dir : strwriter_func : Callable, None@register_pytree_node_class
+class MPMExplicit:
+ """A class to implement the fully explicit MPM."""
+
+ __particle_props = ("loc", "velocity", "stress", "strain")
+
+ def __init__(
+ self,
+ mesh: _MeshBase,
+ dt: float,
+ scheme: str = "usf",
+ velocity_update: bool = False,
+ sim_steps: int = 1,
+ out_steps: int = 1,
+ out_dir: str = "results/",
+ writer_func: Optional[Callable] = None,
+ ) -> None:
+ """Create an `MPMExplicit` object.
+
+ This can be used to solve a given configuration of an MPM
+ problem.
+
+ Parameters
+ ----------
+ mesh: _MeshBase
+ A `diffmpm.Mesh` object that contains the elements that form
+ the underlying mesh used to solve the simulation.
+ dt: float
+ Timestep used in the simulation.
+ scheme: str
+ The MPM Scheme type used for the simulation. Can be one of
+ `"usl"` or `"usf"`. Default set to `"usf"`.
+ velocity_update: bool
+ Flag to control if velocity should be updated using nodal
+ velocity or interpolated nodal acceleration. If `True`, nodal
+ velocity is used, else nodal acceleration. Default `False`.
+ sim_steps: int
+ Number of steps to run the simulation for. Default set to 1.
+ out_steps: int
+ Frequency with which to store the results. For example, if
+ set to 5, the result at every 5th step will be stored. Default
+ set to 1.
+ out_dir: str
+ Path to the output directory where results are stored.
+ writer_func: Callable, None
+ Function that is used to write the state in the output
+ directory.
+ """
+
+ if scheme == "usf":
+ self.mpm_scheme: _MPMScheme = USF(mesh, dt, velocity_update) # type: ignore
+ elif scheme == "usl":
+ self.mpm_scheme: _MPMScheme = USL(mesh, dt, velocity_update) # type: ignore
+ else:
+ raise ValueError(f"Please select scheme from {_schemes}. Found {scheme}")
+ self.mesh = mesh
+ self.dt = dt
+ self.scheme = scheme
+ self.velocity_update = velocity_update
+ self.sim_steps = sim_steps
+ self.out_steps = out_steps
+ self.out_dir = out_dir
+ self.writer_func = writer_func
+ self.mpm_scheme.mesh.apply_on_elements("set_particle_element_ids")
+ self.mpm_scheme.mesh.apply_on_elements("compute_volume")
+ self.mpm_scheme.mesh.apply_on_particles(
+ "compute_volume", args=(self.mesh.elements.total_elements,)
+ )
+
+ def tree_flatten(self):
+ children = (self.mesh,)
+ aux_data = {
+ "dt": self.dt,
+ "scheme": self.scheme,
+ "velocity_update": self.velocity_update,
+ "sim_steps": self.sim_steps,
+ "out_steps": self.out_steps,
+ "out_dir": self.out_dir,
+ "writer_func": self.writer_func,
+ }
+ return children, aux_data
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(
+ *children,
+ aux_data["dt"],
+ scheme=aux_data["scheme"],
+ velocity_update=aux_data["velocity_update"],
+ sim_steps=aux_data["sim_steps"],
+ out_steps=aux_data["out_steps"],
+ out_dir=aux_data["out_dir"],
+ writer_func=aux_data["writer_func"],
+ )
+
+ def _jax_writer(self, func, args):
+ id_tap(func, args)
+
+ def solve(self, gravity: ArrayLike):
+ """Non-JIT solve method.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ .. note::
+ This is mainly used for debugging and might be removed in
+ future versions or moved to the JIT solver.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `ArrayLike` arrays corresponding to the
+ all states of the simulation after completing all steps.
+ """
+ from collections import defaultdict
+
+ from tqdm import tqdm # type: ignore
+
+ result = defaultdict(list)
+ for step in tqdm(range(self.sim_steps)):
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, step)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+ for pset in self.mesh.particles:
+ result["position"].append(pset.loc)
+ result["velocity"].append(pset.velocity)
+ result["stress"].append(pset.stress[:, :2, 0])
+ result["strain"].append(pset.strain[:, :2, 0])
+
+ result_arr = {k: jnp.asarray(v) for k, v in result.items()}
+ return result_arr
+
+ def solve_jit(self, gravity: ArrayLike) -> dict:
+ """Solver method that runs the simulation.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `jax.numpy` arrays corresponding to the
+ final state of the simulation after completing all steps.
+ """
+
+ def _step(i, data):
+ self = data
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, i)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+
+ def _write(self, i):
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name).squeeze()
+ for j in range(len(self.mesh.particles))
+ ]
+ )
+ self._jax_writer(
+ functools.partial(
+ self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps
+ ),
+ (arrays, i),
+ )
+
+ if self.writer_func is not None:
+ lax.cond(
+ i % self.out_steps == 0,
+ _write,
+ lambda s, i: None,
+ self,
+ i,
+ )
+ return self
+
+ self = lax.fori_loop(0, self.sim_steps, _step, self)
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name)
+ for j in range(len(self.mesh.particles))
+ ]
+ ).squeeze()
+ return arrays
+
+def tree_unflatten(aux_data, children)
+@classmethod
+def tree_unflatten(cls, aux_data, children):
+ return cls(
+ *children,
+ aux_data["dt"],
+ scheme=aux_data["scheme"],
+ velocity_update=aux_data["velocity_update"],
+ sim_steps=aux_data["sim_steps"],
+ out_steps=aux_data["out_steps"],
+ out_dir=aux_data["out_dir"],
+ writer_func=aux_data["writer_func"],
+ )
+
+def solve(self, gravity: ArrayLike)
+Non-JIT solve method.
+This method runs the entire simulation for the defined number +of steps.
+Note
+This is mainly used for debugging and might be removed in +future versions or moved to the JIT solver.
+gravity : ArrayLike(1, ndim) where ndim is the dimension of the
+simulation.dictArrayLike arrays corresponding to the
+all states of the simulation after completing all steps.def solve(self, gravity: ArrayLike):
+ """Non-JIT solve method.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ .. note::
+ This is mainly used for debugging and might be removed in
+ future versions or moved to the JIT solver.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `ArrayLike` arrays corresponding to the
+ all states of the simulation after completing all steps.
+ """
+ from collections import defaultdict
+
+ from tqdm import tqdm # type: ignore
+
+ result = defaultdict(list)
+ for step in tqdm(range(self.sim_steps)):
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, step)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+ for pset in self.mesh.particles:
+ result["position"].append(pset.loc)
+ result["velocity"].append(pset.velocity)
+ result["stress"].append(pset.stress[:, :2, 0])
+ result["strain"].append(pset.strain[:, :2, 0])
+
+ result_arr = {k: jnp.asarray(v) for k, v in result.items()}
+ return result_arr
+
+def solve_jit(self, gravity: ArrayLike) ‑> dict
+Solver method that runs the simulation.
+This method runs the entire simulation for the defined number +of steps.
+gravity : ArrayLike(1, ndim) where ndim is the dimension of the
+simulation.dictjax.numpy arrays corresponding to the
+final state of the simulation after completing all steps.def solve_jit(self, gravity: ArrayLike) -> dict:
+ """Solver method that runs the simulation.
+
+ This method runs the entire simulation for the defined number
+ of steps.
+
+ Parameters
+ ----------
+ gravity: ArrayLike
+ Gravity present in the system. This should be an array equal
+ with shape `(1, ndim)` where `ndim` is the dimension of the
+ simulation.
+
+ Returns
+ -------
+ dict
+ A dictionary of `jax.numpy` arrays corresponding to the
+ final state of the simulation after completing all steps.
+ """
+
+ def _step(i, data):
+ self = data
+ self.mpm_scheme.compute_nodal_kinematics()
+ self.mpm_scheme.precompute_stress_strain()
+ self.mpm_scheme.compute_forces(gravity, i)
+ self.mpm_scheme.compute_particle_kinematics()
+ self.mpm_scheme.postcompute_stress_strain()
+
+ def _write(self, i):
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name).squeeze()
+ for j in range(len(self.mesh.particles))
+ ]
+ )
+ self._jax_writer(
+ functools.partial(
+ self.writer_func, out_dir=self.out_dir, max_steps=self.sim_steps
+ ),
+ (arrays, i),
+ )
+
+ if self.writer_func is not None:
+ lax.cond(
+ i % self.out_steps == 0,
+ _write,
+ lambda s, i: None,
+ self,
+ i,
+ )
+ return self
+
+ self = lax.fori_loop(0, self.sim_steps, _step, self)
+ arrays = {}
+ for name in self.__particle_props:
+ arrays[name] = jnp.array(
+ [
+ getattr(self.mesh.particles[j], name)
+ for j in range(len(self.mesh.particles))
+ ]
+ ).squeeze()
+ return arrays
+
+def tree_flatten(self)
+def tree_flatten(self):
+ children = (self.mesh,)
+ aux_data = {
+ "dt": self.dt,
+ "scheme": self.scheme,
+ "velocity_update": self.velocity_update,
+ "sim_steps": self.sim_steps,
+ "out_steps": self.out_steps,
+ "out_dir": self.out_dir,
+ "writer_func": self.writer_func,
+ }
+ return children, aux_data
+diffmpm.utilsfrom jax.tree_util import tree_flatten, tree_unflatten
+
+
+def _show_example(structured):
+ flat, tree = tree_flatten(structured)
+ unflattened = tree_unflatten(tree, flat)
+ print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}")
+diffmpm.writersimport abc
+import logging
+from pathlib import Path
+
+from typing import Tuple, Annotated, Any
+from jax.typing import ArrayLike
+import numpy as np
+
+logger = logging.getLogger(__file__)
+
+__all__ = ["_Writer", "EmptyWriter", "NPZWriter"]
+
+
+class _Writer(abc.ABC):
+ """Base writer class."""
+
+ @abc.abstractmethod
+ def write(self):
+ ...
+
+
+class EmptyWriter(_Writer):
+ """Empty writer used when output is not to be written."""
+
+ def write(self, args, transforms, **kwargs):
+ """Empty function."""
+ pass
+
+
+class NPZWriter(_Writer):
+ """Writer to write output in `.npz` format."""
+
+ def write(
+ self,
+ args: Tuple[
+ Annotated[ArrayLike, "JAX arrays to be written"],
+ Annotated[int, "step number of the simulation"],
+ ],
+ transforms: Any,
+ **kwargs,
+ ):
+ """Writes the output arrays as `.npz` files."""
+ arrays, step = args
+ max_digits = int(np.log10(kwargs["max_steps"])) + 1
+ if step == 0:
+ req_zeros = max_digits - 1
+ else:
+ req_zeros = max_digits - (int(np.log10(step)) + 1)
+ fileno = f"{'0' * req_zeros}{step}"
+ filepath = Path(kwargs["out_dir"]).joinpath(f"particles_{fileno}.npz")
+ if not filepath.parent.is_dir():
+ filepath.parent.mkdir(parents=True)
+ np.savez(filepath, **arrays)
+ logger.info(f"Saved particle data for step {step} at {filepath}")
+
+class EmptyWriter
+Empty writer used when output is not to be written.
class EmptyWriter(_Writer):
+ """Empty writer used when output is not to be written."""
+
+ def write(self, args, transforms, **kwargs):
+ """Empty function."""
+ pass
+
+def write(self, args, transforms, **kwargs)
+Empty function.
def write(self, args, transforms, **kwargs):
+ """Empty function."""
+ pass
+
+class NPZWriter
+Writer to write output in .npz format.
class NPZWriter(_Writer):
+ """Writer to write output in `.npz` format."""
+
+ def write(
+ self,
+ args: Tuple[
+ Annotated[ArrayLike, "JAX arrays to be written"],
+ Annotated[int, "step number of the simulation"],
+ ],
+ transforms: Any,
+ **kwargs,
+ ):
+ """Writes the output arrays as `.npz` files."""
+ arrays, step = args
+ max_digits = int(np.log10(kwargs["max_steps"])) + 1
+ if step == 0:
+ req_zeros = max_digits - 1
+ else:
+ req_zeros = max_digits - (int(np.log10(step)) + 1)
+ fileno = f"{'0' * req_zeros}{step}"
+ filepath = Path(kwargs["out_dir"]).joinpath(f"particles_{fileno}.npz")
+ if not filepath.parent.is_dir():
+ filepath.parent.mkdir(parents=True)
+ np.savez(filepath, **arrays)
+ logger.info(f"Saved particle data for step {step} at {filepath}")
+
+def write(self, args: Tuple[Annotated[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex], 'JAX arrays to be written'], Annotated[int, 'step number of the simulation']], transforms: Any, **kwargs)
+Writes the output arrays as .npz files.
def write(
+ self,
+ args: Tuple[
+ Annotated[ArrayLike, "JAX arrays to be written"],
+ Annotated[int, "step number of the simulation"],
+ ],
+ transforms: Any,
+ **kwargs,
+):
+ """Writes the output arrays as `.npz` files."""
+ arrays, step = args
+ max_digits = int(np.log10(kwargs["max_steps"])) + 1
+ if step == 0:
+ req_zeros = max_digits - 1
+ else:
+ req_zeros = max_digits - (int(np.log10(step)) + 1)
+ fileno = f"{'0' * req_zeros}{step}"
+ filepath = Path(kwargs["out_dir"]).joinpath(f"particles_{fileno}.npz")
+ if not filepath.parent.is_dir():
+ filepath.parent.mkdir(parents=True)
+ np.savez(filepath, **arrays)
+ logger.info(f"Saved particle data for step {step} at {filepath}")
+
+class _Writer
+Base writer class.
class _Writer(abc.ABC):
+ """Base writer class."""
+
+ @abc.abstractmethod
+ def write(self):
+ ...
+
+def write(self)
+@abc.abstractmethod
+def write(self):
+ ...
+