-
Notifications
You must be signed in to change notification settings - Fork 48
/
function.py
3835 lines (2941 loc) · 136 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import typing
if typing.TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
from typing import Tuple, Union, Type, Callable, Sequence, Any, Optional, Iterator, Iterable, Dict, Mapping, List, FrozenSet
from . import evaluable, numeric, util, types, warnings, debug_flags, sparse
from .transform import EvaluableTransformChain
from .transformseq import Transforms
import builtins
import numpy
import functools
import operator
import numbers
IntoArray = Union['Array', numpy.ndarray, bool, int, float, complex]
Shape = Sequence[int]
DType = Type[Union[bool, int, float, complex]]
_dtypes = bool, int, float, complex
_PointsShape = Tuple[evaluable.Array, ...]
_TransformChainsMap = Mapping[str, Tuple[EvaluableTransformChain, EvaluableTransformChain]]
_CoordinatesMap = Mapping[str, evaluable.Array]
class Lowerable(Protocol):
'Protocol for lowering to :class:`nutils.evaluable.Array`.'
@property
def spaces(self) -> FrozenSet[str]: ...
@property
def arguments(self) -> Mapping[str, Tuple[Shape, DType]]: ...
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
'''Lower this object to a :class:`nutils.evaluable.Array`.
Parameters
----------
points_shape : :class:`tuple` of scalar, integer :class:`nutils.evaluable.Array`
The shape of the leading points axes that are to be added to the
lowered :class:`nutils.evaluable.Array`.
transform_chains : mapping of :class:`str` to :class:`nutils.transform.EvaluableTransformChain` pairs
coordinates : mapping of :class:`str` to :class:`nutils.evaluable.Array` objects
The coordinates at which the function will be evaluated.
'''
_ArrayMeta = type
if debug_flags.lower:
def _debug_lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
result = self._ArrayMeta__debug_lower_orig(points_shape, transform_chains, coordinates)
assert isinstance(result, evaluable.Array)
assert all(evaluable.equalshape(coords.shape[:-1], points_shape) for coords in coordinates.values())
assert all(space in transform_chains for space in coordinates)
offset = 0 if type(self) == _WithoutPoints else len(points_shape)
assert result.ndim == self.ndim + offset
assert tuple(int(sh) for sh in result.shape[offset:]) == self.shape, 'shape mismatch'
assert result.dtype == self.dtype, ('dtype mismatch', self.__class__)
return result
class _ArrayMeta(_ArrayMeta):
def __new__(mcls, name, bases, namespace):
if 'lower' in namespace:
namespace['_ArrayMeta__debug_lower_orig'] = namespace.pop('lower')
namespace['lower'] = _debug_lower
return super().__new__(mcls, name, bases, namespace)
# The lower cache introduced below should stay below the debug wrapper added
# above. Otherwise the cached results are debugge again and again.
def _cache_lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
key = points_shape, transform_chains, coordinates
cached_key, cached_result = getattr(self, '_ArrayMeta__cached_lower', (None, None))
if cached_key == key:
return cached_result
missing_spaces = self.spaces - set(transform_chains)
if missing_spaces:
raise ValueError('Cannot lower {} because the following spaces are unspecified: {}.'.format(self, missing_spaces))
result = self._ArrayMeta__cache_lower_orig(points_shape, transform_chains, coordinates)
self._ArrayMeta__cached_lower = key, result
return result
class _ArrayMeta(_ArrayMeta):
def __new__(mcls, name, bases, namespace):
if 'lower' in namespace:
namespace['_ArrayMeta__cache_lower_orig'] = namespace.pop('lower')
namespace['lower'] = _cache_lower
return super().__new__(mcls, name, bases, namespace)
class Array(numpy.lib.mixins.NDArrayOperatorsMixin, metaclass=_ArrayMeta):
'''Base class for array valued functions.
Parameters
----------
shape : :class:`tuple` of :class:`int`
The shape of the array function.
dtype : :class:`bool`, :class:`int`, :class:`float` or :class:`complex`
The dtype of the array elements.
spaces : :class:`frozenset` of :class:`str`
The spaces this array function is defined on.
arguments : mapping of :class:`str`
The mapping of argument names to their shapes and dtypes for all
arguments of this array function.
Attributes
----------
shape : :class:`tuple` of :class:`int`
The shape of this array function.
ndim : :class:`int`
The dimension of this array function.
dtype : :class:`bool`, :class:`int`, :class:`float` or :class:`complex`
The dtype of the array elements.
spaces : :class:`frozenset` of :class:`str`
The spaces this array function is defined on.
arguments : mapping of :class:`str`
The mapping of argument names to their shapes and dtypes for all
arguments of this array function.
'''
__array_priority__ = 1. # http://stackoverflow.com/questions/7042496/numpy-coercion-problem-for-left-sided-binary-operator/7057530#7057530
def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
if method != '__call__' or ufunc not in HANDLED_FUNCTIONS:
return NotImplemented
try:
arrays = [v if isinstance(v, (Array, bool, int, float, complex, numpy.ndarray)) else Array.cast(v) for v in inputs]
except ValueError:
return NotImplemented
return HANDLED_FUNCTIONS[ufunc](*arrays, **kwargs)
def __array_function__(self, func, types, args, kwargs):
if func not in HANDLED_FUNCTIONS:
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
@classmethod
def cast(cls, __value: IntoArray, dtype: Optional[DType] = None, ndim: Optional[int] = None) -> 'Array':
'''Cast a value to an :class:`Array`.
Parameters
----------
value : :class:`Array`, or a :class:`numpy.ndarray` or similar
The value to cast.
'''
if isinstance(__value, Array):
value = __value
else:
try:
value = _Constant(__value)
except:
if isinstance(__value, (list, tuple)):
value = stack(__value, axis=0)
else:
raise ValueError('cannot convert {}.{} to Array'.format(type(__value).__module__, type(__value).__qualname__))
if dtype is not None and _dtypes.index(value.dtype) > _dtypes.index(dtype):
raise ValueError('expected an array with dtype `{}` but got `{}`'.format(dtype.__name__, value.dtype.__name__))
if ndim is not None and value.ndim != ndim:
raise ValueError('expected an array with dimension `{}` but got `{}`'.format(ndim, value.ndim))
return value
@classmethod
def cast_withscale(cls, __value: IntoArray, dtype: Optional[DType] = None, ndim: Optional[int] = None):
try:
scale = type(__value).reference_quantity
except AttributeError:
value = cls.cast(__value, dtype=dtype, ndim=ndim)
scale = value.dtype(1)
else:
value = cls.cast(__value / scale, dtype=dtype, ndim=ndim)
return value, scale
def __init__(self, shape: Shape, dtype: DType, spaces: FrozenSet[str], arguments: Mapping[str, Tuple[Shape, DType]]) -> None:
self.shape = tuple(sh.__index__() for sh in shape)
self.dtype = dtype
self.spaces = frozenset(spaces)
self.arguments = dict(arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
raise NotImplementedError
@util.cached_property
def as_evaluable_array(self) -> evaluable.Array:
return self.lower((), {}, {})
def __index__(self):
if self.arguments or self.spaces:
raise ValueError('cannot convert non-constant array to index: arguments={}'.format(','.join(self.arguments)))
elif self.ndim:
raise ValueError('cannot convert non-scalar array to index: shape={}'.format(self.shape))
elif self.dtype != int:
raise ValueError('cannot convert non-integer array to index: dtype={}'.format(self.dtype.__name__))
else:
return self.as_evaluable_array.__index__()
@property
def ndim(self) -> int:
return len(self.shape)
def __getitem__(self, item: Any) -> 'Array':
if not isinstance(item, tuple):
item = item,
iell = None
nx = self.ndim - len(item)
for i, it in enumerate(item):
if it is ...:
assert iell is None, 'at most one ellipsis allowed'
iell = i
elif it is numpy.newaxis:
nx += 1
array = self
axis = 0
for it in item + (slice(None),)*nx if iell is None else item[:iell] + (slice(None),)*(nx+1) + item[iell+1:]:
if isinstance(it, numbers.Integral):
array = get(array, axis, it)
else:
array = expand_dims(array, axis) if it is numpy.newaxis \
else _takeslice(array, it, axis) if isinstance(it, slice) \
else take(array, it, axis)
axis += 1
assert axis == array.ndim
return array
def __bool__(self) -> bool:
return True
def __len__(self) -> int:
'Length of the first axis.'
if self.ndim == 0:
raise TypeError('len() of unsized object')
return self.shape[0]
def __iter__(self) -> Iterator['Array']:
'Iterator over the first axis.'
if self.ndim == 0:
raise TypeError('iteration over a 0-D array')
return (self[i, ...] for i in range(self.shape[0]))
@property
def size(self) -> Union[int, 'Array']:
'The total number of elements in this array.'
return util.product(self.shape, 1)
@property
def T(self) -> 'Array':
'The transposed array.'
return transpose(self)
def astype(self, dtype):
if dtype == self.dtype:
return self
else:
return _Wrapper(functools.partial(evaluable.astype, dtype=dtype), self, shape=self.shape, dtype=dtype)
def sum(self, axis: Optional[Union[int, Sequence[int]]] = None) -> 'Array':
'See :func:`sum`.'
return sum(self, axis)
def prod(self, __axis: int) -> 'Array':
'See :func:`prod`.'
return product(self, __axis)
def dot(self, __other: IntoArray, axes: Optional[Union[int, Sequence[int]]] = None) -> 'Array':
'See :func:`dot`.'
return dot(self, __other, axes)
def normalized(self, __axis: int = -1) -> 'Array':
'See :func:`normalized`.'
return normalized(self, __axis)
def normal(self, refgeom: Optional['Array'] = None) -> 'Array':
'See :func:`normal`.'
return normal(self, refgeom)
def curvature(self, ndims: int = -1) -> 'Array':
'See :func:`curvature`.'
return curvature(self, ndims)
def swapaxes(self, __axis1: int, __axis2: int) -> 'Array':
'See :func:`swapaxes`.'
return swapaxes(self, __axis1, __axis2)
def transpose(self, __axes: Optional[Sequence[int]]) -> 'Array':
'See :func:`transpose`.'
return transpose(self, __axes)
def add_T(self, axes: Tuple[int, int]) -> 'Array':
'See :func:`add_T`.'
return add_T(self, axes)
def grad(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`grad`.'
return grad(self, __geom, ndims)
def laplace(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`laplace`.'
return laplace(self, __geom, ndims)
def symgrad(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`symgrad`.'
return symgrad(self, __geom, ndims)
def div(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`div`.'
return div(self, __geom, ndims)
def curl(self, __geom: IntoArray) -> 'Array':
'See :func:`curl`.'
return curl(self, __geom)
def dotnorm(self, __geom: IntoArray, axis: int = -1) -> 'Array':
'See :func:`dotnorm`.'
return dotnorm(self, __geom, axis)
def tangent(self, __vec: IntoArray) -> 'Array':
'See :func:`tangent`.'
return tangent(self, __vec)
def ngrad(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`ngrad`.'
return ngrad(self, __geom, ndims)
def nsymgrad(self, __geom: IntoArray, ndims: int = 0) -> 'Array':
'See :func:`nsymgrad`.'
return nsymgrad(self, __geom, ndims)
def choose(self, __choices: Sequence[IntoArray]) -> 'Array':
'See :func:`choose`.'
return choose(self, __choices)
def vector(self, ndims):
if not self.ndim:
raise Exception('a scalar function cannot be vectorized')
return ravel(diagonalize(insertaxis(self, 1, ndims), 1), 0)
def __repr__(self) -> str:
return 'Array<{}>'.format(','.join(str(n) for n in self.shape))
@property
def simplified(self):
warnings.deprecation('`nutils.function.Array.simplified` is deprecated. This property returns the array unmodified and can safely be omitted.')
return self
def eval(self, **arguments: Any) -> numpy.ndarray:
'Evaluate this function.'
from .sample import eval_integrals
return eval_integrals(self, **arguments)[0]
def derivative(self, __var: Union[str, 'Argument']) -> 'Array':
'See :func:`derivative`.'
return derivative(self, __var)
def replace(self, __arguments: Mapping[str, IntoArray]) -> 'Array':
'Return a copy with arguments applied.'
return replace_arguments(self, __arguments)
def contains(self, __name: str) -> bool:
'Test if target occurs in this function.'
return __name in self.arguments
@property
def argshapes(self) -> Mapping[str, Tuple[int, ...]]:
return {name: shape for name, (shape, dtype) in self.arguments.items()}
def conjugate(self):
'See :func:`conjugate`.'
return conjugate(self)
conj = conjugate
@property
def real(self):
'See :func:`real`.'
return real(self)
@property
def imag(self):
'See :func:`imag`.'
return imag(self)
class _Unlower(Array):
def __init__(self, array: evaluable.Array, spaces: FrozenSet[str], arguments: Mapping[str, Tuple[Shape, DType]], points_shape: Tuple[evaluable.Array, ...], transform_chains: Tuple[EvaluableTransformChain, ...], coordinates: Tuple[evaluable.Array, ...]) -> None:
self._array = array
self._points_shape = points_shape
self._transform_chains = transform_chains
self._coordinates = coordinates
shape = tuple(n.__index__() for n in array.shape[len(points_shape):])
super().__init__(shape=shape, dtype=array.dtype, spaces=spaces, arguments=arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
if self._points_shape != points_shape or self._transform_chains != transform_chains or self._coordinates != coordinates:
raise ValueError('_Unlower must be lowered with the same arguments as those with which it is instantiated.')
return self._array
class Custom(Array):
'''Combined :mod:`nutils.function` and :mod:`nutils.evaluable` array base class.
Ordinary :class:`Array` subclasses should define the ``Array.lower`` method,
which returns the corresponding :class:`nutils.evaluable.Array` with the
proper amount of points axes. In many cases the :class:`Array` subclass is
trivial and the corresponding :class:`nutils.evaluable.Array` contains all
the specifics. For those situations the :class:`Custom` base class exists.
Rather than defining the ``Array.lower`` method, this base class allows you
to define a :meth:`Custom.evalf` and optionally a
:meth:`Custom.partial_derivative`, which are used to instantiate a generic
:class:`nutils.evaluable.Array` automatically during lowering.
By default the :class:`Array` arguments passed to the constructor are
unmodified. Broadcasting and singleton expansion, if required, should be
applied before passing the arguments to the constructor of :class:`Custom`.
It is possible to declare ``npointwise`` leading axes as being pointwise. In
that case :class:`Custom` applies singleton expansion to the leading
pointwise axes and the shape of the result passed to :class:`Custom` should
not include the pointwise axes.
For internal reasons, both ``evalf`` and ``partial_derivative`` must be
decorated as ``classmethod`` or ``staticmethod``, meaning that they will not
receive a reference to ``self`` when called. Instead, all relevant data
should be passed to ``evalf`` via the constructor argument ``args``. The
constructor will automatically distinguish between Array and non-Array
arguments, and pass the latter on to ``evalf`` unchanged. The
``partial_derivative`` will not be called for those arguments.
The lowered array does not have a Nutils hash by default. If this is desired,
the methods :meth:`evalf` and :meth:`partial_derivative` can be decorated
with :func:`nutils.types.hashable_function` in addition to ``classmethod`` or
``staticmethod``.
Parameters
----------
args : iterable of :class:`Array` objects or immutable and hashable objects
The arguments of this array function.
shape : :class:`tuple` of :class:`int` or :class:`Array`
The shape of the array function without leading pointwise axes.
dtype : :class:`bool`, :class:`int`, :class:`float` or :class:`complex`
The dtype of the array elements.
npointwise : :class:`int`
The number of leading pointwise axis.
Example
-------
The following class implements :func:`multiply` using :class:`Custom`
without broadcasting and for :class:`float` arrays only.
>>> class Multiply(Custom):
...
... def __init__(self, left: IntoArray, right: IntoArray) -> None:
... # Broadcast the arrays. `broadcast_arrays` automatically casts the
... # arguments to `Array`.
... left, right = broadcast_arrays(left, right)
... # Dtype coercion is beyond the scope of this example.
... if left.dtype != float or right.dtype != float:
... raise ValueError('left and right arguments should have dtype float')
... # We treat all axes as pointwise, hence parameter `shape`, the shape
... # of the remainder, is empty and `npointwise` is the dimension of the
... # arrays.
... super().__init__(args=(left, right), shape=(), dtype=float, npointwise=left.ndim)
...
... @staticmethod
... def evalf(left: numpy.ndarray, right: numpy.ndarray) -> numpy.ndarray:
... # Because all axes are pointwise, the evaluated `left` and `right`
... # arrays are 1d.
... return left * right
...
... @staticmethod
... def partial_derivative(iarg: int, left: Array, right: Array) -> IntoArray:
... # The arguments passed to this function are of type `Array` and the
... # pointwise axes are omitted, hence `left` and `right` are 0d.
... if iarg == 0:
... return right
... elif iarg == 1:
... return left
... else:
... raise NotImplementedError
...
>>> Multiply([1., 2.], [3., 4.]).eval()
array([ 3., 8.])
>>> a = Argument('a', (2,))
>>> Multiply(a, [3., 4.]).derivative(a).eval(a=numpy.array([1., 2.])).export('dense')
array([[ 3., 0.],
[ 0., 4.]])
The following class wraps :func:`numpy.roll`, applied to the last axis of the
array argument, with constant shift.
>>> class Roll(Custom):
...
... def __init__(self, array: IntoArray, shift: int) -> None:
... array = asarray(array)
... # We are being nit-picky here and cast `exponent` to an `int` without
... # truncation.
... shift = shift.__index__()
... # We treat all but the last axis of `array` as pointwise.
... super().__init__(args=(array, shift), shape=array.shape[-1:], dtype=array.dtype, npointwise=array.ndim-1)
...
... @staticmethod
... def evalf(array: numpy.ndarray, shift: int) -> numpy.ndarray:
... # `array` is evaluated to a `numpy.ndarray` because we passed `array`
... # as an `Array` to the constructor. `shift`, however, is untouched
... # because it is not an `Array`. The `array` has two axes: a points
... # axis and the axis to be rolled.
... return numpy.roll(array, shift, 1)
...
... @staticmethod
... def partial_derivative(iarg, array: Array, shift: int) -> IntoArray:
... if iarg == 0:
... return Roll(eye(array.shape[0]), shift).T
... else:
... # We don't implement the derivative to `shift`, because this is
... # a constant `int`.
... raise NotImplementedError
...
>>> Roll([1, 2, 3], 1).eval()
array([3, 1, 2])
>>> b = Argument('b', (3,))
>>> Roll(b, 1).derivative(b).eval().export('dense')
array([[ 0., 0., 1.],
[ 1., 0., 0.],
[ 0., 1., 0.]])
'''
def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointwise: int = 0):
args = tuple(args)
if any(isinstance(arg, evaluable.Evaluable) for arg in args):
raise ValueError('It is not allowed to call this function with a `nutils.evaluable.Evaluable` argument.')
if npointwise:
# Apply singleton expansion to the leading points axes.
points_shapes = tuple(arg.shape[:npointwise] for arg in args if isinstance(arg, Array))
if not all(len(points_shape) == npointwise for points_shape in points_shapes):
raise ValueError('All arrays must have at least {} axes.'.format(npointwise))
if len(points_shapes) == 0:
raise ValueError('Pointwise axes can only be used in combination with at least one `function.Array` argument.')
points_shape = broadcast_shapes(*points_shapes)
args = tuple(broadcast_to(arg, points_shape + arg.shape[npointwise:]) if isinstance(arg, Array) else arg for arg in args)
else:
points_shape = ()
self._args = args
self._npointwise = npointwise
spaces = frozenset(space for arg in args if isinstance(arg, Array) for space in arg.spaces)
arguments = _join_arguments(arg.arguments for arg in args if isinstance(arg, Array))
super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
args = tuple(arg.lower(points_shape, transform_chains, coordinates) if isinstance(arg, Array) else evaluable.EvaluableConstant(arg) for arg in self._args) # type: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...]
add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise]))
points_shape += add_points_shape
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in coordinates.items()}
return _CustomEvaluable(type(self).__name__, self.evalf, self.partial_derivative, args, self.shape[self._npointwise:], self.dtype, self.spaces, types.frozendict(self.arguments), points_shape, tuple(transform_chains.items()), tuple(coordinates.items()))
@classmethod
def evalf(cls, *args: Any) -> numpy.ndarray:
'''Evaluate this function for the given evaluated arguments.
This function is called with arguments that correspond to the arguments
that are passed to the constructor of :class:`Custom`: every instance of
:class:`Array` is evaluated to a :class:`numpy.ndarray` with one leading
axis compared to the :class:`Array` and all other instances are passed as
is. The return value of this method should also include a leading axis with
the same length as the other array arguments have, or length one if there
are no array arguments. If constructor argument ``npointwise`` is nonzero,
the pointwise axes of the :class:`Array` arguments are raveled and included
in the single leading axis of the evaluated array arguments as well.
If possible this method should not use ``self``, e.g. by decorating this
method with :func:`staticmethod`. The result of this function must only
depend on the arguments and must not mutate the arguments.
This method is equivalent to ``nutils.evaluable.Array.evalf`` up to
the treatment of the leading axis.
Parameters
----------
*args
The evaluated arguments corresponding to the ``args`` parameter of the
:class:`Custom` constructor.
Returns
-------
:class:`numpy.ndarray`
The result of this function with one leading points axis.
'''
raise NotImplementedError # pragma: nocover
@classmethod
def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray:
'''Return the partial derivative of this function to :class:`Custom` constructor argument number ``iarg``.
This method is only called for those arguments that are instances of
:class:`Array` with dtype :class:`float` and have the derivative target as
a dependency. It is therefor allowed to omit an implementation for some or
all arguments if the above conditions are not met.
Axes that are declared pointwise via the ``npointwise`` constructor
argument are omitted.
Parameters
----------
iarg : :class:`int`
The index of the argument to compute the derivative for.
*args
The arguments as passed to the constructor of :class:`Custom`.
Returns
-------
:class:`Array` or similar
The partial derivative of this function to the given argument.
'''
raise NotImplementedError('The partial derivative of {} to argument {} (counting from 0) is not defined.'.format(cls.__name__, iarg)) # pragma: nocover
class _CustomEvaluable(evaluable.Array):
def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...], shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> None:
assert all(isinstance(arg, (evaluable.Array, evaluable.EvaluableConstant)) for arg in args)
self.name = name
self.custom_evalf = evalf
self.custom_partial_derivative = partial_derivative
self.args = args
self.points_dim = len(points_shape)
self.lower_args = points_shape, dict(transform_chains), dict(coordinates)
self.spaces = spaces
self.function_arguments = arguments
super().__init__((evaluable.Tuple(points_shape), *args), shape=points_shape+shape, dtype=dtype)
@property
def _node_details(self) -> str:
return self.name
def evalf(self, points_shape: Tuple[numpy.ndarray, ...], *args: Any) -> numpy.ndarray:
points_shape = tuple(n.__index__() for n in points_shape)
npoints = util.product(points_shape, 1)
# Flatten the points axes of the array arguments and call `custom_evalf`.
flattened = (arg.reshape(npoints, *arg.shape[self.points_dim:]) if isinstance(origarg, evaluable.Array) else arg for arg, origarg in zip(args, self.args))
result = self.custom_evalf(*flattened)
assert result.ndim == self.ndim + 1 - self.points_dim
# Unflatten the points axes of the result. If there are no array arguments,
# the points axis must have length one. Otherwise the length must be
# `npoints` (checked by `reshape`).
if not any(isinstance(origarg, evaluable.Array) for origarg in self.args):
if result.shape[0] != 1:
raise ValueError('Expected a points axis of length one but got {}.'.format(result.shape[0]))
return numpy.broadcast_to(result[0], points_shape + result.shape[1:])
else:
return result.reshape(points_shape + result.shape[1:])
def _derivative(self, var: evaluable.Array, seen: Dict[evaluable.Array, evaluable.Array]) -> evaluable.Array:
if self.dtype in (bool, int):
return super()._derivative(var, seen)
result = evaluable.Zeros(self.shape + var.shape, dtype=self.dtype)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, *self.lower_args) if isinstance(arg, evaluable.Array) else arg.value for arg in self.args)
for iarg, arg in enumerate(self.args):
if not isinstance(arg, evaluable.Array) or arg.dtype in (bool, int) or var not in arg.arguments and var != arg:
continue
fpd = Array.cast(self.custom_partial_derivative(iarg, *unlowered_args))
fpd_expected_shape = tuple(n.__index__() for n in self.shape[self.points_dim:] + arg.shape[self.points_dim:])
if fpd.shape != fpd_expected_shape:
raise ValueError('`partial_derivative` to argument {} returned an array with shape {} but {} was expected.'.format(iarg, fpd.shape, fpd_expected_shape))
epd = evaluable.appendaxes(fpd.lower(*self.lower_args), var.shape)
eda = evaluable.derivative(arg, var, seen)
eda = evaluable.Transpose.from_end(evaluable.appendaxes(eda, self.shape[self.points_dim:]), *range(self.points_dim, self.ndim))
result += (epd * eda).sum(range(self.ndim, self.ndim + arg.ndim - self.points_dim))
return result
class _WithoutPoints:
def __init__(self, __arg: Array) -> None:
self._arg = __arg
self.spaces = __arg.spaces
self.arguments = __arg.arguments
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return self._arg.lower((), transform_chains, {})
class _Wrapper(Array):
@classmethod
def broadcasted_arrays(cls, lower: Callable[..., evaluable.Array], *args: IntoArray, min_dtype: DType = bool, force_dtype: Optional[DType] = None) -> '_Wrapper':
broadcasted = broadcast_arrays(*typecast_arrays(*args, min_dtype=min_dtype))
return cls(lower, *broadcasted, shape=broadcasted[0].shape, dtype=force_dtype or broadcasted[0].dtype)
def __init__(self, lower: Callable[..., evaluable.Array], *args: Lowerable, shape: Shape, dtype: DType) -> None:
self._lower = lower
self._args = args
assert all(hasattr(arg, 'lower') for arg in self._args)
spaces = frozenset(space for arg in args for space in arg.spaces)
arguments = _join_arguments(arg.arguments for arg in self._args)
super().__init__(shape, dtype, spaces, arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return self._lower(*(arg.lower(points_shape, transform_chains, coordinates) for arg in self._args))
class _Zeros(Array):
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return evaluable.Zeros((*points_shape, *self.shape), self.dtype)
class _Ones(Array):
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return evaluable.ones((*points_shape, *self.shape), self.dtype)
class _Constant(Array):
def __init__(self, value: Any) -> None:
self._value = types.arraydata(value)
super().__init__(self._value.shape, self._value.dtype, frozenset(()), {})
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return evaluable.prependaxes(evaluable.Constant(self._value), points_shape)
class Argument(Array):
'''Array valued function argument.
Parameters
----------
name : str
The name of this argument.
shape : :class:`tuple` of :class:`int`
The shape of this argument.
dtype : :class:`bool`, :class:`int`, :class:`float` or :class:`complex`
The dtype of the array elements.
Attributes
----------
name : str
The name of this argument.
'''
def __init__(self, name: str, shape: Shape, *, dtype: DType = float) -> None:
self.name = name
super().__init__(shape, dtype, frozenset(()), {name: (tuple(shape), dtype)})
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
return evaluable.prependaxes(evaluable.Argument(self.name, self.shape, self.dtype), points_shape)
class _Replace(Array):
def __init__(self, arg: Array, replacements: Dict[str, Array]) -> None:
self._arg = arg
# TODO: verify that the replacements have empty spaces
self._replacements = replacements
# Build arguments map with replacements.
unreplaced = {}
arguments = [unreplaced]
for name, (shape, dtype) in arg.arguments.items():
replacement = replacements.get(name, None)
if replacement is None:
unreplaced[name] = shape, dtype
elif replacement.shape != shape:
raise ValueError('Argument {!r} has shape {} but the replacement has shape {}.'.format(name, shape, replacement.shape))
elif replacement.dtype != dtype:
raise ValueError('Argument {!r} has dtype {} but the replacement has dtype {}.'.format(name, dtype.__name__ if dtype in _dtypes else dtype, replacement.dtype.__name__ if replacement.dtype in _dtypes else replacement.dtype))
else:
arguments.append(replacement.arguments)
super().__init__(arg.shape, arg.dtype, arg.spaces, _join_arguments(arguments))
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
arg = self._arg.lower(points_shape, transform_chains, coordinates)
replacements = {name: _WithoutPoints(value).lower(points_shape, transform_chains, coordinates) for name, value in self._replacements.items()}
return evaluable.replace_arguments(arg, replacements)
class _Transpose(Array):
@classmethod
def _end(cls, array: Array, axes: Tuple[int, ...], invert: bool = False) -> Array:
axes = tuple(numeric.normdim(array.ndim, axis) for axis in axes)
if all(a == b for a, b in enumerate(axes, start=array.ndim-len(axes))):
return array
trans = [i for i in range(array.ndim) if i not in axes]
trans.extend(axes)
if len(trans) != array.ndim:
raise Exception('duplicate axes')
return cls(array, tuple(numpy.argsort(trans) if invert else trans))
@classmethod
def from_end(cls, array: Array, *axes: int) -> Array:
return cls._end(array, axes, invert=True)
@classmethod
def to_end(cls, array: Array, *axes: int) -> Array:
return cls._end(array, axes, invert=False)
def __init__(self, arg: Array, axes: Tuple[int, ...]) -> None:
self._arg = arg
self._axes = axes
super().__init__(tuple(arg.shape[axis] for axis in axes), arg.dtype, arg.spaces, arg.arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
arg = self._arg.lower(points_shape, transform_chains, coordinates)
offset = len(points_shape)
axes = (*range(offset), *(i+offset for i in self._axes))
return evaluable.Transpose(arg, axes)
class _Opposite(Array):
def __init__(self, arg: Array, space: str) -> None:
self._arg = arg
self._space = space
super().__init__(arg.shape, arg.dtype, arg.spaces, arg.arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
transform_chains = dict(transform_chains)
transform_chains[self._space] = transform_chains[self._space][::-1]
return self._arg.lower(points_shape, transform_chains, coordinates)
class _RootCoords(Array):
def __init__(self, space: str, ndims: int) -> None:
self._space = space
super().__init__((ndims,), float, frozenset({space}), {})
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
inv_linear = evaluable.diagonalize(evaluable.ones(self.shape))
inv_linear = evaluable.prependaxes(inv_linear, points_shape)
tip_coords = coordinates[self._space]
tip_coords = evaluable.WithDerivative(tip_coords, _tip_derivative_target(self._space, tip_coords.shape[-1]), evaluable.Diagonalize(evaluable.ones(tip_coords.shape)))
coords = transform_chains[self._space][0].apply(tip_coords)
return evaluable.WithDerivative(coords, _root_derivative_target(self._space, self.shape[0]), inv_linear)
class _TransformsIndex(Array):
def __init__(self, space: str, transforms: Transforms) -> None:
self._space = space
self._transforms = transforms
super().__init__((), int, frozenset({space}), {})
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
index, tail = transform_chains[self._space][0].index_with_tail_in(self._transforms)
return evaluable.prependaxes(index, points_shape)
class _TransformsCoords(Array):
def __init__(self, space: str, transforms: Transforms) -> None:
self._space = space
self._transforms = transforms
super().__init__((transforms.fromdims,), float, frozenset({space}), {})
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
index, tail = transform_chains[self._space][0].index_with_tail_in(self._transforms)
head = self._transforms.get_evaluable(index)
L = head.linear
if self._transforms.todims > self._transforms.fromdims:
LTL = evaluable.einsum('ki,kj->ij', L, L)
Linv = evaluable.einsum('ik,jk->ij', evaluable.inverse(LTL), L)
else:
Linv = evaluable.inverse(L)
Linv = evaluable.prependaxes(Linv, points_shape)
tip_coords = coordinates[self._space]
tip_coords = evaluable.WithDerivative(tip_coords, _tip_derivative_target(self._space, tip_coords.shape[-1]), evaluable.Diagonalize(evaluable.ones(tip_coords.shape)))
coords = tail.apply(tip_coords)
return evaluable.WithDerivative(coords, _root_derivative_target(self._space, self._transforms.todims), Linv)
class _Derivative(Array):
def __init__(self, arg: Array, var: Argument) -> None:
assert isinstance(var, Argument)
self._arg = arg
self._var = var
self._eval_var = evaluable.Argument(var.name, var.shape, var.dtype)
arguments = _join_arguments((arg.arguments, var.arguments))
super().__init__(arg.shape+var.shape, complex if var.dtype == complex else arg.dtype, arg.spaces | var.spaces, arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
arg = self._arg.lower(points_shape, transform_chains, coordinates)
return evaluable.derivative(arg, self._eval_var)
def _tip_derivative_target(space: str, dim: int) -> evaluable.DerivativeTargetBase:
return evaluable.IdentifierDerivativeTarget((space, 'tip'), (dim,))
def _root_derivative_target(space: str, dim: int) -> evaluable.DerivativeTargetBase:
return evaluable.IdentifierDerivativeTarget((space, 'root'), (dim,))
class _Gradient(Array):
# Derivative of `func` to `geom` using the root coords as reference.
def __init__(self, func: Array, geom: Array) -> None:
assert geom.spaces, '0d array'
assert geom.dtype == float
common_shape = broadcast_shapes(func.shape, geom.shape[:-1])
self._func = broadcast_to(func, common_shape)
self._geom = broadcast_to(geom, (*common_shape, geom.shape[-1]))
arguments = _join_arguments((func.arguments, geom.arguments))
super().__init__(self._geom.shape, complex if func.dtype == complex else float, func.spaces | geom.spaces, arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
func = self._func.lower(points_shape, transform_chains, coordinates)
geom = self._geom.lower(points_shape, transform_chains, coordinates)
ref_dim = builtins.sum(transform_chains[space][0].todims for space in self._geom.spaces)
if self._geom.shape[-1] != ref_dim:
raise Exception('cannot invert {}x{} jacobian'.format(self._geom.shape[-1], ref_dim))
refs = tuple(_root_derivative_target(space, chain.todims) for space, (chain, opposite) in transform_chains.items() if space in self._geom.spaces)
dfunc_dref = evaluable.concatenate([evaluable.derivative(func, ref) for ref in refs], axis=-1)
dgeom_dref = evaluable.concatenate([evaluable.derivative(geom, ref) for ref in refs], axis=-1)
dref_dgeom = evaluable.inverse(dgeom_dref)
return evaluable.einsum('Ai,Aij->Aj', dfunc_dref, dref_dgeom)
class _SurfaceGradient(Array):
# Surface gradient of `func` to `geom` using the tip coordinates as
# reference.
def __init__(self, func: Array, geom: Array) -> None:
assert geom.spaces, '0d array'
assert geom.dtype == float
common_shape = broadcast_shapes(func.shape, geom.shape[:-1])
self._func = broadcast_to(func, common_shape)
self._geom = broadcast_to(geom, (*common_shape, geom.shape[-1]))
arguments = _join_arguments((func.arguments, geom.arguments))
super().__init__(self._geom.shape, complex if func.dtype == complex else float, func.spaces | geom.spaces, arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
func = self._func.lower(points_shape, transform_chains, coordinates)
geom = self._geom.lower(points_shape, transform_chains, coordinates)
ref_dim = builtins.sum(transform_chains[space][0].fromdims for space in self._geom.spaces)
if self._geom.shape[-1] != ref_dim + 1:
raise ValueError('expected a {}d geometry but got a {}d geometry'.format(ref_dim + 1, self._geom.shape[-1]))
refs = tuple((_root_derivative_target if chain.todims == chain.fromdims else _tip_derivative_target)(space, chain.fromdims) for space, (chain, opposite) in transform_chains.items() if space in self._geom.spaces)
dfunc_dref = evaluable.concatenate([evaluable.derivative(func, ref) for ref in refs], axis=-1)
dgeom_dref = evaluable.concatenate([evaluable.derivative(geom, ref) for ref in refs], axis=-1)
dref_dgeom = evaluable.einsum('Ajk,Aik->Aij', dgeom_dref, evaluable.inverse(evaluable.grammium(dgeom_dref)))
return evaluable.einsum('Ai,Aij->Aj', dfunc_dref, dref_dgeom)
class _Jacobian(Array):
# The jacobian determinant of `geom` to the tip coordinates of the spaces of
# `geom`. The last axis of `geom` is the coordinate axis.
def __init__(self, geom: Array, tip_dim: Optional[int] = None) -> None:
assert geom.ndim >= 1
assert geom.dtype == float
if not geom.spaces and geom.shape[-1] != 0:
raise ValueError('The jacobian of a constant (in space) geometry must have dimension zero.')
if tip_dim is not None and tip_dim > geom.shape[-1]:
raise ValueError('Expected a dimension of the tip coordinate system '
'not greater than the dimension of the geometry.')
self._tip_dim = tip_dim
self._geom = geom
super().__init__((), float, geom.spaces, geom.arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
geom = self._geom.lower(points_shape, transform_chains, coordinates)
tip_dim = builtins.sum(transform_chains[space][0].fromdims for space in self._geom.spaces)
if self._tip_dim is not None and self._tip_dim != tip_dim:
raise ValueError('Expected a tip dimension of {} but got {}.'.format(self._tip_dim, tip_dim))
if self._geom.shape[-1] < tip_dim:
raise ValueError('the dimension of the geometry cannot be lower than the dimension of the tip coords')
if not self._geom.spaces:
return evaluable.ones(geom.shape[:-1])
tips = [_tip_derivative_target(space, chain.fromdims) for space, (chain, opposite) in transform_chains.items() if space in self._geom.spaces]
J = evaluable.concatenate([evaluable.derivative(geom, tip) for tip in tips], axis=-1)
return evaluable.sqrt_abs_det_gram(J)
class _Normal(Array):
def __init__(self, geom: Array) -> None:
self._geom = geom
assert geom.dtype == float
super().__init__(geom.shape, float, geom.spaces, geom.arguments)
def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array:
geom = self._geom.lower(points_shape, transform_chains, coordinates)
spaces_dim = builtins.sum(transform_chains[space][0].todims for space in self._geom.spaces)
normal_dim = spaces_dim - builtins.sum(transform_chains[space][0].fromdims for space in self._geom.spaces)
if self._geom.shape[-1] != spaces_dim:
raise ValueError('The dimension of geometry must equal the sum of the dimensions of the given spaces.')
if normal_dim == 0: