Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples: Boundary reconstruction #1926

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ class Injection(UnevaluatedSparseOperation):
Evaluates to a list of Eq objects.
"""

def __new__(cls, field, expr, offset, interpolator, callback):
def __new__(cls, field, expr, offset, increment, interpolator, callback):
obj = super().__new__(cls, interpolator, callback)

# TODO: unused now, but will be necessary to compute the adjoint
obj.field = field
obj.expr = expr
obj.offset = offset
obj.increment = increment

return obj

Expand Down Expand Up @@ -255,7 +256,7 @@ def callback():

return Interpolation(expr, offset, increment, self_subs, self, callback)

def inject(self, field, expr, offset=0):
def inject(self, field, expr, offset=0, increment=True):
"""
Generate equations injecting an arbitrary expression into a field.

Expand All @@ -267,6 +268,8 @@ def inject(self, field, expr, offset=0):
Injected expression.
offset : int, optional
Additional offset from the boundary.
increment: bool, optional
If True, generate increments (Inc) rather than assignments (Eq).
"""
def callback():
# Derivatives must be evaluated before the introduction of indirect accesses
Expand All @@ -285,13 +288,18 @@ def callback():
field_offset=field_offset)

# Substitute coordinate base symbols into the interpolation coefficients
eqns = [Inc(field.xreplace(vsub), _expr.xreplace(vsub) * b,
implicit_dims=self.sfunction.dimensions)
for b, vsub in zip(self._interpolation_coeffs, idx_subs)]
if increment:
eqns = [Inc(field.xreplace(vsub), _expr.xreplace(vsub) * b,
implicit_dims=self.sfunction.dimensions)
for b, vsub in zip(self._interpolation_coeffs, idx_subs)]
else:
eqns = [Eq(field.xreplace(idx_subs[0]), _expr.xreplace(idx_subs[0]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you picking idx_subs[0]? should you not traverse the whole list of idx_subs like in the increment case?

Could you do something like:

if increment:
   cls: Inc
else:
   cls: Eq
eqns = [cls(field.xreplace(....... .....) for b, vsub in zip(....)]

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I thought in this case very specifically to apply boundary conditions, not source injection, for example. The ` increment=False option as I suggested should not be used for off-grid points.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is true for your case, the generic case is still a standard potentially off the grid point injection without increment. So there is two option

  • Use the injection as is knowing the coefficient will still be zero for the interpolation if the point is on the grid
  • Define somehow an "on-the-grid" sparse function that ignores the interpolation.

The first case is definitely easier and should still work for your case and would boil down to what @FabioLuporini wrote earlier

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the first option is not possible for the case increment = False. Because even if the point is exactly on the grid, interpolation coefficients equal to zero will vanish the fields at points that should not be vanished. Which is not the case for injecting the source (when increment = True), for example, since accumulating zeros does not change the field at the other points involved in the interpolation.

image

The way I see it, we are left with the second option that @mloubout suggested, or use subdomains that are under development in another PR if I'm not mistaken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed good point I missed that gotta think about it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mloubout prodding

Copy link
Contributor

@mloubout mloubout Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting back here since been looking into sparse, I think that the "best" way to go for increment=Flase is to do what interpolation does i.e

  • Define sum and do the standard sum += inject(expr)
  • assign sum afterward field = sum

This would cover the off-the-grid case as well

implicit_dims=self.sfunction.dimensions)]


return temps + eqns

return Injection(field, expr, offset, self, callback)
return Injection(field, expr, offset, increment, self, callback)


class PrecomputedInterpolator(GenericInterpolator):
Expand Down Expand Up @@ -388,3 +396,4 @@ def callback():
return [Eq(_field, _field + rhs.subs(dim_subs))]

return Injection(field, expr, offset, self, callback)

6 changes: 4 additions & 2 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def interpolate(self, expr, offset=0, u_t=None, p_t=None, increment=False):
increment=increment,
self_subs=subs)

def inject(self, field, expr, offset=0, u_t=None, p_t=None):
def inject(self, field, expr, offset=0, u_t=None, p_t=None, increment=True):
"""
Generate equations injecting an arbitrary expression into a field.

Expand All @@ -864,14 +864,16 @@ def inject(self, field, expr, offset=0, u_t=None, p_t=None):
Time index at which the interpolation is performed.
p_t : expr-like, optional
Time index at which the result of the interpolation is stored.
increment: bool, optional
If True, generate increments (Inc) rather than assignments (Eq).
"""
# Apply optional time symbol substitutions to field and expr
if u_t is not None:
field = field.subs({field.time_dim: u_t})
if p_t is not None:
expr = expr.subs({self.time_dim: p_t})

return super(SparseTimeFunction, self).inject(field, expr, offset=offset)
return super(SparseTimeFunction, self).inject(field, expr, offset=offset, increment=increment)

# Pickling support
_pickle_kwargs = AbstractSparseTimeFunction._pickle_kwargs +\
Expand Down
1,069 changes: 1,069 additions & 0 deletions examples/seismic/tutorials/16_source_wavefield_reconstruction.ipynb

Large diffs are not rendered by default.

Binary file added examples/seismic/tutorials/boundary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 38 additions & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import sin, floor
from math import sin, floor, prod

import numpy as np
import pytest
Expand Down Expand Up @@ -53,6 +53,18 @@ def time_points(grid, ranges, npoints, name='points', nt=10):
return points


def time_grid_points(grid, name='points', nt=10):
"""Create a SparseTimeFunction field with coordinates
filled in by all grid points"""
npoints = prod(grid.shape)
a = SparseTimeFunction(name=name, grid=grid, npoint=npoints, nt=nt)
dims = tuple([np.linspace(0., 1., d) for d in grid.shape])
for i in range(len(grid.shape)):
a.coordinates.data[:,i] = np.meshgrid(*dims)[i].flatten()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably compute meshgrid only once before the for loop


return a


def a(shape=(11, 11)):
grid = Grid(shape=shape)
a = Function(name='a', grid=grid)
Expand Down Expand Up @@ -417,6 +429,31 @@ def test_inject_time_shift(shape, coords, result, npoints=19):
assert np.allclose(a.data[indices], result, rtol=1.e-5)


@pytest.mark.parametrize('shape, result, increment', [
((10, 10), 1., False),
((10, 10), 5., True),
((10, 10, 10), 1., False),
((10, 10, 10), 5., True)
])
def test_inject_time_increment(shape, result, increment):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for attaching tests to this PR!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth having another test to check that increment=False specified on SparseFunction with non-grid-aligned points is caught and handled appropriately?

"""Test the increment option in the SparseTimeFunction's
injection method. The increment=False option is
expected to work only at points located on the grid,
where no interpolation needed.
"""
a = unit_box_time(shape=shape)
a.data[:] = 0.
p = time_grid_points(a.grid, name='points', nt=10)

expr = p.inject(a, Float(1.), increment=increment)

Operator(expr)(a=a)

assert np.allclose(a.data, result*np.ones(a.grid.shape), rtol=1.e-5)




@pytest.mark.parametrize('shape, coords, result', [
((11, 11), [(.05, .95), (.45, .45)], 1.),
((11, 11, 11), [(.05, .95), (.45, .45), (.45, .45)], 0.5)
Expand Down