Skip to content

Commit

Permalink
implement ArrayFunc getitem on dof axis
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Feb 6, 2015
1 parent 756c19e commit e99f1f5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 21 deletions.
70 changes: 65 additions & 5 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def __init__( self, args, length ):
self.ndim = 1
Evaluable.__init__( self, args=args )

def __str__( self ):
return '%s<%s>' % ( self.__class__.__name__, ','.join( str(n) for n in self.shape ) )

class DofMap( IndexVector ):
'dof axis'

Expand Down Expand Up @@ -407,7 +410,10 @@ def __getitem__( self, item ):
arr = self
while myitem:
it = myitem.pop(0)
if numeric.isint(it): # retrieve one item from axis
if isinstance(it,numpy.ndarray): # numpy first because of 'equals issues'
arr = take( arr, it, n )
n += 1
elif numeric.isint(it): # retrieve one item from axis
arr = get( arr, n, it )
elif it == _: # insert a singleton axis
arr = insert( arr, n )
Expand All @@ -422,7 +428,7 @@ def __getitem__( self, item ):
elif isinstance(it,slice) and it.step in (1,None) and it.stop == ( it.start or 0 ) + 1: # special case: unit length slice
arr = insert( get( arr, n, it.start or 0 ), n )
n += 1
elif isinstance(it,(slice,list,tuple,numpy.ndarray)): # modify axis (shorten, extend or renumber one axis)
elif isinstance(it,(slice,list,tuple)): # modify axis (shorten, extend or renumber one axis)
arr = take( arr, it, n )
n += 1
else:
Expand Down Expand Up @@ -819,6 +825,38 @@ def _localgradient( self, ndims ):
return grad if ndims == self.ndims \
else dot( grad[...,_], Transform( self.ndims, ndims, self.side ), axes=-2 )

def _take( self, indices, axis ):
if axis != 0:
return
assert isinstance( indices, DofMap )
assert indices.shape[0] == self.shape[0]
stdmap = {}
for trans, stdkeep in self.stdmap.items():
ind = indices.dofmap[trans]
assert all( numpy.diff( ind ) > 0 )
nshapes = sum( 0 if not std else std.nshapes if keep is None else sum(keep) for std, keep in stdkeep )
where = numpy.zeros( nshapes, dtype=bool )
where[ind] = True
newstdkeep = []
for std, keep in stdkeep:
if std:
if keep is None:
n = std.nshapes
keep = where[:n]
else:
n = sum(keep)
keep = keep.copy()
keep[keep] = where[:n]
if not keep.any():
std = None
elif keep.all():
keep = None
where = where[n:]
newstdkeep.append(( std, keep ))
assert not where.size
stdmap[trans] = newstdkeep
return Function( self.ndims, stdmap, self.igrad, indices.target, side=self.side )

class Choose( ArrayFunc ):
'piecewise function'

Expand Down Expand Up @@ -1975,9 +2013,27 @@ def _takediag( self ):

def _take( self, index, axis ):
if axis == self.axis:
assert index == self.dofmap
return self.func
return inflate( take( self.func, index, axis ), self.dofmap, self.axis )
if index == self.dofmap:
return self.func
assert numeric.isintarray(index) and index.ndim == 1
if self.dofmap.offset != 0:
raise NotImplementedError
reverse_index = numpy.empty( self.shape[axis], dtype=int )
reverse_index[:] = -1
reverse_index[index] = numpy.arange( len(index) )
globaldofs = {}
localdofs = {}
for trans, dofs in self.dofmap.dofmap.items():
newdofs = reverse_index[dofs]
keep = newdofs != -1
globaldofs[trans] = newdofs[keep]
localdofs[trans], = numpy.where(keep)
strlen = '~%d'%len(index)
dofmap = DofMap( globaldofs, axis=strlen, target=len(index), side=self.dofmap.side )
index = DofMap( localdofs, axis=self.dofmap.shape[0], target=strlen, side=self.dofmap.side )
else:
dofmap = self.dofmap
return inflate( take( self.func, index, axis ), dofmap, self.axis )

def _diagonalize( self ):
assert self.axis < self.ndim-1
Expand Down Expand Up @@ -3055,6 +3111,10 @@ def take( arg, index, axis ):
else:
index = numpy.asarray( index )
assert numpy.all( index >= 0 )

if numeric.isboolarray(index) and index.ndim == 1 and len(index) == arg.shape[axis]:
index, = numpy.where( index )

assert numeric.isintarray(index) and index.ndim == 1 and len(index) > 0

if len(index) == arg.shape[axis] and all( index == numpy.arange(arg.shape[axis]) ):
Expand Down
4 changes: 2 additions & 2 deletions nutils/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ def eig( A ):
return L, V

def isbool( a ):
return isboolarray( a ) and a.ndim == 0 or numpy.issubdtype( type(a), numpy.bool )
return isboolarray( a ) and a.ndim == 0 or type(a) == bool

def isboolarray( a ):
return isinstance( a, numpy.ndarray ) and numpy.issubdtype( a.dtype, numpy.bool )
return isinstance( a, numpy.ndarray ) and a.dtype == bool

def isint( a ):
return isintarray( a ) and a.ndim == 0 or numpy.issubdtype( type(a), numpy.integer )
Expand Down
19 changes: 5 additions & 14 deletions nutils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,21 +1283,12 @@ def __getitem__( self, key ):
return itemtopo

def prune_basis( self, basis ):
((dofaxis,),func), = function.blocks( basis )

used = numpy.zeros( len(basis), dtype=bool )
for elem in self:
used[ dofaxis.dofmap[elem.transform] ] = True

ndofs = used.sum()
renumber = numpy.empty( len(basis), dtype=int )
renumber[:] = ndofs # invalid index
renumber[used] = numpy.arange(ndofs)

nmap = { elem.transform: renumber[ dofaxis.dofmap[elem.transform] ] for elem in self }
fmap = { elem.transform: func.stdmap[elem.transform] for elem in self }

return function.function( fmap=fmap, nmap=nmap, ndofs=ndofs, ndims=self.ndims )
for axes, func in function.blocks( basis ):
dofmap = axes[0].dofmap
for elem in self:
used[ dofmap[elem.transform] ] = True
return basis[used]

@log.title
def basis( self, name, *args, **kwargs ):
Expand Down

0 comments on commit e99f1f5

Please sign in to comment.