Skip to content

Commit

Permalink
Add support for 3D visualization via export.triplot. (#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Jun 9, 2023
2 parents 41f8027 + 763647d commit 1943449
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
python -um pip install --upgrade --upgrade-strategy eager wheel
python -um pip install --upgrade --upgrade-strategy eager coverage numpy$_numpy_version
# Install Nutils from `dist` dir created in job `build-python-package`.
python -um pip install "$_wheel[import_gmsh]"
python -um pip install "$_wheel[import_gmsh,export_mpl]"
- name: Install Scipy
if: ${{ matrix.matrix-backend == 'scipy' }}
run: python -um pip install --upgrade scipy
Expand Down
117 changes: 106 additions & 11 deletions nutils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,18 @@ def mplfigure(name, kwargs=...):


def plotlines_(ax, xy, lines, **kwargs):
from matplotlib import collections
lc = collections.LineCollection(numpy.asarray(xy).T[lines], **kwargs)
if len(xy) == 2:
from matplotlib.collections import LineCollection
elif len(xy) == 3:
from mpl_toolkits.mplot3d.art3d import Line3DCollection as LineCollection
else:
raise NotImplementedError(f'dimension not supported: {len(xy)}')
lc = LineCollection(numpy.asarray(xy).T[lines], **kwargs)
ax.add_collection(lc)
return lc


def _triplot_1d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None):
def _triplot_1d(ax, points, values, tri, hull, cmap, clim, linewidth, linecolor, plabel, vlabel):
if plabel:
ax.set_xlabel(plabel)
if vlabel:
Expand All @@ -64,7 +69,7 @@ def _triplot_1d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim
ax.set_ylim(clim)


def _triplot_2d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None):
def _triplot_2d(ax, points, values, tri, hull, cmap, clim, linewidth, linecolor, plabel, vlabel):
ax.set_aspect('equal')
if plabel:
ax.set_xlabel(plabel)
Expand All @@ -81,25 +86,115 @@ def _triplot_2d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim
return im


def triplot(name, points, values=None, **kwargs):
def _triplot_3d(ax, points, values, tri, hull, cmap, clim, linewidth, linecolor, plabel, vlabel):
if tri is not None:
im = ax.plot_trisurf(*points.T, triangles=tri, rasterized=True, antialiased=True, cmap=cmap)
if values is not None:
im.set_array(numpy.nanmean(values[tri], axis=1))
if clim is not None:
im.set_clim(clim)
else:
im = None
if hull is not None:
plotlines_(ax, points.T, hull, colors=linecolor, linewidths=linewidth, alpha=1 if tri is None else .5)
pmin = points.min(axis=0)
pmax = points.max(axis=0)
for d, *lim in zip('xyz', pmin, pmax):
getattr(ax, f'set_{d}lim3d')(lim)
if plabel:
getattr(ax, f'set_{d}label')(plabel)
ax.set_box_aspect(pmax - pmin) # together with set_*lim3d above this results in 1:1:1 aspect ratio
return im


def triplot(name, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None):
'''
Uniform plotting interface to preview 1D/2D/3D results.
This function serves to quickly visualise field data and/or finite element
meshes, with a consistent interface that works across dimensions. If the
provided data is one-dimensional the resulting plot is a graph; if it is
two-dimensional the result is a surface plot; and three-dimensional data is
plotted in a fixed ortholinear projection.
The function can be used in two modes, depending on the first argument:
* Standalone: By providing a filename as the first argument, data is
plotted, a colorbar is added if values are provided, and the image is
saved to the specified file.
* Matplotlib component: By providing a Matplotlib axes object as the
first argument, data is plotted to the axes and a handle to the scalar
mappable (if any) is returned.
Notes on 3D: Due to limitations of the underlying libraries, 3D data can
only be visualised on 2D meshes, i.e. the boundary of the topology.
Plotting of the 3D hull is supported but without occlusion, meaning that a
full wireframe is layed over the (properly occluded) field data. For use as
a Matplotlib component the provided axes must have projection="3d" set.
Args
----
name : :class:`str` or axes object
File name of the destination image (with extension) or Matplotlib axes object.
points : :class:`float` array
Vertex coordinates: a 2D float array shaped as <number of vertices> x
<spatial dimension>.
values : :class:`float` array
Scalar field quantities in the vertices specified by ``points``: a 1D float
array shaped as <number of vertices>.
tri : :class:`int` array
Triangulation of the vertices: a 2D integer array shaped as <number of
simplices> x <mesh dimension + 1>, where the mesh dimension may not
exceed the spatial dimension.
hull : :class:`int` array
Triangulation of the element hulls for the visualisation of mesh lines: a
2D integer array shaped as <number of hull simplices> x <mesh dimension>.
cmap : :class:`str`
Color map used for the visualisation of ``values``. Ignored if the
spatial dimension is one.
clim : :class:`tuple` of floats
Data truncation range.
linewidth : :class:`float`
Mesh line thickness. Ignored if ``hull`` is not specified.
linecolor : :class:`str`
Mesh line color. Ignored if ``hull`` is not specified.
plabel : :class:`str`
Axis label for the coordinates.
vlabel : :class:`str`
Axis label for the values.
'''

