Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to Mesh.element_finder #667

Merged
merged 26 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ab23b0e
Optimize the case where the number of input points is large
kinnala Jul 13, 2021
096629a
Remove unused variable
kinnala Jul 13, 2021
bc02ee9
Support only more recent gmsh in the example
kinnala Jul 13, 2021
f49c93d
Simplify the implementation
kinnala Jul 13, 2021
5e976bd
Fix case where the number of elements is smaller than 5
kinnala Jul 13, 2021
e8517b5
Wrap line
kinnala Jul 13, 2021
d073eb3
Optimize MeshTet1.element_finder for large number of points
kinnala Jul 13, 2021
68f91ce
Add tests for MeshQuad interpolation
kinnala Jul 14, 2021
ccc06b7
Compare also the values in the test
kinnala Jul 14, 2021
b1aa748
Add tests for quadratic elements and MeshHex
kinnala Jul 14, 2021
99641b0
Add a changelog entry
kinnala Jul 14, 2021
592208c
Compare the number of points and the number of elements
kinnala Jul 16, 2021
f696dc1
Fall back to exhaustive search if the original search fails
kinnala Jul 16, 2021
fb719be
Run several rounds of randomized tests
kinnala Jul 16, 2021
1973693
Add a changelog entry
kinnala Jul 16, 2021
5e5cc27
Use the candidate reduction strategy from #668
kinnala Jul 20, 2021
7fdbe88
Remove unused imports
kinnala Jul 20, 2021
184203e
Remove ncandidates parameter
kinnala Jul 21, 2021
407fe18
Small fixes
kinnala Jul 21, 2021
25712e8
Take unique elements always
kinnala Jul 22, 2021
018a738
Fix class name
kinnala Jul 22, 2021
bf288e2
Add a warning if the correct element is not found
kinnala Jul 23, 2021
e506151
Search all elements if the correct one is not found
kinnala Jul 23, 2021
53a7ce8
Raise ValueError if the element is not found
kinnala Jul 23, 2021
9f6a830
Raise exception also for MeshLine1.element_finder
kinnala Jul 23, 2021
287bd16
Update changelog
kinnala Jul 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ with respect to documented and/or tested features.
- Added: `ElementTriCCR` and `ElementTetCCR`, conforming Crouzeix-Raviart finite elements
- Fixed: `Mesh.mirrored` returned a wrong mesh when a point other than the origin was used
- Fixed: `MeshLine` constructor accepted only NumPy arrays and not plain Python lists
- Fixed: `Mesh.element_finder` (and `CellBasis.probes` and `CellBasis.interpolator`) was not working properly for a small number of elements (<5) or a large number of input points (>1000)
- Fixed: `Mesh.element_finder` is now more robust against degenerate triangles and tetrahedra
kinnala marked this conversation as resolved.
Show resolved Hide resolved

### [3.1.0] - 2021-06-18

Expand Down
21 changes: 3 additions & 18 deletions docs/examples/ex28.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,7 @@
import pygmsh


if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
class NullContextManager():
def __enter__(self):
return None
def __exit__(self, *args):
pass
geometrycontext = NullContextManager()
else:
geometrycontext = pygmsh.geo.Geometry()
geometrycontext = pygmsh.geo.Geometry()

halfheight = 1.
length = 10.
Expand All @@ -111,11 +103,7 @@ def make_mesh(halfheight: float, # mm
length: float,
thickness: float) -> MeshTri:
with geometrycontext as g:
if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
geom = pygmsh.built_in.Geometry()
geom.add_curve_loop = geom.add_line_loop
else:
geom = g
geom = g
kinnala marked this conversation as resolved.
Show resolved Hide resolved

