Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to Mesh.element_finder #667

Merged
merged 26 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ab23b0e
Optimize the case where the number of input points is large
kinnala Jul 13, 2021
096629a
Remove unused variable
kinnala Jul 13, 2021
bc02ee9
Support only more recent gmsh in the example
kinnala Jul 13, 2021
f49c93d
Simplify the implementation
kinnala Jul 13, 2021
5e976bd
Fix case where the number of elements is smaller than 5
kinnala Jul 13, 2021
e8517b5
Wrap line
kinnala Jul 13, 2021
d073eb3
Optimize MeshTet1.element_finder for large number of points
kinnala Jul 13, 2021
68f91ce
Add tests for MeshQuad interpolation
kinnala Jul 14, 2021
ccc06b7
Compare also the values in the test
kinnala Jul 14, 2021
b1aa748
Add tests for quadratic elements and MeshHex
kinnala Jul 14, 2021
99641b0
Add a changelog entry
kinnala Jul 14, 2021
592208c
Compare the number of points and the number of elements
kinnala Jul 16, 2021
f696dc1
Fall back to exhaustive search if the original search fails
kinnala Jul 16, 2021
fb719be
Run several rounds of randomized tests
kinnala Jul 16, 2021
1973693
Add a changelog entry
kinnala Jul 16, 2021
5e5cc27
Use the candidate reduction strategy from #668
kinnala Jul 20, 2021
7fdbe88
Remove unused imports
kinnala Jul 20, 2021
184203e
Remove ncandidates parameter
kinnala Jul 21, 2021
407fe18
Small fixes
kinnala Jul 21, 2021
25712e8
Take unique elements always
kinnala Jul 22, 2021
018a738
Fix class name
kinnala Jul 22, 2021
bf288e2
Add a warning if the correct element is not found
kinnala Jul 23, 2021
e506151
Search all elements if the correct one is not found
kinnala Jul 23, 2021
53a7ce8
Raise ValueError if the element is not found
kinnala Jul 23, 2021
9f6a830
Raise exception also for MeshLine1.element_finder
kinnala Jul 23, 2021
287bd16
Update changelog
kinnala Jul 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ with respect to documented and/or tested features.
- Added: `ElementTriCCR` and `ElementTetCCR`, conforming Crouzeix-Raviart finite elements
- Fixed: `Mesh.mirrored` returned a wrong mesh when a point other than the origin was used
- Fixed: `MeshLine` constructor accepted only NumPy arrays and not plain Python lists
- Fixed: `Mesh.element_finder` (and `CellBasis.probes`, `CellBasis.interpolator`) was not working properly for a small number of elements (<5) or a large number of input points (>1000)
- Fixed: `MeshTet` and `MeshTri.element_finder` is are now more robust against degenerate elements
- Fixed: `Mesh.element_finder` (and `CellBasis.probes`, `CellBasis.interpolator`) raises exception if the query point is outside of the domain

### [3.1.0] - 2021-06-18

Expand Down
24 changes: 4 additions & 20 deletions docs/examples/ex28.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,7 @@
import pygmsh


if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
class NullContextManager():
def __enter__(self):
return None
def __exit__(self, *args):
pass
geometrycontext = NullContextManager()
else:
geometrycontext = pygmsh.geo.Geometry()
geometrycontext = pygmsh.geo.Geometry()

halfheight = 1.
length = 10.
Expand All @@ -110,17 +102,12 @@ def __exit__(self, *args):
def make_mesh(halfheight: float, # mm
length: float,
thickness: float) -> MeshTri:
with geometrycontext as g:
if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
geom = pygmsh.built_in.Geometry()
geom.add_curve_loop = geom.add_line_loop
else:
geom = g
with geometrycontext as geom:

points = []
lines = []

lcar = halfheight / 2**2
lcar = halfheight / 2 ** 2

