-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
270 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from sympy import IndexedBase, Symbol, Add, Indexed, latex | ||
from sympy import Rational, factorial, linsolve, Matrix | ||
|
||
|
||
class SymbolicMesh: | ||
""" | ||
Represents the mesh on which to evaluate finite difference approximations. | ||
""" | ||
|
||
def __init__(self, coord, equidistant=True): | ||
"""Constructor. | ||
Parameters | ||
---------- | ||
coord: str | ||
A comma-separated string of coordinate names for the mesh, | ||
e.g. "x,y" or simply "x" | ||
equidistant: bool | ||
Flag indicating whether the mesh is equidistant. | ||
""" | ||
assert isinstance(coord, str) | ||
|
||
self.equidistant = equidistant | ||
self._coord_names = [n.replace(" ", "") for n in coord.split(",")] | ||
self._coord = [IndexedBase(name) for name in self._coord_names] | ||
|
||
@property | ||
def ndims(self): | ||
"""The number of dimensions of the mesh.""" | ||
return len(self._coord) | ||
|
||
@property | ||
def coord(self): | ||
""" | ||
Returns a tuple with the symbols for the coordinates. | ||
""" | ||
return self._coord | ||
|
||
@property | ||
def spacing(self): | ||
""" | ||
Returns a tuple with the spacing of the mesh along all axes. | ||
Only makes sense for equidistant grid. Raises exception in | ||
case of non-equidistant grids. | ||
""" | ||
if self.equidistant: | ||
spacings = tuple(Symbol(f"\\Delta {x}") for x in self._coord_names) | ||
return spacings | ||
raise Exception("Non-equidistant mesh does not have spacing property.") | ||
|
||
@staticmethod | ||
def create_symbol(name): | ||
""" | ||
Creates a *sympy* symbol of a given name which can carry as many | ||
indices as the mesh has dimensions. | ||
Parameters | ||
---------- | ||
name: str | ||
The name of the meshed symbol. | ||
Returns | ||
------- | ||
An index-carrying *sympy* symbol (IndexedBase). | ||
""" | ||
return IndexedBase(name) | ||
|
||
|
||
class SymbolicDiff: | ||
""" | ||
A symbolic representation of the finite difference approximation | ||
of a partial derivative. Based on *sympy*. | ||
""" | ||
|
||
def __init__(self, mesh, axis=0, degree=1): | ||
"""Constructor | ||
Parameters | ||
---------- | ||
mesh: SymbolicMesh | ||
The symbolic grid on which to evaluate the derivative. | ||
axis: int | ||
The index of the axis with respect to which to differentiate. | ||
degree: int > 0 | ||
The degree of the partial derivative. | ||
""" | ||
self.mesh = mesh | ||
self.axis = axis | ||
self.degree = degree | ||
|
||
def __call__(self, f, at, offsets): | ||
if not isinstance(at, tuple) and not isinstance(at, list): | ||
at = [at] | ||
|
||
if self.mesh.ndims != len(at): | ||
raise ValueError("Index tuple must match the number of dimensions!") | ||
|
||
coefs = self._compute_coefficients(f, at, offsets) | ||
terms = [] | ||
for coef, off in zip(coefs, offsets): | ||
inds = list(at) | ||
inds[self.axis] += off | ||
inds = tuple(inds) | ||
terms.append(coef * f[inds]) | ||
|
||
return Add(*terms).simplify() | ||
|
||
def _compute_coefficients(self, f, at, offsets): | ||
|
||
n = len(offsets) | ||
# the first row always contains 1s: | ||
matrix = [[1] * n] | ||
|
||
def spac(off): | ||
"""A helper function to get the spacing between grid points.""" | ||
if self.mesh.equidistant: | ||
h = self.mesh.spacing[self.axis] | ||
else: | ||
x = self.mesh.coord[self.axis] | ||
h = x[at[self.axis] + off] - x[at[self.axis]] | ||
return h | ||
|
||
# build up the matrix incrementally: | ||
for i in range(1, n): | ||
ifac = Rational(1, factorial(i)) | ||
row = [ifac * (off * spac(off)) ** i for off in offsets] | ||
matrix.append(row) | ||
|
||
# only the entry corresponding to the requested derivative degree | ||
# is non-zero: | ||
rhs = [0] * n | ||
rhs[self.degree] = 1 | ||
|
||
# solve the equation system | ||
matrix = Matrix(matrix) | ||
rhs = Matrix(rhs) | ||
sol = linsolve((matrix, rhs)) | ||
return list(sol)[0].simplify() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import unittest | ||
|
||
from sympy import IndexedBase, Symbol, Expr, Eq, symbols, latex | ||
|
||
from findiff.symbolic import SymbolicMesh, SymbolicDiff | ||
|
||
|
||
class TestSymbolicMesh(unittest.TestCase): | ||
|
||
def test_parse_symbolic_mesh(self): | ||
# 1D | ||
mesh = SymbolicMesh(coord="x", equidistant=True) | ||
x, = mesh.coord | ||
dx, = mesh.spacing | ||
|
||
self.assertEqual(IndexedBase, type(x)) | ||
self.assertEqual(Symbol, type(dx)) | ||
|
||
# 2D | ||
mesh = SymbolicMesh(coord="x,y", equidistant=True) | ||
x, y = mesh.coord | ||
dx, dy = mesh.spacing | ||
|
||
self.assertEqual(IndexedBase, type(x)) | ||
self.assertEqual(IndexedBase, type(y)) | ||
self.assertEqual(Symbol, type(dx)) | ||
self.assertEqual(Symbol, type(dy)) | ||
|
||
# ignores whitespace | ||
mesh = SymbolicMesh(coord="x, y", equidistant=True) | ||
x, y = mesh.coord | ||
|
||
self.assertEqual("x", str(x)) | ||
self.assertEqual("y", str(y)) | ||
|
||
def test_create_symbol(self): | ||
# defaults | ||
mesh = SymbolicMesh(coord="x", equidistant=True) | ||
actual = mesh.create_symbol("u") | ||
expected = IndexedBase("u") | ||
self.assertEqual(latex(actual), latex(expected)) | ||
|
||
# both indices down | ||
mesh = SymbolicMesh(coord="x,y", equidistant=True) | ||
n, m = symbols("n, m") | ||
actual = latex(mesh.create_symbol("u")[n, m]) | ||
#expected = "u{}_{n}{}_{m}" | ||
expected = "{u}_{n,m}" | ||
self.assertEqual(latex(actual), latex(expected)) | ||
|
||
# both indices up | ||
#mesh = SymbolicMesh(coord="x,y", equidistant=True) | ||
#n, m = symbols("n, m") | ||
#u = mesh.create_symbol("u", pos="u,u") | ||
#actual = latex(u[n, m]) | ||
#expected = "u{}^{n}{}^{m}" | ||
#self.assertEqual(actual, expected) | ||
|
||
|
||
class TestDiff(unittest.TestCase): | ||
|
||
def test_init(self): | ||
mesh = SymbolicMesh("x") | ||
d = SymbolicDiff(mesh) | ||
|
||
self.assertEqual(d.axis, 0) | ||
self.assertEqual(d.degree, 1) | ||
self.assertEqual(id(mesh), id(d.mesh)) | ||
|
||
def test_call(self): | ||
# 1D | ||
mesh = SymbolicMesh("x") | ||
u = mesh.create_symbol("u") | ||
d = SymbolicDiff(mesh) | ||
n = Symbol("n") | ||
|
||
actual = d(u, at=n, offsets=[-1, 0, 1]) | ||
|
||
expected = (u[n + 1] - u[n - 1]) / (2 * mesh.spacing[0]) | ||
|
||
self.assertEqual( | ||
0, (expected - actual).simplify() | ||
) | ||
|
||
# 2D | ||
mesh = SymbolicMesh("x, y") | ||
u = mesh.create_symbol("u") | ||
d = SymbolicDiff(mesh, axis=1) | ||
n, m = symbols("n, m") | ||
|
||
actual = d(u, at=(n, m), offsets=[-1, 0, 1]) | ||
|
||
expected = (u[n, m + 1] - u[n, m - 1]) / (2 * mesh.spacing[1]) | ||
|
||
self.assertEqual( | ||
0, (expected - actual).simplify() | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |