Skip to content

Commit

Permalink
remove apply_annotations from transformseq module (#771)
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Jan 30, 2023
2 parents 0b2acb1 + 64d5f6f commit 8128a17
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 75 deletions.
12 changes: 6 additions & 6 deletions nutils/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,16 @@ def simplex(nodes, cnodes, coords, tags, btags, ptags, name='simplex', *, space=
opposites.append(topo.transforms[ioppelem] + (transform.SimplexEdge(ndims, tuple(connectivity[ioppelem]).index(ielem)),))
for groups, (simplices, transforms, opposites) in (bgroups, bitems), (igroups, iitems):
if simplices:
transforms = transformseq.PlainTransforms(transforms, ndims, ndims-1)
opposites = transforms if opposites is None else transformseq.PlainTransforms(opposites, ndims, ndims-1)
transforms = transformseq.PlainTransforms(tuple(transforms), ndims, ndims-1)
opposites = transforms if opposites is None else transformseq.PlainTransforms(tuple(opposites), ndims, ndims-1)
groups[name] = topology.SimplexTopology(space, numpy.asarray(simplices), transforms, opposites)

pgroups = {}
if ptags:
ptrans = [transform.Point(types.arraydata(offset)) for offset in numpy.eye(ndims+1)[:, 1:]]
pmap = {inode: numpy.array(numpy.equal(nodes, inode).nonzero()).T for inode in set.union(*map(set, ptags.values()))}
for pname, inodes in ptags.items():
ptransforms = transformseq.PlainTransforms([topo.transforms[ielem] + (ptrans[ivertex],) for inode in inodes for ielem, ivertex in pmap[inode]], ndims, 0)
ptransforms = transformseq.PlainTransforms(tuple((*topo.transforms[ielem], ptrans[ivertex]) for inode in inodes for ielem, ivertex in pmap[inode]), ndims, 0)
preferences = References.uniform(element.getsimplex(0), len(ptransforms))
pgroups[pname] = topology.TransformChainsTopology(space, preferences, ptransforms, ptransforms)

Expand Down Expand Up @@ -588,12 +588,12 @@ def simplex(nodes, cnodes, coords, tags, btags, ptags, name='simplex', *, space=
opposites.append(topo.transforms[ioppelem] + (transform.SimplexEdge(ndims, ioppedge),))
for groups, (simplices, transforms, opposites) in (vbgroups, bitems), (vigroups, iitems):
if simplices:
transforms = transformseq.PlainTransforms(transforms, ndims, ndims-1)
opposites = transformseq.PlainTransforms(opposites, ndims, ndims-1) if len(opposites) == len(transforms) else transforms
transforms = transformseq.PlainTransforms(tuple(transforms), ndims, ndims-1)
opposites = transformseq.PlainTransforms(tuple(opposites), ndims, ndims-1) if len(opposites) == len(transforms) else transforms
groups[bname] = topology.SimplexTopology(space, numpy.asarray(simplices), transforms, opposites)
vpgroups = {}
for pname, inodes in ptags.items():
ptransforms = transformseq.PlainTransforms([topo.transforms[ielem] + (ptrans[ivertex],) for inode in inodes for ielem, ivertex in pmap[inode] if keep[ielem]], ndims, 0)
ptransforms = transformseq.PlainTransforms(tuple((*topo.transforms[ielem], ptrans[ivertex]) for inode in inodes for ielem, ivertex in pmap[inode] if keep[ielem]), ndims, 0)
preferences = References.uniform(element.getsimplex(0), len(ptransforms))
vpgroups[pname] = topology.TransformChainsTopology(space, preferences, ptransforms, ptransforms)
vgroups[name] = vtopo.withgroups(bgroups=vbgroups, igroups=vigroups, pgroups=vpgroups)
Expand Down
12 changes: 6 additions & 6 deletions nutils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ def __init__(self, space: str, root: transform.TransformItem, axes: Sequence[tra
if nbounds == 0:
opposites = transforms
else:
axes = [axis.opposite(nbounds-1) for axis in self.axes]
axes = tuple(axis.opposite(nbounds-1) for axis in self.axes)
opposites = transformseq.StructuredTransforms(self.root, axes, self.nrefine)

super().__init__(space, references, transforms, opposites)
Expand Down Expand Up @@ -2548,22 +2548,22 @@ def boundary(self):
edgeref -= oppref.edge_refs[ioppedge]
if edgeref:
trimmedreferences.append(edgeref)
trimmedtransforms.append(elemtrans+(edgetrans,))
trimmedopposites.append(self.basetopo.transforms[ioppelem]+(oppref.edge_transforms[ioppedge],))
trimmedtransforms.append(transform.canonical((*elemtrans, edgetrans)))
trimmedopposites.append(transform.canonical((*self.basetopo.transforms[ioppelem], oppref.edge_transforms[ioppedge])))
# The last edges of newref (beyond the number of edges of the original)
# cannot have opposites and are added to the trimmed group directly.
for edgetrans, edgeref in newref.edges[len(ioppelems):]:
trimmedreferences.append(edgeref)
trimmedtransforms.append(elemtrans+(edgetrans,))
trimmedopposites.append(elemtrans+(edgetrans.flipped,))
trimmedtransforms.append(transform.canonical((*elemtrans, edgetrans)))
trimmedopposites.append(transform.canonical((*elemtrans, edgetrans.flipped)))
origboundary = SubsetTopology(baseboundary, brefs)
if isinstance(self.newboundary, TransformChainsTopology):
trimmedbrefs = [ref.empty for ref in self.newboundary.references]
for ref, trans in zip(trimmedreferences, trimmedtransforms):
trimmedbrefs[self.newboundary.transforms.index(trans)] = ref
trimboundary = SubsetTopology(self.newboundary, trimmedbrefs)
else:
trimboundary = TransformChainsTopology(self.space, References.from_iter(trimmedreferences, self.ndims-1), transformseq.PlainTransforms(trimmedtransforms, self.transforms.todims, self.ndims-1), transformseq.PlainTransforms(trimmedopposites, self.transforms.todims, self.ndims-1))
trimboundary = TransformChainsTopology(self.space, References.from_iter(trimmedreferences, self.ndims-1), transformseq.PlainTransforms(tuple(trimmedtransforms), self.transforms.todims, self.ndims-1), transformseq.PlainTransforms(tuple(trimmedopposites), self.transforms.todims, self.ndims-1))
return DisjointUnionTopology([trimboundary, origboundary], names=[self.newboundary] if isinstance(self.newboundary, str) else [])

@cached_property
Expand Down
4 changes: 4 additions & 0 deletions nutils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def canonical(chain):
return tuple(items)


def iscanonical(chain):
return all(b.swapdown(a) == None for a, b in util.pairwise(chain))


def uppermost(chain):
# bring to highest ndims possible
n = len(chain)
Expand Down
83 changes: 51 additions & 32 deletions nutils/transformseq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The transformseq module."""

from typing import Tuple
from numbers import Integral
from . import types, numeric, _util as util, transform, element
from .elementseq import References
from .transform import TransformChain
Expand Down Expand Up @@ -44,8 +45,9 @@ class supports indexing, iterating and has an :meth:`index` method. In

__slots__ = 'todims', 'fromdims'

@types.apply_annotations
def __init__(self, todims: types.strictint, fromdims: types.strictint):
def __init__(self, todims: Integral, fromdims: Integral):
assert isinstance(todims, Integral), f'todims={todims!r}'
assert isinstance(fromdims, Integral), f'fromdims={fromdims!r}'
if not 0 <= fromdims <= todims:
raise ValueError('invalid dimensions')
self.todims = todims
Expand All @@ -69,7 +71,7 @@ def __getitem__(self, index):
return self
if index.step < 0:
raise NotImplementedError('reordering the sequence is not yet implemented')
return MaskedTransforms(self, numpy.arange(index.start, index.stop, index.step))
return MaskedTransforms(self, types.arraydata(numpy.arange(index.start, index.stop, index.step)))
elif numeric.isintarray(index):
if index.ndim != 1:
raise IndexError('invalid index')
Expand All @@ -82,12 +84,12 @@ def __getitem__(self, index):
raise ValueError('repeating an element is not allowed')
if not numpy.all(numpy.greater(dindex, 0)):
s = numpy.argsort(index)
return ReorderedTransforms(self[index[s]], numpy.argsort(s))
return ReorderedTransforms(self[index[s]], types.arraydata(numpy.argsort(s)))
if len(index) == 0:
return EmptyTransforms(self.todims, self.fromdims)
if len(index) == len(self):
return self
return MaskedTransforms(self, index)
return MaskedTransforms(self, types.arraydata(index))
elif numeric.isboolarray(index):
if index.shape != (len(self),):
raise IndexError('mask has invalid shape')
Expand All @@ -96,7 +98,7 @@ def __getitem__(self, index):
if numpy.all(index):
return self
index, = numpy.where(index)
return MaskedTransforms(self, index)
return MaskedTransforms(self, types.arraydata(index))
else:
raise IndexError('invalid index')

Expand Down Expand Up @@ -133,7 +135,7 @@ def index_with_tail(self, trans):
Consider the following plain sequence of two index transforms:
>>> from nutils.transform import Index, SimplexChild
>>> transforms = PlainTransforms([(Index(1, 0),), (Index(1, 1),)], 1, 1)
>>> transforms = PlainTransforms(((Index(1, 0),), (Index(1, 1),)), 1, 1)
Calling :meth:`index_with_tail` with the first transform gives index ``0``
and no tail:
Expand Down Expand Up @@ -178,7 +180,7 @@ def index(self, trans):
Consider the following plain sequence of two index transforms:
>>> from nutils.transform import Index, SimplexChild
>>> transforms = PlainTransforms([(Index(1, 0),), (Index(1, 1),)], 1, 1)
>>> transforms = PlainTransforms(((Index(1, 0),), (Index(1, 1),)), 1, 1)
Calling :meth:`index` with the first transform gives index ``0``:
Expand Down Expand Up @@ -344,15 +346,17 @@ class PlainTransforms(Transforms):
Parameters
----------
transforms : :class:`tuple` of :class:`~nutils.transform.TransformItem` objects
The sequence of transforms.
The sequence of transforms in canonical order.
fromdims : :class:`int`
The number of dimensions all ``transforms`` map from.
'''

__slots__ = '_transforms', '_sorted', '_indices'

@types.apply_annotations
def __init__(self, transforms: types.tuple[transform.canonical], todims: types.strictint, fromdims: types.strictint):
def __init__(self, transforms: Tuple[Tuple[transform.TransformItem, ...], ...], todims: Integral, fromdims: Integral):
assert isinstance(transforms, tuple) and all(isinstance(items, tuple) and all(isinstance(item, transform.TransformItem) for item in items) and transform.iscanonical(items) for items in transforms), f'transforms={transforms!r}'
assert isinstance(todims, Integral), f'todims={todims!r}'
assert isinstance(fromdims, Integral), f'fromdims={fromdims!r}'
transforms_todims = set(trans[0].todims for trans in transforms)
transforms_fromdims = set(trans[-1].fromdims for trans in transforms)
if not (transforms_todims <= {todims}):
Expand Down Expand Up @@ -406,8 +410,10 @@ class IndexTransforms(Transforms):

__slots__ = '_length', '_offset'

@types.apply_annotations
def __init__(self, ndims: types.strictint, length: int, offset: int = 0):
def __init__(self, ndims: Integral, length: Integral, offset: Integral = 0):
assert isinstance(ndims, Integral), f'ndims={ndims!r}'
assert isinstance(length, Integral), f'length={length!r}'
assert isinstance(offset, Integral), f'offset={offset!r}'
self._length = length
self._offset = offset
super().__init__(ndims, ndims)
Expand All @@ -432,7 +438,10 @@ class Axis(types.Singleton):

__slots__ = 'i', 'j', 'mod'

def __init__(self, i: types.strictint, j: types.strictint, mod: types.strictint):
def __init__(self, i: Integral, j: Integral, mod: Integral):
assert isinstance(i, Integral), f'i={i!r}'
assert isinstance(j, Integral), f'j={j!r}'
assert isinstance(mod, Integral), f'mod={mod!r}'
assert i <= j
self.i = i
self.j = j
Expand Down Expand Up @@ -462,8 +471,8 @@ class DimAxis(Axis):
__slots__ = 'isperiodic'
isdim = True

@types.apply_annotations
def __init__(self, i: types.strictint, j: types.strictint, mod: types.strictint, isperiodic: bool):
def __init__(self, i: Integral, j: Integral, mod: Integral, isperiodic: bool):
assert isinstance(isperiodic, bool), f'isperiodic={isperiodic!r}'
super().__init__(i, j, mod)
self.isperiodic = isperiodic

Expand Down Expand Up @@ -498,8 +507,9 @@ class IntAxis(Axis):
__slots__ = 'ibound', 'side'
isdim = False

@types.apply_annotations
def __init__(self, i: types.strictint, j: types.strictint, mod: types.strictint, ibound: types.strictint, side: bool):
def __init__(self, i: Integral, j: Integral, mod: Integral, ibound: Integral, side: bool):
assert isinstance(ibound, Integral), f'ibound={ibound!r}'
assert isinstance(side, Integral), f'side={side!r}'
super().__init__(i, j, mod)
self.ibound = ibound
self.side = side
Expand Down Expand Up @@ -530,8 +540,11 @@ class StructuredTransforms(Transforms):

__slots__ = '_root', '_axes', '_nrefine', '_etransforms', '_ctransforms', '_cindices'

@types.apply_annotations
def __init__(self, root: transform.stricttransformitem, axes: types.tuple[types.strict[Axis]], nrefine: types.strictint):
def __init__(self, root: transform.TransformItem, axes: Tuple[Axis, ...], nrefine: Integral):
assert isinstance(root, transform.TransformItem), f'root={root!r}'
assert isinstance(axes, tuple) and all(isinstance(axis, Axis) for axis in axes), f'axes={axes!r}'
assert isinstance(nrefine, Integral), f'nrefine={nrefine!r}'

self._root = root
self._axes = axes
self._nrefine = nrefine
Expand Down Expand Up @@ -619,9 +632,9 @@ class MaskedTransforms(Transforms):

__slots__ = '_parent', '_mask', '_indices'

@types.apply_annotations
def __init__(self, parent: stricttransforms, indices: types.arraydata):
assert indices.dtype == int
def __init__(self, parent: Transforms, indices: types.arraydata):
assert isinstance(parent, Transforms), f'parent={parent!r}'
assert isinstance(indices, types.arraydata) and indices.dtype == int, f'indices={indices!r}'
self._parent = parent
self._indices = numpy.asarray(indices)
super().__init__(parent.todims, parent.fromdims)
Expand Down Expand Up @@ -661,9 +674,9 @@ class ReorderedTransforms(Transforms):
__slots__ = '_parent', '_mask', '_indices'
__cache__ = '_rindices'

@types.apply_annotations
def __init__(self, parent: stricttransforms, indices: types.arraydata):
assert indices.dtype == int
def __init__(self, parent: Transforms, indices: types.arraydata):
assert isinstance(parent, Transforms), f'parent={parent!r}'
assert isinstance(indices, types.arraydata) and indices.dtype == int, f'indices={indices!r}'
self._parent = parent
self._indices = numpy.asarray(indices)
super().__init__(parent.todims, parent.fromdims)
Expand Down Expand Up @@ -713,8 +726,11 @@ class DerivedTransforms(Transforms):
__slots__ = '_parent', '_parent_references', '_derived_transforms'
__cache__ = '_offsets'

@types.apply_annotations
def __init__(self, parent: stricttransforms, parent_references: types.strict[References], derived_attribute: types.strictstr, fromdims: types.strictint):
def __init__(self, parent: Transforms, parent_references: References, derived_attribute: str, fromdims: Integral):
assert isinstance(parent, Transforms), f'parent={parent!r}'
assert isinstance(parent_references, References), f'parent_references={parent_references!r}'
assert isinstance(derived_attribute, str), f'derived_attribute={derived_attribute!r}'
assert isinstance(fromdims, Integral), f'fromdims={fromdims!r}'
if len(parent) != len(parent_references):
raise ValueError('`parent` and `parent_references` should have the same length')
if parent.fromdims != parent_references.ndims:
Expand Down Expand Up @@ -780,8 +796,11 @@ class UniformDerivedTransforms(Transforms):

__slots__ = '_parent', '_derived_transforms'

@types.apply_annotations
def __init__(self, parent: stricttransforms, parent_reference: element.strictreference, derived_attribute: types.strictstr, fromdims: types.strictint):
def __init__(self, parent: Transforms, parent_reference: element.Reference, derived_attribute: str, fromdims: Integral):
assert isinstance(parent, Transforms), f'parent={parent!r}'
assert isinstance(parent_reference, element.Reference), f'parent_reference={parent_reference!r}'
assert isinstance(derived_attribute, str), f'derived_attribute={derived_attribute!r}'
assert isinstance(fromdims, Integral), f'fromdims={fromdims!r}'
if parent.fromdims != parent_reference.ndims:
raise ValueError('`parent` and `parent_reference` have different dimensions')
self._parent = parent
Expand Down Expand Up @@ -826,8 +845,8 @@ class ChainedTransforms(Transforms):
__slots__ = '_items'
__cache__ = '_offsets'

@types.apply_annotations
def __init__(self, items: types.tuple[stricttransforms]):
def __init__(self, items: Tuple[Transforms, ...]):
assert isinstance(items, tuple) and all(isinstance(item, Transforms) for item in items), f'items={items!r}'
if len(items) == 0:
raise ValueError('Empty chain.')
if len(set(item.todims for item in items)) != 1:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,32 @@ def setUp(self):
super().setUp(trans=transform.Point(types.arraydata([1., 2., 3.])), linear=numpy.zeros((3, 0)), offset=[1., 2., 3.])


class swaps(TestCase):

def setUp(self):
self.chain = transform.SimplexChild(3, 2), transform.SimplexEdge(3, 0), transform.SimplexChild(2, 1), transform.SimplexChild(2, 1), transform.SimplexEdge(2, 0)

def assertMidpoint(self, chain):
midpoint = transform.apply(self.chain, numpy.array([.5]))
self.assertEqual(midpoint.tolist(), [0, 0.9375, 0.0625])

def test_canonical(self):
canonical = transform.SimplexEdge(3, 0), transform.SimplexEdge(2, 0), transform.SimplexChild(1, 0), transform.SimplexChild(1, 0), transform.SimplexChild(1, 0)
self.assertEqual(transform.canonical(self.chain), canonical)
self.assertMidpoint(canonical)
self.assertTrue(transform.iscanonical(canonical))

def test_promote(self):
promote = transform.SimplexEdge(3, 0), transform.SimplexChild(2, 1), transform.SimplexChild(2, 1), transform.SimplexChild(2, 1), transform.SimplexEdge(2, 0)
self.assertEqual(transform.promote(self.chain, 2), promote)
self.assertMidpoint(promote)
self.assertFalse(transform.iscanonical(promote))

def test_uppermost(self):
uppermost = transform.SimplexChild(3, 2), transform.SimplexChild(3, 2), transform.SimplexChild(3, 2), transform.SimplexEdge(3, 0), transform.SimplexEdge(2, 0)
self.assertEqual(transform.uppermost(self.chain), uppermost)
self.assertMidpoint(uppermost)
self.assertFalse(transform.iscanonical(uppermost))


del TestTransform, TestInvertible, TestUpdim

0 comments on commit 8128a17

Please sign in to comment.