Skip to content

Commit

Permalink
Merge pull request #675 from MartinEssink/refine-memory-fix
Browse files Browse the repository at this point in the history
Fix memory overflow on refinement
  • Loading branch information
gertjanvanzwieten committed Apr 29, 2022
2 parents 520eed8 + f600436 commit da57b58
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
13 changes: 5 additions & 8 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,9 +3405,9 @@ def dotarg(__argname: str, *arrays: IntoArray, shape: Tuple[int, ...] = (), dtyp
# BASES


def _int_or_vec(f, self, arg, argname, nargs, nvals):
def _int_or_vec(f, arg, argname, nargs, nvals):
if isinstance(arg, numbers.Integral):
return f(self, int(numeric.normdim(nargs, arg)))
return f(int(numeric.normdim(nargs, arg)))
if numeric.isboolarray(arg):
if arg.shape != (nargs,):
raise IndexError('{} has invalid shape'.format(argname))
Expand All @@ -3420,24 +3420,21 @@ def _int_or_vec(f, self, arg, argname, nargs, nvals):
arg = numpy.unique(arg)
if arg[0] < 0 or arg[-1] >= nargs:
raise IndexError('{} out of bounds'.format(argname))
mask = numpy.zeros(nvals, dtype=bool)
for d in arg:
mask[numpy.asarray(f(self, d))] = True
return mask.nonzero()[0]
return functools.reduce(numpy.union1d, map(f, arg))
raise IndexError('invalid {}'.format(argname))


def _int_or_vec_dof(f):
@functools.wraps(f)
def wrapped(self, dof: Union[numbers.Integral, numpy.ndarray]) -> numpy.ndarray:
return _int_or_vec(f, self, arg=dof, argname='dof', nargs=self.ndofs, nvals=self.nelems)
return _int_or_vec(f.__get__(self), arg=dof, argname='dof', nargs=self.ndofs, nvals=self.nelems)
return wrapped


def _int_or_vec_ielem(f):
@functools.wraps(f)
def wrapped(self, ielem: Union[numbers.Integral, numpy.ndarray]) -> numpy.ndarray:
return _int_or_vec(f, self, arg=ielem, argname='ielem', nargs=self.nelems, nvals=self.ndofs)
return _int_or_vec(f.__get__(self), arg=ielem, argname='ielem', nargs=self.nelems, nvals=self.ndofs)
return wrapped


Expand Down
2 changes: 1 addition & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def test_dofs_array(self):
indices, = numpy.where(mask)
for value in mask, indices:
with self.subTest(tuple(value)):
self.assertEqual(self.basis.get_dofs(value).tolist(), list(sorted(set(itertools.chain.from_iterable(self.checkdofs[i] for i in indices)))))
self.assertEqual(sorted(self.basis.get_dofs(value)), sorted(set(itertools.chain.from_iterable(self.checkdofs[i] for i in indices))))

def test_dofs_intarray_outofbounds(self):
for i in [-1, self.checknelems]:
Expand Down

0 comments on commit da57b58

Please sign in to comment.