Skip to content

Commit

Permalink
Merge pull request #681 from evalf/lowerargs
Browse files Browse the repository at this point in the history
Lowerargs
  • Loading branch information
gertjanvanzwieten committed May 18, 2022
2 parents 280e6d7 + 5988f51 commit 53e6fef
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 187 deletions.
2 changes: 1 addition & 1 deletion nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def wrapped(target, *funcargs, **funckwargs):
cache[obj] = rstack[-1] if rstack[-1] is not obj else identity
continue

if isinstance(obj, (tuple, list, dict, set, frozenset)):
if obj.__class__ in (tuple, list, dict, set, frozenset):
if not obj:
rstack.append(obj) # shortcut to avoid recreation of empty container
else:
Expand Down
239 changes: 132 additions & 107 deletions nutils/function.py

Large diffs are not rendered by default.

91 changes: 38 additions & 53 deletions nutils/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_evaluable_indices(self, __ielem: evaluable.Array) -> evaluable.Array:
def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
raise NotImplementedError

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
raise NotImplementedError

@util.positional_only
Expand Down Expand Up @@ -396,21 +396,8 @@ def __init__(self, space: str, transforms: Tuple[Transforms, ...], points: Point
def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
return self.points.get_evaluable_weights(__ielem)

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
if self.space in transform_chains or self.space in coordinates:
raise ValueError('Nested integrals or samples in the same space are not supported.')

transform_chains = dict(transform_chains)
transform_chains[self.space] = space_transform_chains = tuple(t.get_evaluable(__ielem) for t in (self.transforms*2)[:2])

space_coordinates = self.points.get_evaluable_coords(__ielem)
assert space_coordinates.ndim == 2 # axes: points, coord dim
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, space_coordinates.shape[:-1]), coords.ndim - 1) for space, coords in coordinates.items()}
coordinates[self.space] = evaluable.prependaxes(space_coordinates, points_shape)

points_shape = points_shape + space_coordinates.shape[:-1]

return points_shape, transform_chains, coordinates
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
return function.LowerArgs.for_space(self.space, tuple(t.get_evaluable(__ielem) for t in (self.transforms*2)[:2]), self.points.get_evaluable_coords(__ielem))

def basis(self) -> function.Array:
return _Basis(self)
Expand Down Expand Up @@ -502,8 +489,8 @@ def get_evaluable_indices(self, __ielem: evaluable.Array) -> evaluable.Array:
def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
raise SkipTest('`{}` does not implement `Sample.get_evaluable_weights`'.format(type(self).__qualname__))

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
raise SkipTest('`{}` does not implement `Sample.update_lower_args`'.format(type(self).__qualname__))
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
raise SkipTest('`{}` does not implement `Sample.get_lower_args`'.format(type(self).__qualname__))

@property
def transforms(self) -> Tuple[Transforms, ...]:
Expand Down Expand Up @@ -534,8 +521,8 @@ def get_evaluable_indices(self, __ielem: evaluable.Array) -> evaluable.Array:
def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
return evaluable.Zeros((0,) * len(self.spaces), dtype=float)

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
return points_shape, transform_chains, coordinates
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
return function.LowerArgs((), {}, {})

def get_element_tri(self, ielem: int) -> numpy.ndarray:
raise IndexError('index out of range')
Expand Down Expand Up @@ -631,10 +618,9 @@ def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
weights2 = self._sample2.get_evaluable_weights(ielem2)
return evaluable.einsum('A,B->AB', weights1, weights2)

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
ielem1, ielem2 = evaluable.divmod(__ielem, self._sample2.nelems)
points_shape, transform_chains, coordinates = self._sample1.update_lower_args(ielem1, points_shape, transform_chains, coordinates)
return self._sample2.update_lower_args(ielem2, points_shape, transform_chains, coordinates)
return self._sample1.get_lower_args(ielem1) | self._sample2.get_lower_args(ielem2)

def get_element_tri(self, ielem: int) -> numpy.ndarray:
if self._sample1.ndims == 1:
Expand Down Expand Up @@ -730,16 +716,16 @@ def _getslice(self, array, ielem):
s = evaluable.Take(self._offsets, ielem) + evaluable.Range(evaluable.Take(self._sizes, ielem))
return evaluable.Take(array, s)

def update_lower_args(self, ielem, points_shape, transform_chains, coordinates):
if set(self.spaces) & set(transform_chains):
raise ValueError('Nested integrals or samples in the same space are not supported.')
size = evaluable.Take(self._sizes, ielem)
coordinates = {space: evaluable.insertaxis(coords, -2, size) for space, coords in coordinates.items()}
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
points_shape = evaluable.Take(self._sizes, __ielem),
coordinates = {}
transform_chains = {}
for samplei, ielemsi, ilocalsi in zip(self._samples, self._ielems, self._ilocals):
_, transform_chains, coordinatesi = samplei.update_lower_args(evaluable.Take(ielemsi, ielem), (), transform_chains, {})
for space, coords in coordinatesi.items():
coordinates[space] = evaluable.prependaxes(evaluable._take(coords, self._getslice(ilocalsi, ielem), axis=0), points_shape)
return (*points_shape, size), transform_chains, coordinates
argsi = samplei.get_lower_args(evaluable.Take(ielemsi, __ielem))
slicei = self._getslice(ilocalsi, __ielem)
transform_chains.update(argsi.transform_chains)
coordinates.update({space: evaluable._take(coords, slicei, axis=0) for space, coords in argsi.coordinates.items()})
return function.LowerArgs(points_shape, transform_chains, coordinates)

def get_evaluable_indices(self, ielem):
return self._getslice(self._indices, ielem)
Expand Down Expand Up @@ -787,8 +773,8 @@ def get_evaluable_indices(self, __ielem: evaluable.Array) -> evaluable.Array:
def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
return self._parent.get_evaluable_weights(evaluable.Take(self._indices, __ielem))

def update_lower_args(self, __ielem: evaluable.Array, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> Tuple[_PointsShape, _TransformChainsMap, _CoordinatesMap]:
return self._parent.update_lower_args(evaluable.Take(self._indices, __ielem), points_shape, transform_chains, coordinates)
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
return self._parent.get_lower_args(evaluable.Take(self._indices, __ielem))

def get_element_tri(self, __ielem: int) -> numpy.ndarray:
if not 0 <= __ielem < self.nelems:
Expand Down Expand Up @@ -878,11 +864,10 @@ def __init__(self, integrand: function.Array, sample: Sample) -> None:
self._sample = sample
super().__init__(shape=integrand.shape, dtype=float if integrand.dtype in (bool, int) else integrand.dtype, spaces=integrand.spaces - frozenset(sample.spaces), arguments=integrand.arguments)

def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
def lower(self, args: function.LowerArgs) -> evaluable.Array:
ielem = evaluable.loop_index('_sample_' + '_'.join(self._sample.spaces), self._sample.nelems)
points_shape, transform_chains, coordinates = self._sample.update_lower_args(ielem, points_shape, transform_chains, coordinates)
weights = self._sample.get_evaluable_weights(ielem)
integrand = self._integrand.lower(points_shape, transform_chains, coordinates)
integrand = self._integrand.lower(args | self._sample.get_lower_args(ielem))
elem_integral = evaluable.einsum('B,ABC->AC', weights, integrand, B=weights.ndim, C=self.ndim)
return evaluable.loop_sum(elem_integral, ielem)

Expand All @@ -894,13 +879,13 @@ def __init__(self, func: function.Array, sample: _TransformChainsSample) -> None
self._sample = sample
super().__init__(shape=(sample.npoints, *func.shape), dtype=func.dtype, spaces=func.spaces - frozenset(sample.spaces), arguments=func.arguments)

def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
axis = len(points_shape)
def lower(self, args: function.LowerArgs) -> evaluable.Array:
axis = len(args.points_shape)
ielem = evaluable.loop_index('_sample_' + '_'.join(self._sample.spaces), self._sample.nelems)
points_shape, transform_chains, coordinates = self._sample.update_lower_args(ielem, points_shape, transform_chains, coordinates)
func = self._func.lower(points_shape, transform_chains, coordinates)
func = evaluable.Transpose.to_end(func, *range(axis, len(points_shape)))
for i in range(len(points_shape) - axis - 1):
args |= self._sample.get_lower_args(ielem)
func = self._func.lower(args)
func = evaluable.Transpose.to_end(func, *range(axis, len(args.points_shape)))
for i in range(len(args.points_shape) - axis - 1):
func = evaluable.Ravel(func)
func = evaluable.loop_concatenate(func, ielem)
return evaluable.Transpose.from_end(func, axis)
Expand All @@ -914,9 +899,9 @@ def __init__(self, func: function.Array, indices: evaluable.Array) -> None:
assert indices.ndim == 1 and func.shape[0] == indices.shape[0].__index__()
super().__init__(shape=func.shape, dtype=func.dtype, spaces=func.spaces, arguments=func.arguments)

def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
func = self._func.lower(points_shape, transform_chains, coordinates)
axis = len(points_shape)
def lower(self, args: function.LowerArgs) -> evaluable.Array:
func = self._func.lower(args)
axis = len(args.points_shape)
return evaluable.Transpose.from_end(evaluable.Inflate(evaluable.Transpose.to_end(func, axis), self._indices, self._indices.shape[0]), axis)


Expand All @@ -926,20 +911,20 @@ def __init__(self, sample: _TransformChainsSample) -> None:
self._sample = sample
super().__init__(shape=(sample.npoints,), dtype=float, spaces=frozenset({sample.space}), arguments={})

def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
aligned_space_coords = coordinates[self._sample.space]
assert aligned_space_coords.ndim == len(points_shape) + 1
def lower(self, args: function.LowerArgs) -> evaluable.Array:
aligned_space_coords = args.coordinates[self._sample.space]
assert aligned_space_coords.ndim == len(args.points_shape) + 1
space_coords, where = evaluable.unalign(aligned_space_coords)
# Reinsert the coordinate axis, the last axis of `aligned_space_coords`, or
# make sure this is the last axis of `space_coords`.
if len(points_shape) not in where:
if len(args.points_shape) not in where:
space_coords = evaluable.InsertAxis(space_coords, aligned_space_coords.shape[-1])
where += len(points_shape),
elif where[-1] != len(points_shape):
elif where[-1] != len(args.points_shape):
space_coords = evaluable.Transpose(space_coords, numpy.argsort(where))
where = tuple(sorted(where))

chain = transform_chains[self._sample.space][0]
chain = args.transform_chains[self._sample.space][0]
index, tail = chain.index_with_tail_in(self._sample.transforms[0])
coords = tail.apply(space_coords)
expect = self._sample.points.get_evaluable_coords(index)
Expand All @@ -949,7 +934,7 @@ def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMa

# Realign the points axes. The coordinate axis of `aligned_space_coords` is
# replaced by a dofs axis in the aligned basis, hence we can reuse `where`.
return evaluable.align(basis, where, (*points_shape, self._sample.npoints))
return evaluable.align(basis, where, (*args.points_shape, self._sample.npoints))


def _offsets(pointsseq: PointsSequence) -> evaluable.Array:
Expand Down
10 changes: 6 additions & 4 deletions nutils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def locate(self, geom, coords, *, tol=0, eps=0, maxiter=0, arguments=None, weigh

class _EmptyUnlowerable(function.Array):

def lower(self, points_shape, transform_chains, coordinates) -> evaluable.Array:
def lower(self, args: function.LowerArgs) -> evaluable.Array:
raise ValueError('cannot lower')


Expand Down Expand Up @@ -1419,7 +1419,8 @@ def trim(self, levelset, maxrefine, ndivisions=8, name='trimmed', leveltopo=None
if leveltopo is None:
ielem_arg = evaluable.Argument('_trim_index', (), dtype=int)
coordinates = self.references.getpoints('vertex', maxrefine).get_evaluable_coords(ielem_arg)
levelset = levelset.lower(coordinates.shape[:-1], {self.space: (self.transforms.get_evaluable(ielem_arg), self.opposites.get_evaluable(ielem_arg))}, {self.space: coordinates}).optimized_for_numpy
transform_chains = self.transforms.get_evaluable(ielem_arg), self.opposites.get_evaluable(ielem_arg)
levelset = levelset.lower(function.LowerArgs.for_space(self.space, transform_chains, coordinates)).optimized_for_numpy
with log.iter.percentage('trimming', range(len(self)), self.references) as items:
for ielem, ref in items:
levels = levelset.eval(_trim_index=ielem, **arguments)
Expand All @@ -1428,7 +1429,7 @@ def trim(self, levelset, maxrefine, ndivisions=8, name='trimmed', leveltopo=None
log.info('collecting leveltopo elements')
coordinates = evaluable.Points(evaluable.NPoints(), self.ndims)
transform_chain = transform.EvaluableTransformChain.from_argument('trans', self.transforms.todims, self.transforms.fromdims)
levelset = levelset.lower(coordinates.shape[:-1], {self.space: (transform_chain, transform_chain)}, {self.space: coordinates}).optimized_for_numpy
levelset = levelset.lower(function.LowerArgs.for_space(self.space, (transform_chain, transform_chain), coordinates)).optimized_for_numpy
bins = [set() for ielem in range(len(self))]
for trans in leveltopo.transforms:
ielem, tail = self.transforms.index_with_tail(trans)
Expand Down Expand Up @@ -1496,7 +1497,8 @@ def locate(self, geom, coords, *, tol=0, eps=0, maxiter=0, arguments=None, weigh
points = parallel.shempty((len(coords), len(geom)), dtype=float)
_ielem = evaluable.Argument('_locate_ielem', shape=(), dtype=int)
_point = evaluable.Argument('_locate_point', shape=(self.ndims,))
egeom = geom.lower((), {self.space: (self.transforms.get_evaluable(_ielem), self.opposites.get_evaluable(_ielem))}, {self.space: _point})
transform_chains = self.transforms.get_evaluable(_ielem), self.opposites.get_evaluable(_ielem)
egeom = geom.lower(function.LowerArgs.for_space(self.space, transform_chains, _point))
xJ = evaluable.Tuple((egeom, evaluable.derivative(egeom, _point))).simplified
if skip_missing:
if weights is not None:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,13 @@ class Unlower(TestCase):
def test(self):
e = evaluable.Argument('arg', (2, 3, 4, 5), int)
arguments = {'arg': ((2, 3), int)}
f = function._Unlower(e, frozenset(), arguments, (2, 3), {}, {})
f = function._Unlower(e, frozenset(), arguments, function.LowerArgs((2, 3), {}, {}))
self.assertEqual(f.shape, (4, 5))
self.assertEqual(f.dtype, int)
self.assertEqual(f.arguments, arguments)
self.assertEqual(f.lower((2, 3), {}, {}), e)
self.assertEqual(f.lower(function.LowerArgs((2, 3), {}, {})), e)
with self.assertRaises(ValueError):
f.lower((3, 4), {}, {})
f.lower(function.LowerArgs((3, 4), {}, {}))


class Custom(TestCase):
Expand All @@ -394,12 +394,12 @@ def assertEvalAlmostEqual(self, factual, fdesired, **args):
transform_chains = dict(test=(transform.EvaluableTransformChain.from_argument('test', 2, 2),)*2)
with self.subTest('1d-points'):
coords = evaluable.Zeros((5, 2), float)
lower_args = coords.shape[:-1], transform_chains, dict(test=coords)
self.assertAllAlmostEqual(factual.lower(*lower_args).eval(**args), fdesired.lower(*lower_args).eval(**args))
lower_args = function.LowerArgs(coords.shape[:-1], transform_chains, dict(test=coords))
self.assertAllAlmostEqual(factual.lower(lower_args).eval(**args), fdesired.lower(lower_args).eval(**args))
with self.subTest('2d-points'):
coords = evaluable.Zeros((5, 6, 2), float)
lower_args = coords.shape[:-1], transform_chains, dict(test=coords)
self.assertAllAlmostEqual(factual.lower(*lower_args).eval(**args), fdesired.lower(*lower_args).eval(**args))
lower_args = function.LowerArgs(coords.shape[:-1], transform_chains, dict(test=coords))
self.assertAllAlmostEqual(factual.lower(lower_args).eval(**args), fdesired.lower(lower_args).eval(**args))

def assertMultipy(self, leftval, rightval):

Expand Down Expand Up @@ -1148,7 +1148,7 @@ def test_lower(self):
ref = element.PointReference() if self.basis.coords.shape[0] == 0 else element.LineReference()**self.basis.coords.shape[0]
points = ref.getpoints('bezier', 4)
coordinates = evaluable.Constant(points.coords)
lowered = self.basis.lower(coordinates.shape[:-1], dict(X=(self.checktransforms.get_evaluable(evaluable.Argument('ielem', (), int)),)*2), dict(X=coordinates))
lowered = self.basis.lower(function.LowerArgs(coordinates.shape[:-1], dict(X=(self.checktransforms.get_evaluable(evaluable.Argument('ielem', (), int)),)*2), dict(X=coordinates)))
with _builtin_warnings.catch_warnings():
_builtin_warnings.simplefilter('ignore', category=evaluable.ExpensiveEvaluationWarning)
for ielem in range(self.checknelems):
Expand Down

0 comments on commit 53e6fef

Please sign in to comment.