points = []
lines = []
Expand Down Expand Up @@ -155,10 +143,7 @@ def make_mesh(halfheight: float, # mm
geom.add_physical(geom.add_plane_surface(geom.add_curve_loop(
[*lines[-3:], -lines[1]])), 'solid')

if version.parse(pygmsh.__version__) < version.parse('7.0.0'):
return from_meshio(pygmsh.generate_mesh(geom, dim=2))
else:
return from_meshio(geom.generate_mesh(dim=2))
return from_meshio(geom.generate_mesh(dim=2))

mesh = from_file(Path(__file__).parent / 'meshes' / 'ex28.json')
element = ElementTriP1()
Expand Down
30 changes: 20 additions & 10 deletions skfem/mesh/mesh_tet_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,28 @@ def element_finder(self, mapping=None):
if mapping is None:
mapping = self._mapping()

tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)
if not hasattr(self, '_cached_tree'):
self._cached_tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)

tree = self._cached_tree
nelems = self.t.shape[1]

def finder(x, y, z, ncandidates=8):

ix = tree.query(np.array([x, y, z]).T,
min(ncandidates, nelems))[1].flatten()
if len(ix) > nelems:
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]

def finder(x, y, z):
ix = tree.query(np.array([x, y, z]).T, 5)[1].flatten()
X = mapping.invF(np.array([x, y, z])[:, None], ix)
inside = (
(X[0] >= 0) *
(X[1] >= 0) *
(X[2] >= 0) *
(1 - X[0] - X[1] - X[2] >= 0)
)
return np.array([ix[np.argmax(inside, axis=0)]]).flatten()
inside = ((X[0] >= 0) *
(X[1] >= 0) *
(X[2] >= 0) *
(1 - X[0] - X[1] - X[2] >= 0))
elems = np.argmax(inside, axis=0)

return np.array([ix[elems]]).flatten()

return finder

Expand Down
28 changes: 19 additions & 9 deletions skfem/mesh/mesh_tri_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,26 @@ def element_finder(self, mapping=None):
if mapping is None:
mapping = self._mapping()

tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)
if not hasattr(self, '_cached_tree'):
self._cached_tree = cKDTree(np.mean(self.p[:, self.t], axis=1).T)

tree = self._cached_tree
nelems = self.t.shape[1]

def finder(x, y, ncandidates=5):

ix = tree.query(np.array([x, y]).T,
min(ncandidates, nelems))[1].flatten()
if len(ix) > nelems:
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]

def finder(x, y):
ix = tree.query(np.array([x, y]).T, 5)[1].flatten()
X = mapping.invF(np.array([x, y])[:, None], ix)
inside = (
(X[0] >= 0) *
(X[1] >= 0) *
(1 - X[0] - X[1] >= 0)
)
return np.array([ix[np.argmax(inside, axis=0)]]).flatten()
inside = ((X[0] >= 0) *
(X[1] >= 0) *
(1 - X[0] - X[1] >= 0))
elems = np.argmax(inside, axis=0)

return np.array([ix[elems]]).flatten()

