/
evaluable.py
4783 lines (3755 loc) · 170 KB
/
evaluable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
The function module defines the :class:`Evaluable` class and derived objects,
commonly referred to as nutils functions. They represent mappings from a
:mod:`nutils.topology` onto Python space. The notabe class of :class:`Array`
objects map onto the space of Numpy arrays of predefined dimension and shape.
Most functions used in nutils applicatons are of this latter type, including the
geometry and function bases for analysis.
Nutils functions are essentially postponed python functions, stored in a tree
structure of input/output dependencies. Many :class:`Array` objects have
directly recognizable numpy equivalents, such as :class:`Sin` or
:class:`Inverse`. By not evaluating directly but merely stacking operations,
complex operations can be defined prior to entering a quadrature loop, allowing
for a higher level style programming. It also allows for automatic
differentiation and code optimization.
It is important to realize that nutils functions do not map for a physical
xy-domain but from a topology, where a point is characterized by the combination
of an element and its local coordinate. This is a natural fit for typical finite
element operations such as quadrature. Evaluation from physical coordinates is
possible only via inverting of the geometry function, which is a fundamentally
expensive and currently unsupported operation.
"""
import typing
if typing.TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
from . import debug_flags, util, types, numeric, cache, warnings, parallel, sparse
from ._graph import Node, RegularNode, DuplicatedLeafNode, InvisibleNode, Subgraph
import numpy
import sys
import itertools
import functools
import operator
import inspect
import numbers
import builtins
import re
import types as builtin_types
import abc
import collections.abc
import math
import treelog as log
import weakref
import time
import contextlib
import subprocess
import os
graphviz = os.environ.get('NUTILS_GRAPHVIZ')
isevaluable = lambda arg: isinstance(arg, Evaluable)
def strictevaluable(value):
if not isinstance(value, Evaluable):
raise ValueError('expected an object of type {!r} but got {!r} with type {!r}'.format(Evaluable.__qualname__, value, type(value).__qualname__))
return value
def simplified(value):
return strictevaluable(value).simplified
_type_order = bool, int, float, complex
asdtype = lambda arg: arg if any(arg is dtype for dtype in _type_order) else {'f': float, 'i': int, 'b': bool, 'c': complex}[numpy.dtype(arg).kind]
def asarray(arg):
if hasattr(type(arg), 'as_evaluable_array'):
return arg.as_evaluable_array
if _containsarray(arg):
return stack(arg, axis=0)
else:
return Constant(arg)
asarrays = types.tuple[asarray]
def asindex(arg):
arg = asarray(arg)
if arg.ndim or arg.dtype != int:
raise ValueError('argument is not an index: {}'.format(arg))
if arg._intbounds[0] < 0:
raise ValueError('index must be non-negative')
return arg
@types.apply_annotations
def equalindex(n: asindex, m: asindex):
'''Compare two array indices.
Returns `True` if the two indices are certainly equal, `False` if they are
certainly not equal, or `None` if equality cannot be determined at compile
time.
'''
if n is m:
return True
n = n.simplified
m = m.simplified
if n is m:
return True
if n.arguments != m.arguments:
return False
if n.isconstant: # implies m.isconstant
return int(n) == int(m)
asshape = types.tuple[asindex]
@types.apply_annotations
def equalshape(N: asshape, M: asshape):
'''Compare two array shapes.
Returns `True` if all indices are certainly equal, `False` if any indices are
certainly not equal, or `None` if equality cannot be determined at compile
time.
'''
if N == M:
return True
if len(N) != len(M):
return False
retval = True
for eq in map(equalindex, N, M):
if eq == False:
return False
if eq == None:
retval = None
return retval
class ExpensiveEvaluationWarning(warnings.NutilsInefficiencyWarning):
pass
def replace(func=None, depthfirst=False, recursive=False, lru=4):
'''decorator for deep object replacement
Generates a deep replacement method for general objects based on a callable
that is applied (recursively) on individual constructor arguments.
Args
----
func
Callable which maps an object onto a new object, or `None` if no
replacement is made. It must have one positional argument for the object,
and may have any number of additional positional and/or keyword
arguments.
depthfirst : :class:`bool`
If `True`, decompose each object as far a possible, then apply `func` to
all arguments as the objects are reconstructed. Otherwise apply `func`
directly on each new object that is encountered in the decomposition,
proceding only if the return value is `None`.
recursive : :class:`bool`
If `True`, repeat replacement for any object returned by `func` until it
returns `None`. Otherwise perform a single, non-recursive sweep.
lru : :class:`int`
Maximum size of the least-recently-used cache. A persistent weak-key
dictionary is maintained for every unique set of function arguments. When
the size of `lru` is reached, the least recently used cache is dropped.
Returns
-------
:any:`callable`
The method that searches the object to perform the replacements.
'''
if func is None:
return functools.partial(replace, depthfirst=depthfirst, recursive=recursive, lru=lru)
signature = inspect.signature(func)
arguments = [] # list of past function arguments, least recently used last
caches = [] # list of weak-key dictionaries matching arguments (above)
remember = object() # token to signal that rstack[-1] can be cached as the replacement of fstack[-1]
recreate = object() # token to signal that all arguments for object recreation are ready on rstack
pending = object() # token to hold the place of a cachable object pending creation
identity = object() # token to hold the place of the cache value in case it matches key, to avoid circular references
@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):
# retrieve or create a weak-key dictionary
bound = signature.bind(None, *funcargs, **funckwargs)
bound.apply_defaults()
try:
index = arguments.index(bound.arguments) # by using index, arguments need not be hashable
except ValueError:
index = -1
cache = weakref.WeakKeyDictionary()
else:
cache = caches[index]
if index != 0: # function arguments are not the most recent (possibly new)
if index > 0 or len(arguments) >= lru:
caches.pop(index) # pop matching (or oldest) item
arguments.pop(index)
caches.insert(0, cache) # insert popped (or new) item to front
arguments.insert(0, bound.arguments)
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
_stack = fstack if recursive else rstack
try:
while fstack:
obj = fstack.pop()
if obj is recreate:
args = [rstack.pop() for obj in range(fstack.pop())]
f = fstack.pop()
r = f(*args)
if depthfirst:
newr = func(r, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue
rstack.append(r)
continue
if obj is remember:
obj = fstack.pop()
cache[obj] = rstack[-1] if rstack[-1] is not obj else identity
continue
if isinstance(obj, (tuple, list, dict, set, frozenset)):
if not obj:
rstack.append(obj) # shortcut to avoid recreation of empty container
else:
fstack.append(lambda *x, T=type(obj): T(x))
fstack.append(len(obj))
fstack.append(recreate)
fstack.extend(obj if not isinstance(obj, dict) else obj.items())
continue
try:
r = cache[obj]
except KeyError: # object can be weakly cached, but isn't
cache[obj] = pending
fstack.append(obj)
fstack.append(remember)
except TypeError: # object cannot be referenced or is not hashable
pass
else: # object is in cache
if r is pending:
pending_objs = [k for k, v in cache.items() if v is pending]
index = pending_objs.index(obj)
raise Exception('{}@replace caught in a circular dependence\n'.format(func.__name__) + Tuple(pending_objs[index:]).asciitree().split('\n', 1)[1])
rstack.append(r if r is not identity else obj)
continue
if not depthfirst:
newr = func(obj, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue
try:
f, args = obj.__reduce__()
except: # obj cannot be reduced into a constructor and its arguments
rstack.append(obj)
else:
fstack.append(f)
fstack.append(len(args))
fstack.append(recreate)
fstack.extend(args)
assert len(rstack) == 1
finally:
while fstack:
if fstack.pop() is remember:
assert cache.pop(fstack.pop()) is pending
return rstack[0]
return wrapped
class Evaluable(types.Singleton):
'Base class'
__slots__ = '__args'
__cache__ = 'dependencies', 'arguments', 'ordereddeps', 'dependencytree', 'optimized_for_numpy', '_loop_concatenate_deps'
@types.apply_annotations
def __init__(self, args: types.tuple[strictevaluable]):
super().__init__()
self.__args = args
def evalf(self, *args):
raise NotImplementedError('Evaluable derivatives should implement the evalf method')
def evalf_withtimes(self, times, *args):
with times[self]:
return self.evalf(*args)
@property
def dependencies(self):
'''collection of all function arguments'''
deps = {}
for func in self.__args:
funcdeps = func.dependencies
deps.update(funcdeps)
deps[func] = len(funcdeps)
return types.frozendict(deps)
@property
def arguments(self):
'a frozenset of all arguments of this evaluable'
return frozenset().union(*(child.arguments for child in self.__args))
@property
def isconstant(self):
return EVALARGS not in self.dependencies
@property
def ordereddeps(self):
'''collection of all function arguments such that the arguments to
dependencies[i] can be found in dependencies[:i]'''
deps = self.dependencies.copy()
deps.pop(EVALARGS, None)
return tuple([EVALARGS] + sorted(deps, key=deps.__getitem__))
@property
def dependencytree(self):
'''lookup table of function arguments into ordereddeps, such that
ordereddeps[i].__args[j] == ordereddeps[dependencytree[i][j]], and
self.__args[j] == ordereddeps[dependencytree[-1][j]]'''
args = self.ordereddeps
return tuple(tuple(map(args.index, func.__args)) for func in args+(self,))
@property
def serialized(self):
return zip(self.ordereddeps[1:]+(self,), self.dependencytree[1:])
def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
args = tuple(arg._node(cache, subgraph, times) for arg in self.__args)
label = '\n'.join(filter(None, (type(self).__name__, self._node_details)))
cache[self] = node = RegularNode(label, args, {}, (type(self).__name__, times[self]), subgraph)
return node
@property
def _node_details(self):
return ''
def asciitree(self, richoutput=False):
'string representation'
return self._node({}, None, collections.defaultdict(_Stats)).generate_asciitree(richoutput)
def __str__(self):
return self.__class__.__name__
def eval(self, **evalargs):
'''Evaluate function on a specified element, point set.'''
values = [evalargs]
try:
values.extend(op.evalf(*[values[i] for i in indices]) for op, indices in self.serialized)
except KeyboardInterrupt:
raise
except Exception as e:
raise EvaluationError(self, values) from e
else:
return values[-1]
def eval_withtimes(self, times, **evalargs):
'''Evaluate function on a specified element, point set while measure time of each step.'''
values = [evalargs]
try:
values.extend(op.evalf_withtimes(times, *[values[i] for i in indices]) for op, indices in self.serialized)
except KeyboardInterrupt:
raise
except Exception as e:
raise EvaluationError(self, values) from e
else:
return values[-1]
@contextlib.contextmanager
def session(self, graphviz):
if graphviz is None:
yield self.eval
return
stats = collections.defaultdict(_Stats)
def eval(**args):
return self.eval_withtimes(stats, **args)
with log.context('eval'):
yield eval
node = self._node({}, None, stats)
maxtime = builtins.max(n.metadata[1].time for n in node.walk(set()))
tottime = builtins.sum(n.metadata[1].time for n in node.walk(set()))
aggstats = tuple((key, builtins.sum(v.time for v in values), builtins.sum(v.ncalls for v in values)) for key, values in util.gather(n.metadata for n in node.walk(set())))
fill_color = (lambda node: '0,{:.2f},1'.format(node.metadata[1].time/maxtime)) if maxtime else None
node.export_graphviz(fill_color=fill_color, dot_path=graphviz)
log.info('total time: {:.0f}ms\n'.format(tottime/1e6) + '\n'.join('{:4.0f} {} ({} calls, avg {:.3f} per call)'.format(t / 1e6, k, n, t / (1e6*n))
for k, t, n in sorted(aggstats, reverse=True, key=lambda item: item[1]) if n))
def _stack(self, values):
lines = [' %0 = EVALARGS']
for (op, indices), v in zip(self.serialized, values):
lines[-1] += ' --> ' + type(v).__name__
if numeric.isarray(v):
lines[-1] += '({})'.format(','.join(map(str, v.shape)))
try:
code = op.evalf.__code__
offset = 1 if getattr(op.evalf, '__self__', None) is not None else 0
names = code.co_varnames[offset:code.co_argcount]
names += tuple('{}[{}]'.format(code.co_varnames[code.co_argcount], n) for n in range(len(indices) - len(names)))
args = map(' {}=%{}'.format, names, indices)
except:
args = map(' %{}'.format, indices)
lines.append(' %{} = {}:{}'.format(len(lines), op, ','.join(args)))
return lines
@property
@replace(depthfirst=True, recursive=True)
def simplified(obj):
if isinstance(obj, Evaluable):
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval
def _simplified(self):
return
@property
def optimized_for_numpy(self):
retval = self._optimized_for_numpy1() or self
return retval._combine_loop_concatenates(frozenset())
@types.apply_annotations
@replace(depthfirst=True, recursive=True)
def _optimized_for_numpy1(obj: simplified.fget):
if isinstance(obj, Evaluable):
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval
def _optimized_for_numpy(self):
return
@property
def _loop_concatenate_deps(self):
deps = []
for arg in self.__args:
deps += [dep for dep in arg._loop_concatenate_deps if dep not in deps]
return tuple(deps)
def _combine_loop_concatenates(self, outer_exclude):
while True:
exclude = set(outer_exclude)
combine = {}
# Collect all top-level `LoopConcatenate` instances in `combine` and all
# their dependent `LoopConcatenate` instances in `exclude`.
for lc in self._loop_concatenate_deps:
lcs = combine.setdefault(lc.index, [])
if lc not in lcs:
lcs.append(lc)
exclude.update(set(lc._loop_concatenate_deps) - {lc})
# Combine top-level `LoopConcatenate` instances excluding those in
# `exclude`.
replacements = {}
for index, lcs in combine.items():
lcs = [lc for lc in lcs if lc not in exclude]
if not lcs:
continue
# We're extracting data from `LoopConcatenate` in favor of using
# `loop_concatenate_combined(lcs, ...)` because the later requires
# reapplying simplifications that are already applied in the former.
# For example, in `loop_concatenate_combined` the offsets (used by
# start, stop and the concatenation length) are formed by
# `loop_concatenate`-ing `func.shape[-1]`. If the shape is constant,
# this can be simplified to a `Range`.
data = Tuple((Tuple(lc.funcdata) for lc in lcs))
# Combine `LoopConcatenate` instances in `data` excluding
# `outer_exclude` and those that will be processed in a subsequent loop
# (the remainder of `exclude`). The latter consists of loops that are
# invariant w.r.t. the current loop `index`.
data = data._combine_loop_concatenates(exclude)
combined = LoopConcatenateCombined(data, index._name, index.length)
for i, lc in enumerate(lcs):
intbounds = dict(zip(('_lower', '_upper'), lc._intbounds)) if lc.dtype == int else {}
replacements[lc] = ArrayFromTuple(combined, i, lc.shape, lc.dtype, **intbounds)
if replacements:
self = replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None, recursive=False, depthfirst=False)(self)
else:
return self
class EvaluationError(Exception):
def __init__(self, f, values):
super().__init__('evaluation failed in step {}/{}\n'.format(len(values), len(f.dependencies)) + '\n'.join(f._stack(values)))
class EVALARGS(Evaluable):
def __init__(self):
super().__init__(args=())
def _node(self, cache, subgraph, times):
return InvisibleNode((type(self).__name__, _Stats()))
EVALARGS = EVALARGS()
class EvaluableConstant(Evaluable):
'''Evaluate to the given constant value.
Parameters
----------
value
The return value of ``eval``.
'''
__slots__ = 'value'
def __init__(self, value):
self.value = value
super().__init__(())
def evalf(self):
return self.value
@property
def _node_details(self):
s = repr(self.value)
if '\n' in s:
s = s.split('\n', 1)[0] + '...'
if len(s) > 20:
s = s[:17] + '...'
return s
class Tuple(Evaluable):
__slots__ = 'items'
@types.apply_annotations
def __init__(self, items: types.tuple[strictevaluable]):
self.items = items
super().__init__(items)
def evalf(self, *items):
return items
def __iter__(self):
'iterate'
return iter(self.items)
def __len__(self):
'length'
return len(self.items)
def __getitem__(self, item):
'get item'
return self.items[item]
def __add__(self, other):
'add'
return Tuple(self.items + tuple(other))
def __radd__(self, other):
'add'
return Tuple(tuple(other) + self.items)
class SparseArray(Evaluable):
'sparse array'
@types.apply_annotations
def __init__(self, chunks: types.tuple[asarrays], shape: asarrays, dtype: asdtype):
self._shape = shape
self._dtype = dtype
super().__init__(args=[Tuple(shape), *map(Tuple, chunks)])
def evalf(self, shape, *chunks):
length = builtins.sum(values.size for *indices, values in chunks)
data = numpy.empty((length,), dtype=sparse.dtype(tuple(map(int, shape)), self._dtype))
start = 0
for *indices, values in chunks:
stop = start + values.size
d = data[start:stop].reshape(values.shape)
d['value'] = values
for idim, ii in enumerate(indices):
d['index']['i'+str(idim)] = ii
start = stop
return data
# ARRAYFUNC
#
# The main evaluable. Closely mimics a numpy array.
def add(a, b):
a, b = _numpy_align(a, b)
return Add([a, b])
def multiply(a, b):
a, b = _numpy_align(a, b)
return Multiply([a, b])
def sum(arg, axis=None):
'''Sum array elements over a given axis.'''
if axis is None:
return Sum(arg)
axes = (axis,) if numeric.isint(axis) else axis
summed = Transpose.to_end(arg, *axes)
for i in range(len(axes)):
summed = Sum(summed)
return summed
def product(arg, axis):
return Product(Transpose.to_end(arg, axis))
def power(arg, n):
arg, n = _numpy_align(arg, n)
return Power(arg, n)
def dot(a, b, axes):
'''
Contract ``a`` and ``b`` along ``axes``.
'''
return multiply(a, b).sum(axes)
def conjugate(arg):
arg = asarray(arg)
if arg.dtype == complex:
return Conjugate(arg)
else:
return arg
conjugate
def real(arg):
arg = asarray(arg)
if arg.dtype == complex:
return Real(arg)
else:
return arg
def imag(arg):
arg = asarray(arg)
if arg.dtype == complex:
return Imag(arg)
else:
return zeros_like(arg)
def transpose(arg, trans=None):
arg = asarray(arg)
if trans is None:
normtrans = range(arg.ndim-1, -1, -1)
else:
normtrans = _normdims(arg.ndim, trans)
assert sorted(normtrans) == list(range(arg.ndim))
return Transpose(arg, normtrans)
def swapaxes(arg, axis1, axis2):
arg = asarray(arg)
trans = numpy.arange(arg.ndim)
trans[axis1], trans[axis2] = trans[axis2], trans[axis1]
return transpose(arg, trans)
def align(arg, where, shape):
'''Align array to target shape.
The align operation can be considered the opposite of transpose: instead of
specifying for each axis of the return value the original position in the
argument, align specifies for each axis of the argument the new position in
the return value. In addition, the return value may be of higher dimension,
with new axes being inserted according to the ``shape`` argument.
Args
----
arg : :class:`Array`
Original array.
where : :class:`tuple` of integers
New axis positions.
shape : :class:`tuple`
Shape of the aligned array.
Returns
-------
:class:`Array`
The aligned array.
'''
where = list(where)
for i, length in enumerate(shape):
if i not in where:
arg = InsertAxis(arg, length)
where.append(i)
if where != list(range(len(shape))):
arg = Transpose(arg, numpy.argsort(where))
assert equalshape(arg.shape, shape)
return arg
def unalign(*args):
'''Remove (joint) inserted axes.
Given one or more equally shaped array arguments, return the shortest common
axis vector along with function arguments such that the original arrays can
be recovered by :func:`align`.
'''
assert args
if len(args) == 1:
return args[0]._unaligned
if any(arg.ndim != args[0].ndim for arg in args[1:]):
raise ValueError('varying dimensions in unalign')
nonins = functools.reduce(operator.or_, [set(arg._unaligned[1]) for arg in args])
if len(nonins) == args[0].ndim:
return (*args, tuple(range(args[0].ndim)))
ret = []
for arg in args:
unaligned, where = arg._unaligned
for i in sorted(nonins - set(where)):
unaligned = InsertAxis(unaligned, args[0].shape[i])
where += i,
if not ret: # first argument
commonwhere = where
elif where != commonwhere:
unaligned = Transpose(unaligned, map(where.index, commonwhere))
ret.append(unaligned)
return (*ret, commonwhere)
# ARRAYS
_ArrayMeta = type(Evaluable)
if debug_flags.sparse:
def _chunked_assparse_checker(orig):
assert isinstance(orig, property)
@property
def _assparse(self):
chunks = orig.fget(self)
assert isinstance(chunks, tuple)
assert all(isinstance(chunk, tuple) for chunk in chunks)
assert all(all(isinstance(item, Array) for item in chunk) for chunk in chunks)
if self.ndim:
for *indices, values in chunks:
assert len(indices) == self.ndim
assert all(idx.dtype == int for idx in indices)
assert all(equalshape(idx.shape, values.shape) for idx in indices)
elif chunks:
assert len(chunks) == 1
chunk, = chunks
assert len(chunk) == 1
values, = chunk
assert values.shape == ()
return chunks
return _assparse
class _ArrayMeta(_ArrayMeta):
def __new__(mcls, name, bases, namespace):
if '_assparse' in namespace:
namespace['_assparse'] = _chunked_assparse_checker(namespace['_assparse'])
return super().__new__(mcls, name, bases, namespace)
if debug_flags.evalf:
class _evalf_checker:
def __init__(self, orig):
self.evalf_obj = getattr(orig, '__get__', lambda *args: orig)
def __get__(self, instance, owner):
evalf = self.evalf_obj(instance, owner)
@functools.wraps(evalf)
def evalf_with_check(*args, **kwargs):
res = evalf(*args, **kwargs)
assert not hasattr(instance, 'dtype') or asdtype(res.dtype) == instance.dtype, ((instance.dtype, res.dtype), instance, res)
assert not hasattr(instance, 'ndim') or res.ndim == instance.ndim
assert not hasattr(instance, 'shape') or all(m == n for m, n in zip(res.shape, instance.shape) if isinstance(n, int)), 'shape mismatch'
return res
return evalf_with_check
class _ArrayMeta(_ArrayMeta):
def __new__(mcls, name, bases, namespace):
if 'evalf' in namespace:
namespace['evalf'] = _evalf_checker(namespace['evalf'])
return super().__new__(mcls, name, bases, namespace)
class AsEvaluableArray(Protocol):
'Protocol for conversion into an :class:`Array`.'
@property
def as_evaluable_array(self) -> 'Array':
'Lower this object to a :class:`nutils.evaluable.Array`.'
class Array(Evaluable, metaclass=_ArrayMeta):
'''
Base class for array valued functions.
Attributes
----------
shape : :class:`tuple` of :class:`int`\\s
The shape of this array function.
ndim : :class:`int`
The number of dimensions of this array array function. Equal to
``len(shape)``.
dtype : :class:`int`, :class:`float`
The dtype of the array elements.
'''
__slots__ = 'shape', 'dtype', '__index'
__cache__ = 'assparse', '_assparse', '_intbounds'
__array_priority__ = 1. # http://stackoverflow.com/questions/7042496/numpy-coercion-problem-for-left-sided-binary-operator/7057530#7057530
@types.apply_annotations
def __init__(self, args: types.tuple[strictevaluable], shape: asshape, dtype: asdtype):
self.shape = shape
self.dtype = dtype
super().__init__(args=args)
@property
def ndim(self):
return len(self.shape)
def __getitem__(self, item):
if not isinstance(item, tuple):
item = item,
if ... in item:
iell = item.index(...)
if ... in item[iell+1:]:
raise IndexError('an index can have only a single ellipsis')
# replace ellipsis by the appropriate number of slice(None)
item = item[:iell] + (slice(None),)*(self.ndim-len(item)+1) + item[iell+1:]
if len(item) > self.ndim:
raise IndexError('too many indices for array')
array = self
for axis, it in reversed(tuple(enumerate(item))):
array = get(array, axis, item=it) if numeric.isint(it) \
else _takeslice(array, it, axis) if isinstance(it, slice) \
else take(array, it, axis)
return array
def __bool__(self):
return True
def __len__(self):
if self.ndim == 0:
raise TypeError('len() of unsized object')
return self.shape[0]
def __index__(self):
try:
index = self.__index
except AttributeError:
if self.ndim or self.dtype not in (int, bool) or not self.isconstant:
raise TypeError('cannot convert {!r} to int'.format(self))
index = self.__index = int(self.simplified.eval())
return index
size = property(lambda self: util.product(self.shape) if self.ndim else 1)
T = property(lambda self: transpose(self))
__add__ = __radd__ = add
__sub__ = lambda self, other: subtract(self, other)
__rsub__ = lambda self, other: subtract(other, self)
__mul__ = __rmul__ = multiply
__truediv__ = lambda self, other: divide(self, other)
__rtruediv__ = lambda self, other: divide(other, self)
__pos__ = lambda self: self
__neg__ = lambda self: negative(self)
__pow__ = power
__abs__ = lambda self: abs(self)
__mod__ = lambda self, other: mod(self, other)
__int__ = __index__
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str))
_shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(str(int(length)) if length.isconstant else '?' for length in self.shape) if hasattr(self, 'shape') else '?')
sum = sum
prod = product
dot = dot
swapaxes = swapaxes
transpose = transpose
choose = lambda self, choices: Choose(self, choices)
conjugate = conjugate
@property
def real(self):
return real(self)
@property
def imag(self):
return imag(self)
@property
def assparse(self):
'Convert to a :class:`SparseArray`.'
return SparseArray(self.simplified._assparse, self.shape, self.dtype)
@property
def _assparse(self):
# Convert to a sequence of sparse COO arrays. The returned data is a tuple
# of `(*indices, values)` tuples, where `values` is an `Array` with the
# same dtype as `self`, but this is not enforced yet, and each index in
# `indices` is an `Array` with dtype `int` and the exact same shape as
# `values`. The length of `indices` equals `self.ndim`. In addition, if
# `self` is 0d the length of `self._assparse` is at most one and the
# `values` array must be 0d as well.
#
# The sparse data can be reassembled after evaluation by
#
# dense = numpy.zeros(self.shape)
# for I0,...,Ik,V in self._assparse:
# for i0,...,ik,v in zip(I0.eval().ravel(),...,Ik.eval().ravel(),V.eval().ravel()):
# dense[i0,...,ik] = v
indices = [prependaxes(appendaxes(Range(length), self.shape[i+1:]), self.shape[:i]) for i, length in enumerate(self.shape)]
return (*indices, self),
def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
args = tuple(arg._node(cache, subgraph, times) for arg in self._Evaluable__args)
bounds = '[{},{}]'.format(*self._intbounds) if self.dtype == int else None
label = '\n'.join(filter(None, (type(self).__name__, self._node_details, self._shape_str(form=repr), bounds)))
cache[self] = node = RegularNode(label, args, {}, (type(self).__name__, times[self]), subgraph)
return node
# simplifications
_multiply = lambda self, other: None
_transpose = lambda self, axes: None
_insertaxis = lambda self, axis, length: None
_power = lambda self, n: None
_add = lambda self, other: None
_sum = lambda self, axis: None
_take = lambda self, index, axis: None
_rtake = lambda self, index, axis: None
_determinant = lambda self, axis1, axis2: None
_inverse = lambda self, axis1, axis2: None
_takediag = lambda self, axis1, axis2: None
_diagonalize = lambda self, axis: None
_product = lambda self: None
_sign = lambda self: None
_eig = lambda self, symmetric: None
_inflate = lambda self, dofmap, length, axis: None
_rinflate = lambda self, func, length, axis: None
_unravel = lambda self, axis, shape: None
_ravel = lambda self, axis: None
_loopsum = lambda self, loop_index: None # NOTE: type of `loop_index` is `_LoopIndex`
_real = lambda self: None
_imag = lambda self: None
_conjugate = lambda self: None
@property
def _unaligned(self):
return self, tuple(range(self.ndim))
_diagonals = ()
_inflations = ()
def _derivative(self, var, seen):
if self.dtype in (bool, int) or var not in self.dependencies:
return Zeros(self.shape + var.shape, dtype=self.dtype)
raise NotImplementedError('derivative not defined for {}'.format(self.__class__.__name__))
@property
def as_evaluable_array(self):
'return self'