Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
# DiffMPM
# Differentiable Material Point Method (DiffMPM)

MPM simulations are applied in various fields such as computer graphics, geotechnical engineering, computational mechanics and more. `diffmpm` is a differentiable MPM simulation library written entirely in JAX which means it also has all the niceties that come with JAX. It is a highly parallel, Just-In-Time compiled code that can run on CPUs, GPUs or TPUs. It aims to be a fast solver that can be used in various problems like optimization and inverse problems. Having a differentiable MPM simulation opens up several advantages -
- **Efficient Gradient-based Optimization:** Since the entire simulation model is differentiable, it can be used in conjunction with various gradient-based optimization techniques such as stochastic gradient descent (SGD), ADAM etc.
- **Inverse Problems:** It also enables us to solve inverse problems to determine material properties by formulating an inverse problem as an optimization task.
- **Integration with Deep Learning:** It can be seamlessly integrated with other Neural Network models to enable training physics-informed neural networks.

## Installation
`diffmpm` can be installed directly from PyPI using `pip`

``` shell
pip install diffmpm
```

#### ToDo
Add separate installation commands for CPU/GPU.

## Usage
Once installed, `diffmpm` can be used as a CLI tool or can be imported as a library in Python. Example input files can be found in the `benchmarks/` directory.

```
Usage: mpm [OPTIONS]

CLI utility for DiffMPM.

Options:
-f, --file TEXT Input TOML file [required]
--version Show the version and exit.
--help Show this message and exit.
```

Further documentation about the input file can be found in the documentation _[INSERT LINK HERE]_. `diffmpm` can write the output to various file types like `.npz`, `.vtk` etc. that can then be used to visualize the output of the simulations.

## Examples
2 changes: 2 additions & 0 deletions benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path

import jax.numpy as jnp

from diffmpm import MPM


Expand Down
2 changes: 2 additions & 0 deletions benchmarks/2d/uniaxial_particle_traction/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path

import jax.numpy as jnp

from diffmpm import MPM


Expand Down
2 changes: 2 additions & 0 deletions benchmarks/2d/uniaxial_stress/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path

import jax.numpy as jnp

from diffmpm import MPM


Expand Down
2 changes: 1 addition & 1 deletion diffmpm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, filepath):
raise ValueError("Wrong type of solver specified.")

def solve(self):
"""Solve the MPM simulation."""
"""Solve the MPM simulation using JIT solver."""
arrays = self.solver.solve_jit(
self._config.parsed_config["external_loading"]["gravity"],
)
Expand Down
3 changes: 2 additions & 1 deletion diffmpm/cli/mpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from diffmpm import MPM


@click.command()
@click.command() # type: ignore
@click.option(
"-f", "--file", "filepath", required=True, type=str, help="Input TOML file"
)
@click.version_option(package_name="diffmpm")
def mpm(filepath):
"""CLI utility for DiffMPM."""
solver = MPM(filepath)
solver.solve()
22 changes: 16 additions & 6 deletions diffmpm/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@

@register_pytree_node_class
class Constraint:
def __init__(self, dir, velocity):
"""Generic velocity constraints to apply on nodes or particles."""

def __init__(self, dir: int, velocity: float):
"""Contains 2 govering parameters.

Attributes
----------
dir : int
Direction in which constraint is applied.
velocity : float
Constrained velocity to be applied.
"""
self.dir = dir
self.velocity = velocity

Expand All @@ -16,16 +27,15 @@ def tree_unflatten(cls, aux_data, children):
return cls(*aux_data)

def apply(self, obj, ids):
"""
Apply constraint values to the passed object.
"""Apply constraint values to the passed object.

Arguments
---------
Parameters
----------
obj : diffmpm.node.Nodes, diffmpm.particle.Particles
Object on which the constraint is applied
ids : array_like
The indices of the container `obj` on which the constraint
will be applied.
will be applied.
"""
obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
obj.momentum = obj.momentum.at[ids, :, self.dir].set(
Expand Down
Loading