return finder
47 changes: 39 additions & 8 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from skfem import BilinearForm, asm, solve, condense, projection
from skfem.mesh import MeshTri, MeshTet, MeshHex, MeshQuad, MeshLine
from skfem.assembly import InteriorBasis, FacetBasis, Dofs, Functional
from skfem.assembly import CellBasis, FacetBasis, Dofs, Functional
from skfem.element import (ElementVectorH1, ElementTriP2, ElementTriP1,
ElementTetP2, ElementHexS2, ElementHex2,
ElementQuad2, ElementLineP2, ElementTriP0,
Expand All @@ -28,7 +28,7 @@ def runTest(self):

e = ElementVectorH1(ElementTriP2()) * ElementTriP1()

basis = InteriorBasis(m, e)
basis = CellBasis(m, e)

@BilinearForm
def bilinf(u, p, v, q, w):
Expand Down Expand Up @@ -93,7 +93,7 @@ def runTest(self):

m = self.mesh_type().refined(2)

basis = InteriorBasis(m, self.elem_type())
basis = CellBasis(m, self.elem_type())

for fun in [lambda x: x[0] == 0,
lambda x: x[0] == 1,
Expand Down Expand Up @@ -128,7 +128,7 @@ class TestInterpolatorTet(TestCase):

def runTest(self):
m = self.mesh_type().refined(self.nrefs)
basis = InteriorBasis(m, self.element_type())
basis = CellBasis(m, self.element_type())
x = projection(lambda x: x[0] ** 2, basis)
fun = basis.interpolator(x)
X = np.linspace(0, 1, 10)
Expand Down Expand Up @@ -187,7 +187,38 @@ def runTest(self):
with self.assertRaises(ValueError):
m = MeshTri()
e = ElementTetP2()
basis = InteriorBasis(m, e)
basis = CellBasis(m, e)


@pytest.mark.parametrize(
"mtype,e,nrefs,npoints",
[
(MeshTri, ElementTriP1(), 0, 10),
(MeshTri, ElementTriP2(), 1, 10),
(MeshTri, ElementTriP1(), 5, 10),
(MeshTri, ElementTriP1(), 1, 3e5),
(MeshTet, ElementTetP2(), 1, 10),
(MeshTet, ElementTetP1(), 5, 10),
(MeshTet, ElementTetP1(), 1, 3e5),
(MeshQuad, ElementQuad1(), 1, 10),
(MeshQuad, ElementQuad1(), 1, 3e5),
(MeshHex, ElementHex1(), 1, 1e5),
]
)
def test_interpolator_probes(mtype, e, nrefs, npoints):

m = mtype().refined(nrefs)

np.random.seed(0)
X = np.random.rand(m.p.shape[0], int(npoints))

basis = CellBasis(m, e)

y = projection(lambda x: x[0] ** 2, basis)

assert_allclose(basis.probes(X) @ y, basis.interpolator(y)(X))
atol = 1e-1 if nrefs <= 1 else 1e-3
assert_allclose(basis.probes(X) @ y, X[0] ** 2, atol=atol)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,7 +246,7 @@ def test_trace(mtype, e1, e2):
# use the boundary where last coordinate is zero
basis = FacetBasis(m, e1,
facets=m.facets_satisfying(lambda x: x[x.shape[0] - 1] == 0.0))
xfun = projection(lambda x: x[0], InteriorBasis(m, e1))
xfun = projection(lambda x: x[0], CellBasis(m, e1))
nbasis, y = basis.trace(xfun, lambda p: p[0:(p.shape[0] - 1)], target_elem=e2)

@Functional
Expand All @@ -235,8 +266,8 @@ def test_point_source(etype):
from skfem.models.poisson import laplace

mesh = MeshLine().refined()
basis = InteriorBasis(mesh, etype())
basis = CellBasis(mesh, etype())
source = np.array([0.7])
u = solve(*condense(asm(laplace, basis), basis.point_source(source), D=basis.find_dofs()))
exact = np.stack([(1 - source) * mesh.p, (1 - mesh.p) * source]).min(0)
assert_almost_equal(u[basis.nodal_dofs], exact)
assert_almost_equal(u[basis.nodal_dofs], exact)
32 changes: 30 additions & 2 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from pathlib import Path

import numpy as np
from numpy.testing import assert_array_equal
import pytest
from scipy.spatial import Delaunay
from numpy.testing import assert_array_equal

from skfem.mesh import Mesh, MeshHex, MeshLine, MeshQuad, MeshTet, MeshTri, MeshTri2, MeshQuad2, MeshTet2, MeshHex2
from skfem.mesh import (Mesh, MeshHex, MeshLine, MeshQuad, MeshTet, MeshTri,
MeshTri2, MeshQuad2, MeshTet2, MeshHex2)
from skfem.io.meshio import to_meshio, from_meshio


Expand Down Expand Up @@ -241,6 +243,32 @@ def runTest(self):
self.assertEqual(finder(np.array([0.001]))[0], 0)



@pytest.mark.parametrize(
"m",
[
MeshTri(),
MeshTet(),
]
)
def test_finder_simplex(m):

m = m.refined(3)

for seed in range(10):
np.random.seed(seed)
points = np.hstack((m.p, np.random.rand(m.p.shape[0], 100)))
tri = Delaunay(points.T)
M = type(m)(points, tri.simplices.T)
finder = M.element_finder()

query_pts = np.random.rand(m.p.shape[0], 500)
assert_array_equal(
tri.find_simplex(query_pts.T),
finder(*query_pts, ncandidates=15)
)


@pytest.mark.parametrize(
"m",
[
Expand Down