for xy in [(0., halfheight),
(0., -halfheight),
Expand Down Expand Up @@ -155,10 +142,7 @@ def make_mesh(halfheight: float, # mm
geom.add_physical(geom.add_plane_surface(geom.add_curve_loop(
[*lines[-3:], -lines[1]])), 'solid')

if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
return from_meshio(pygmsh.generate_mesh(geom, dim=2))
else:
return from_meshio(geom.generate_mesh(dim=2))
return from_meshio(geom.generate_mesh(dim=2))

mesh = from_file(Path(__file__).parent / 'meshes' / 'ex28.json')
element = ElementTriP1()
Expand Down
8 changes: 6 additions & 2 deletions skfem/mesh/mesh_line_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,19 @@ def param(self):
return np.max(np.abs(self.p[0, self.t[1]] - self.p[0, self.t[0]]))

def element_finder(self, mapping=None):

ix = np.argsort(self.p[0])
maxt = self.t[np.argmax(self.p[0, self.t], 0),
np.arange(self.t.shape[1])]

def finder(x):
xin = x.copy() # bring endpoint inside for np.digitize
xin[x == self.p[0, ix[-1]]] = self.p[0, ix[-2:]].mean()
return np.nonzero(ix[np.digitize(xin, self.p[0, ix])][:, None]
== maxt)[1]
elems = np.nonzero(ix[np.digitize(xin, self.p[0, ix])][:, None]
== maxt)[1]
if len(elems) < len(x):
raise ValueError("Point is outside of the mesh.")
return elems

return finder

Expand Down
36 changes: 26 additions & 10 deletions skfem/mesh/mesh_tet_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,34 @@ def element_finder(self, mapping=None):
if mapping is None:
mapping = self._mapping()

tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)
if not hasattr(self, '_cached_tree'):
self._cached_tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)

tree = self._cached_tree
nelems = self.t.shape[1]

def finder(x, y, z, _search_all=False):

if not _search_all:
ix = tree.query(np.array([x, y, z]).T,
min(10, nelems))[1].flatten()
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]
else:
ix = np.arange(nelems, dtype=np.int64)

def finder(x, y, z):
ix = tree.query(np.array([x, y, z]).T, 5)[1].flatten()
X = mapping.invF(np.array([x, y, z])[:, None], ix)
inside = (
(X[0] >= 0) *
(X[1] >= 0) *
(X[2] >= 0) *
(1 - X[0] - X[1] - X[2] >= 0)
)
return np.array([ix[np.argmax(inside, axis=0)]]).flatten()
inside = ((X[0] >= 0) *
(X[1] >= 0) *
(X[2] >= 0) *
(1 - X[0] - X[1] - X[2] >= 0))

if not inside.max(axis=0).all():
if _search_all:
raise ValueError("Point is outside of the mesh.")
return finder(x, y, z, _search_all=True)

return np.array([ix[inside.argmax(axis=0)]]).flatten()

return finder

Expand Down
34 changes: 25 additions & 9 deletions skfem/mesh/mesh_tri_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,32 @@ def element_finder(self, mapping=None):
if mapping is None:
mapping = self._mapping()

tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)
if not hasattr(self, '_cached_tree'):
self._cached_tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)

tree = self._cached_tree
nelems = self.t.shape[1]

def finder(x, y, _search_all=False):

if not _search_all:
ix = tree.query(np.array([x, y]).T,
min(5, nelems))[1].flatten()
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]
else:
ix = np.arange(nelems, dtype=np.int64)

def finder(x, y):
ix = tree.query(np.array([x, y]).T, 5)[1].flatten()
X = mapping.invF(np.array([x, y])[:, None], ix)
inside = (
(X[0] >= 0) *
(X[1] >= 0) *
(1 - X[0] - X[1] >= 0)
)
return np.array([ix[np.argmax(inside, axis=0)]]).flatten()
inside = ((X[0] >= 0) *
(X[1] >= 0) *
(1 - X[0] - X[1] >= 0))

if not inside.max(axis=0).all():
if _search_all:
raise ValueError("Point is outside of the mesh.")
return finder(x, y, _search_all=True)

return np.array([ix[inside.argmax(axis=0)]]).flatten()

return finder
47 changes: 39 additions & 8 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from skfem import BilinearForm, asm, solve, condense, projection
from skfem.mesh import MeshTri, MeshTet, MeshHex, MeshQuad, MeshLine
from skfem.assembly import InteriorBasis, FacetBasis, Dofs, Functional
from skfem.assembly import CellBasis, FacetBasis, Dofs, Functional
from skfem.element import (ElementVectorH1, ElementTriP2, ElementTriP1,
ElementTetP2, ElementHexS2, ElementHex2,
ElementQuad2, ElementLineP2, ElementTriP0,
Expand All @@ -28,7 +28,7 @@ def runTest(self):

e = ElementVectorH1(ElementTriP2()) * ElementTriP1()

basis = InteriorBasis(m, e)
basis = CellBasis(m, e)

@BilinearForm
def bilinf(u, p, v, q, w):
Expand Down Expand Up @@ -93,7 +93,7 @@ def runTest(self):

m = self.mesh_type().refined(2)

basis = InteriorBasis(m, self.elem_type())
basis = CellBasis(m, self.elem_type())

for fun in [lambda x: x[0] == 0,
lambda x: x[0] == 1,
Expand Down Expand Up @@ -128,7 +128,7 @@ class TestInterpolatorTet(TestCase):

def runTest(self):
m = self.mesh_type().refined(self.nrefs)
basis = InteriorBasis(m, self.element_type())
basis = CellBasis(m, self.element_type())
x = projection(lambda x: x[0] ** 2, basis)
fun = basis.interpolator(x)
X = np.linspace(0, 1, 10)
Expand Down Expand Up @@ -187,7 +187,38 @@ def runTest(self):
with self.assertRaises(ValueError):
m = MeshTri()
e = ElementTetP2()
basis = InteriorBasis(m, e)
basis = CellBasis(m, e)


@pytest.mark.parametrize(
"mtype,e,nrefs,npoints",
[
(MeshTri, ElementTriP1(), 0, 10),
(MeshTri, ElementTriP2(), 1, 10),
(MeshTri, ElementTriP1(), 5, 10),
(MeshTri, ElementTriP1(), 1, 3e5),
(MeshTet, ElementTetP2(), 1, 10),
(MeshTet, ElementTetP1(), 5, 10),
(MeshTet, ElementTetP1(), 1, 3e5),
(MeshQuad, ElementQuad1(), 1, 10),
(MeshQuad, ElementQuad1(), 1, 3e5),
(MeshHex, ElementHex1(), 1, 1e5),
]
)
def test_interpolator_probes(mtype, e, nrefs, npoints):

m = mtype().refined(nrefs)

np.random.seed(0)
X = np.random.rand(m.p.shape[0], int(npoints))

basis = CellBasis(m, e)

y = projection(lambda x: x[0] ** 2, basis)

assert_allclose(basis.probes(X) @ y, basis.interpolator(y)(X))
atol = 1e-1 if nrefs <= 1 else 1e-3
assert_allclose(basis.probes(X) @ y, X[0] ** 2, atol=atol)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,7 +246,7 @@ def test_trace(mtype, e1, e2):
# use the boundary where last coordinate is zero
basis = FacetBasis(m, e1,
facets=m.facets_satisfying(lambda x: x[x.shape[0] - 1] == 0.0))
xfun = projection(lambda x: x[0], InteriorBasis(m, e1))
xfun = projection(lambda x: x[0], CellBasis(m, e1))
nbasis, y = basis.trace(xfun, lambda p: p[0:(p.shape[0] - 1)], target_elem=e2)

@Functional
Expand All @@ -235,8 +266,8 @@ def test_point_source(etype):
from skfem.models.poisson import laplace

mesh = MeshLine().refined()
basis = InteriorBasis(mesh, etype())
basis = CellBasis(mesh, etype())
source = np.array([0.7])
u = solve(*condense(asm(laplace, basis), basis.point_source(source), D=basis.find_dofs()))
exact = np.stack([(1 - source) * mesh.p, (1 - mesh.p) * source]).min(0)
assert_almost_equal(u[basis.nodal_dofs], exact)
assert_almost_equal(u[basis.nodal_dofs], exact)
34 changes: 32 additions & 2 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from pathlib import Path

import numpy as np
from numpy.testing import assert_array_equal
import pytest
from scipy.spatial import Delaunay
from numpy.testing import assert_array_equal

from skfem.mesh import Mesh, MeshHex, MeshLine, MeshQuad, MeshTet, MeshTri, MeshTri2, MeshQuad2, MeshTet2, MeshHex2
from skfem.mesh import (Mesh, MeshHex, MeshLine, MeshQuad, MeshTet, MeshTri,
MeshTri2, MeshQuad2, MeshTet2, MeshHex2)
from skfem.io.meshio import to_meshio, from_meshio


Expand Down Expand Up @@ -241,6 +243,34 @@ def runTest(self):
self.assertEqual(finder(np.array([0.001]))[0], 0)



@pytest.mark.parametrize(
"m,seed",
[
(MeshTri(), 0),
(MeshTri(), 1),
(MeshTri(), 2),
(MeshTet(), 0),
(MeshTet(), 1),
(MeshTet(), 2),
(MeshTet(), 10),
]
)
def test_finder_simplex(m, seed):

np.random.seed(seed)
points = np.hstack((m.p, np.random.rand(m.p.shape[0], 100)))
tri = Delaunay(points.T)
M = type(m)(points, tri.simplices.T)
finder = M.element_finder()

query_pts = np.random.rand(m.p.shape[0], 500)
assert_array_equal(
tri.find_simplex(query_pts.T),
finder(*query_pts),
)


@pytest.mark.parametrize(
"m",
[
Expand Down