if points.ndim != 2:
raise ValueError(f'points must be a 2-dimensional array, received shape={points.shape}')

if points.shape[1] == 1:
nd = points.shape[1]
if nd == 1:
_triplot = _triplot_1d
elif points.shape[1] == 2:
elif nd == 2:
_triplot = _triplot_2d
elif nd == 3:
_triplot = _triplot_3d
else:
raise Exception(f'invalid spatial dimension: {nd}')

if (kwargs.get('tri') is None) != (values is None):
if (tri is None) != (values is None):
raise Exception('tri and values can only be specified jointly')

args = points, values, tri, hull, cmap, clim, linewidth, linecolor, plabel, vlabel
if not isinstance(name, str):
return _triplot(name, points, values, **kwargs)
return _triplot(name, *args)

with mplfigure(name) as fig:
im = _triplot(fig.add_subplot(111), points, values, **kwargs)
if nd < 3:
ax = fig.add_subplot(111)
cbarargs = {}
else:
ax = fig.add_subplot(111, projection='3d')
cbarargs = dict(shrink=.5, pad=.1)
im = _triplot(ax, *args)
if im:
fig.colorbar(im, label=kwargs.get('vlabel'))
fig.colorbar(im, label=vlabel, **cbarargs)


@util.positional_only
Expand Down
37 changes: 37 additions & 0 deletions tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,43 @@ def test_autodetect_imagetype(self):
test(f.read())


@testing.parametrize
class triplot(testing.TestCase):

def setUp(self):
super().setUp()
self.outdir = pathlib.Path(self.enter_context(tempfile.TemporaryDirectory()))
self.enter_context(treelog.set(treelog.DataLog(str(self.outdir))))
self.coords = numpy.zeros([self.ndims + 1, self.ndims])
self.coords[1:] = numpy.eye(self.ndims)
self.tri = numpy.arange(self.ndims + 1)[numpy.newaxis]
self.hull = numpy.array([self.tri[0,~m] for m in numpy.eye(self.ndims+1, dtype=bool)])
if self.ndims == 3:
self.tri = self.hull
self.hull = numpy.array([[i,j] for i in range(4) for j in range(i)])
self.values = numpy.arange(self.ndims+1, dtype=float) * self.ndims

@testing.requires('matplotlib', 'PIL')
def test_filename(self):
export.triplot('test.jpg', self.coords, self.values, tri=self.tri, hull=self.hull)

@testing.requires('matplotlib', 'PIL')
def test_axesobj(self):
with export.mplfigure('test.jpg') as fig:
ax = fig.add_subplot(111, projection='3d' if self.ndims == 3 else None)
im = export.triplot(ax, self.coords, self.values, tri=self.tri, hull=self.hull)
if self.ndims == 1:
self.assertEqual(im, None)
elif self.ndims == 2:
self.assertAllEqual(im.get_array(), self.values)
elif self.ndims == 3:
self.assertAllEqual(im.get_array(), self.values[self.tri].mean(1))

triplot(ndims=1)
triplot(ndims=2)
triplot(ndims=3)


@testing.parametrize
class vtk(testing.TestCase):

Expand Down

0 comments on commit 1943449

Please sign in to comment.