-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy patharraydecl.py
880 lines (745 loc) · 31 KB
/
arraydecl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
import numpy as np
import operator
from collections import namedtuple
from numba.core import types, utils
from numba.core.typing.templates import (AttributeTemplate, AbstractTemplate,
infer, infer_global, infer_getattr,
signature, bound_function)
# import time side effect: array operations requires typing support of sequence
# defined in collections: e.g. array.shape[i]
from numba.core.typing import collections
from numba.core.errors import (TypingError, RequireLiteralValue, NumbaTypeError,
NumbaNotImplementedError, NumbaAssertionError,
NumbaKeyError, NumbaIndexError, NumbaValueError)
from numba.core.cgutils import is_nonelike
numpy_version = tuple(map(int, np.__version__.split('.')[:2]))
Indexing = namedtuple("Indexing", ("index", "result", "advanced"))
def get_array_index_type(ary, idx):
"""
Returns None or a tuple-3 for the types of the input array, index, and
resulting type of ``array[index]``.
Note: This is shared logic for ndarray getitem and setitem.
"""
if not isinstance(ary, types.Buffer):
return
ndim = ary.ndim
left_indices = []
right_indices = []
ellipsis_met = False
advanced = False
num_newaxis = 0
if not isinstance(idx, types.BaseTuple):
idx = [idx]
# Here, a subspace is considered as a contiguous group of advanced indices.
# num_subspaces keeps track of the number of such
# contiguous groups.
in_subspace = False
num_subspaces = 0
array_indices = 0
# Walk indices
for ty in idx:
if ty is types.ellipsis:
if ellipsis_met:
raise NumbaTypeError(
"Only one ellipsis allowed in array indices "
"(got %s)" % (idx,))
ellipsis_met = True
in_subspace = False
elif isinstance(ty, types.SliceType):
# If we encounter a non-advanced index while in a
# subspace then that subspace ends.
in_subspace = False
# In advanced indexing, any index broadcastable to an
# array is considered an advanced index. Hence all the
# branches below are considered as advanced indices.
elif isinstance(ty, types.Integer):
# Normalize integer index
ty = types.intp if ty.signed else types.uintp
# Integer indexing removes the given dimension
ndim -= 1
# If we're within a subspace/contiguous group of
# advanced indices then no action is necessary
# since we've already counted that subspace once.
if not in_subspace:
# If we're not within a subspace and we encounter
# this branch then we have a new subspace/group.
num_subspaces += 1
in_subspace = True
elif (isinstance(ty, types.Array) and ty.ndim == 0
and isinstance(ty.dtype, types.Integer)):
# 0-d array used as integer index
ndim -= 1
if not in_subspace:
num_subspaces += 1
in_subspace = True
elif (isinstance(ty, types.Array)
and isinstance(ty.dtype, (types.Integer, types.Boolean))):
if ty.ndim > 1:
# Advanced indexing limitation # 1
raise NumbaTypeError(
"Multi-dimensional indices are not supported.")
array_indices += 1
# The condition for activating advanced indexing is simply
# having at least one array with size > 1.
advanced = True
if not in_subspace:
num_subspaces += 1
in_subspace = True
elif (is_nonelike(ty)):
ndim += 1
num_newaxis += 1
else:
raise NumbaTypeError("Unsupported array index type %s in %s"
% (ty, idx))
(right_indices if ellipsis_met else left_indices).append(ty)
if advanced:
if array_indices > 1:
# Advanced indexing limitation # 2
msg = "Using more than one non-scalar array index is unsupported."
raise NumbaTypeError(msg)
if num_subspaces > 1:
# Advanced indexing limitation # 3
msg = ("Using more than one indexing subspace is unsupported."
" An indexing subspace is a group of one or more"
" consecutive indices comprising integer or array types.")
raise NumbaTypeError(msg)
# Only Numpy arrays support advanced indexing
if advanced and not isinstance(ary, types.Array):
return
# Check indices and result dimensionality
all_indices = left_indices + right_indices
if ellipsis_met:
assert right_indices[0] is types.ellipsis
del right_indices[0]
n_indices = len(all_indices) - ellipsis_met - num_newaxis
if n_indices > ary.ndim:
raise NumbaTypeError("cannot index %s with %d indices: %s"
% (ary, n_indices, idx))
if n_indices == ary.ndim and ndim == 0 and not ellipsis_met:
# Full integer indexing => scalar result
# (note if ellipsis is present, a 0-d view is returned instead)
res = ary.dtype
elif advanced:
# Result is a copy
res = ary.copy(ndim=ndim, layout='C', readonly=False)
else:
# Result is a view
if ary.slice_is_copy:
# Avoid view semantics when the original type creates a copy
# when slicing.
return
# Infer layout
layout = ary.layout
def keeps_contiguity(ty, is_innermost):
# A slice can only keep an array contiguous if it is the
# innermost index and it is not strided
return (ty is types.ellipsis or isinstance(ty, types.Integer)
or (is_innermost and isinstance(ty, types.SliceType)
and not ty.has_step))
def check_contiguity(outer_indices):
"""
Whether indexing with the given indices (from outer to inner in
physical layout order) can keep an array contiguous.
"""
for ty in outer_indices[:-1]:
if not keeps_contiguity(ty, False):
return False
if outer_indices and not keeps_contiguity(outer_indices[-1], True):
return False
return True
if layout == 'C':
# Integer indexing on the left keeps the array C-contiguous
if n_indices == ary.ndim:
# If all indices are there, ellipsis's place is indifferent
left_indices = left_indices + right_indices
right_indices = []
if right_indices:
layout = 'A'
elif not check_contiguity(left_indices):
layout = 'A'
elif layout == 'F':
# Integer indexing on the right keeps the array F-contiguous
if n_indices == ary.ndim:
# If all indices are there, ellipsis's place is indifferent
right_indices = left_indices + right_indices
left_indices = []
if left_indices:
layout = 'A'
elif not check_contiguity(right_indices[::-1]):
layout = 'A'
if ndim == 0:
# Implicitly convert to a scalar if the output ndim==0
res = ary.dtype
else:
res = ary.copy(ndim=ndim, layout=layout)
# Re-wrap indices
if isinstance(idx, types.BaseTuple):
idx = types.BaseTuple.from_types(all_indices)
else:
idx, = all_indices
return Indexing(idx, res, advanced)
@infer_global(operator.getitem)
class GetItemBuffer(AbstractTemplate):
def generic(self, args, kws):
assert not kws
[ary, idx] = args
out = get_array_index_type(ary, idx)
if out is not None:
return signature(out.result, ary, out.index)
@infer_global(operator.setitem)
class SetItemBuffer(AbstractTemplate):
def generic(self, args, kws):
assert not kws
ary, idx, val = args
if not isinstance(ary, types.Buffer):
return
if not ary.mutable:
msg = f"Cannot modify readonly array of type: {ary}"
raise NumbaTypeError(msg)
out = get_array_index_type(ary, idx)
if out is None:
return
idx = out.index
res = out.result # res is the result type of the access ary[idx]
if isinstance(res, types.Array):
# Indexing produces an array
if isinstance(val, types.Array):
if not self.context.can_convert(val.dtype, res.dtype):
# DType conversion not possible
return
else:
res = val
elif isinstance(val, types.Sequence):
if (res.ndim == 1 and
self.context.can_convert(val.dtype, res.dtype)):
# Allow assignment of sequence to 1d array
res = val
else:
# NOTE: sequence-to-array broadcasting is unsupported
return
else:
# Allow scalar broadcasting
if self.context.can_convert(val, res.dtype):
res = res.dtype
else:
# Incompatible scalar type
return
elif not isinstance(val, types.Array):
# Single item assignment
if not self.context.can_convert(val, res):
# if the array dtype is not yet defined
if not res.is_precise():
# set the array type to use the dtype of value (RHS)
newary = ary.copy(dtype=val)
return signature(types.none, newary, idx, res)
else:
return
res = val
elif (isinstance(val, types.Array) and val.ndim == 0
and self.context.can_convert(val.dtype, res)):
# val is an array(T, 0d, O), where T is the type of res, O is order
res = val
else:
return
return signature(types.none, ary, idx, res)
def normalize_shape(shape):
if isinstance(shape, types.UniTuple):
if isinstance(shape.dtype, types.Integer):
dimtype = types.intp if shape.dtype.signed else types.uintp
return types.UniTuple(dimtype, len(shape))
elif isinstance(shape, types.Tuple) and shape.count == 0:
# Force (0 x intp) for consistency with other shapes
return types.UniTuple(types.intp, 0)
@infer_getattr
class ArrayAttribute(AttributeTemplate):
key = types.Array
def resolve_dtype(self, ary):
return types.DType(ary.dtype)
def resolve_nbytes(self, ary):
return types.intp
def resolve_itemsize(self, ary):
return types.intp
def resolve_shape(self, ary):
return types.UniTuple(types.intp, ary.ndim)
def resolve_strides(self, ary):
return types.UniTuple(types.intp, ary.ndim)
def resolve_ndim(self, ary):
return types.intp
def resolve_size(self, ary):
return types.intp
def resolve_flat(self, ary):
return types.NumpyFlatType(ary)
def resolve_ctypes(self, ary):
return types.ArrayCTypes(ary)
def resolve_flags(self, ary):
return types.ArrayFlags(ary)
def resolve_T(self, ary):
if ary.ndim <= 1:
retty = ary
else:
layout = {"C": "F", "F": "C"}.get(ary.layout, "A")
retty = ary.copy(layout=layout)
return retty
def resolve_real(self, ary):
return self._resolve_real_imag(ary, attr='real')
def resolve_imag(self, ary):
return self._resolve_real_imag(ary, attr='imag')
def _resolve_real_imag(self, ary, attr):
if ary.dtype in types.complex_domain:
return ary.copy(dtype=ary.dtype.underlying_float, layout='A')
elif ary.dtype in types.number_domain:
res = ary.copy(dtype=ary.dtype)
if attr == 'imag':
res = res.copy(readonly=True)
return res
else:
msg = "cannot access .{} of array of {}"
raise TypingError(msg.format(attr, ary.dtype))
@bound_function("array.transpose")
def resolve_transpose(self, ary, args, kws):
def sentry_shape_scalar(ty):
if ty in types.number_domain:
# Guard against non integer type
if not isinstance(ty, types.Integer):
msg = "transpose() arg cannot be {0}".format(ty)
raise TypingError(msg)
return True
else:
return False
assert not kws
if len(args) == 0:
return signature(self.resolve_T(ary))
if len(args) == 1:
shape, = args
if sentry_shape_scalar(shape):
assert ary.ndim == 1
return signature(ary, *args)
if isinstance(shape, types.NoneType):
return signature(self.resolve_T(ary))
shape = normalize_shape(shape)
if shape is None:
return
assert ary.ndim == shape.count
return signature(self.resolve_T(ary).copy(layout="A"), shape)
else:
if any(not sentry_shape_scalar(a) for a in args):
msg = "transpose({0}) is not supported".format(
', '.join(args))
raise TypingError(msg)
assert ary.ndim == len(args)
return signature(self.resolve_T(ary).copy(layout="A"), *args)
@bound_function("array.copy")
def resolve_copy(self, ary, args, kws):
assert not args
assert not kws
retty = ary.copy(layout="C", readonly=False)
return signature(retty)
@bound_function("array.item")
def resolve_item(self, ary, args, kws):
assert not kws
# We don't support explicit arguments as that's exactly equivalent
# to regular indexing. The no-argument form is interesting to
# allow some degree of genericity when writing functions.
if not args:
return signature(ary.dtype)
if numpy_version < (2, 0):
@bound_function("array.itemset")
def resolve_itemset(self, ary, args, kws):
assert not kws
# We don't support explicit arguments as that's exactly equivalent
# to regular indexing. The no-argument form is interesting to
# allow some degree of genericity when writing functions.
if len(args) == 1:
return signature(types.none, ary.dtype)
@bound_function("array.nonzero")
def resolve_nonzero(self, ary, args, kws):
assert not args
assert not kws
if ary.ndim == 0 and numpy_version >= (2, 1):
raise NumbaValueError(
"Calling nonzero on 0d arrays is not allowed."
" Use np.atleast_1d(scalar).nonzero() instead."
)
# 0-dim arrays return one result array
ndim = max(ary.ndim, 1)
retty = types.UniTuple(types.Array(types.intp, 1, 'C'), ndim)
return signature(retty)
@bound_function("array.reshape")
def resolve_reshape(self, ary, args, kws):
def sentry_shape_scalar(ty):
if ty in types.number_domain:
# Guard against non integer type
if not isinstance(ty, types.Integer):
raise TypingError("reshape() arg cannot be {0}".format(ty))
return True
else:
return False
assert not kws
if ary.layout not in 'CF':
# only work for contiguous array
raise TypingError("reshape() supports contiguous array only")
if len(args) == 1:
# single arg
shape, = args
if sentry_shape_scalar(shape):
ndim = 1
else:
shape = normalize_shape(shape)
if shape is None:
return
ndim = shape.count
retty = ary.copy(ndim=ndim)
return signature(retty, shape)
elif len(args) == 0:
# no arg
raise TypingError("reshape() take at least one arg")
else:
# vararg case
if any(not sentry_shape_scalar(a) for a in args):
raise TypingError("reshape({0}) is not supported".format(
', '.join(map(str, args))))
retty = ary.copy(ndim=len(args))
return signature(retty, *args)
@bound_function("array.sort")
def resolve_sort(self, ary, args, kws):
assert not args
assert not kws
return signature(types.none)
@bound_function("array.argsort")
def resolve_argsort(self, ary, args, kws):
assert not args
kwargs = dict(kws)
kind = kwargs.pop('kind', types.StringLiteral('quicksort'))
if not isinstance(kind, types.StringLiteral):
raise TypingError('"kind" must be a string literal')
if kwargs:
msg = "Unsupported keywords: {!r}"
raise TypingError(msg.format([k for k in kwargs.keys()]))
if ary.ndim == 1:
def argsort_stub(kind='quicksort'):
pass
pysig = utils.pysignature(argsort_stub)
sig = signature(types.Array(types.intp, 1, 'C'), kind).replace(pysig=pysig)
return sig
@bound_function("array.view")
def resolve_view(self, ary, args, kws):
from .npydecl import parse_dtype
assert not kws
dtype, = args
dtype = parse_dtype(dtype)
if dtype is None:
return
retty = ary.copy(dtype=dtype)
return signature(retty, *args)
@bound_function("array.astype")
def resolve_astype(self, ary, args, kws):
from .npydecl import parse_dtype
assert not kws
dtype, = args
if isinstance(dtype, types.UnicodeType):
raise RequireLiteralValue(("array.astype if dtype is a string it "
"must be constant"))
dtype = parse_dtype(dtype)
if dtype is None:
return
if not self.context.can_convert(ary.dtype, dtype):
raise TypingError("astype(%s) not supported on %s: "
"cannot convert from %s to %s"
% (dtype, ary, ary.dtype, dtype))
layout = ary.layout if ary.layout in 'CF' else 'C'
# reset the write bit irrespective of whether the cast type is the same
# as the current dtype, this replicates numpy
retty = ary.copy(dtype=dtype, layout=layout, readonly=False)
return signature(retty, *args)
@bound_function("array.ravel")
def resolve_ravel(self, ary, args, kws):
# Only support no argument version (default order='C')
assert not kws
assert not args
copy_will_be_made = ary.layout != 'C'
readonly = not (copy_will_be_made or ary.mutable)
return signature(ary.copy(ndim=1, layout='C', readonly=readonly))
@bound_function("array.flatten")
def resolve_flatten(self, ary, args, kws):
# Only support no argument version (default order='C')
assert not kws
assert not args
# To ensure that Numba behaves exactly like NumPy,
# we also clear the read-only flag when doing a "flatten"
# Why? Two reasons:
# Because flatten always returns a copy. (see NumPy docs for "flatten")
# And because a copy always returns a writeable array.
# ref: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
return signature(ary.copy(ndim=1, layout='C', readonly=False))
def generic_resolve(self, ary, attr):
# Resolution of other attributes, for record arrays
if isinstance(ary.dtype, types.Record):
if attr in ary.dtype.fields:
attr_dtype = ary.dtype.typeof(attr)
if isinstance(attr_dtype, types.NestedArray):
return ary.copy(
dtype=attr_dtype.dtype,
ndim=ary.ndim + attr_dtype.ndim,
layout='A'
)
else:
return ary.copy(dtype=attr_dtype, layout='A')
@infer_getattr
class DTypeAttr(AttributeTemplate):
key = types.DType
def resolve_type(self, ary):
# Wrap the numeric type in NumberClass
return types.NumberClass(ary.dtype)
def resolve_kind(self, ary):
if isinstance(ary.key, types.scalars.Float):
val = 'f'
elif isinstance(ary.key, types.scalars.Integer):
val = 'i'
else:
return None # other types not supported yet
return types.StringLiteral(val)
@infer
class StaticGetItemArray(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
# Resolution of members for record and structured arrays
ary, idx = args
if (isinstance(ary, types.Array) and isinstance(idx, str) and
isinstance(ary.dtype, types.Record)):
if idx in ary.dtype.fields:
attr_dtype = ary.dtype.typeof(idx)
if isinstance(attr_dtype, types.NestedArray):
ret = ary.copy(
dtype=attr_dtype.dtype,
ndim=ary.ndim + attr_dtype.ndim,
layout='A'
)
return signature(ret, *args)
else:
ret = ary.copy(dtype=attr_dtype, layout='A')
return signature(ret, *args)
@infer_getattr
class RecordAttribute(AttributeTemplate):
key = types.Record
def generic_resolve(self, record, attr):
ret = record.typeof(attr)
assert ret
return ret
@infer
class StaticGetItemRecord(AbstractTemplate):
key = "static_getitem"
def generic(self, args, kws):
# Resolution of members for records
record, idx = args
if isinstance(record, types.Record) and isinstance(idx, str):
if idx not in record.fields:
raise NumbaKeyError(f"Field '{idx}' was not found in record "
"with fields "
f"{tuple(record.fields.keys())}")
ret = record.typeof(idx)
assert ret
return signature(ret, *args)
@infer_global(operator.getitem)
class StaticGetItemLiteralRecord(AbstractTemplate):
def generic(self, args, kws):
# Resolution of members for records
record, idx = args
if isinstance(record, types.Record):
if isinstance(idx, types.StringLiteral):
if idx.literal_value not in record.fields:
msg = (f"Field '{idx.literal_value}' was not found in "
f"record with fields {tuple(record.fields.keys())}")
raise NumbaKeyError(msg)
ret = record.typeof(idx.literal_value)
assert ret
return signature(ret, *args)
elif isinstance(idx, types.IntegerLiteral):
if idx.literal_value >= len(record.fields):
msg = f"Requested index {idx.literal_value} is out of range"
raise NumbaIndexError(msg)
field_names = list(record.fields)
ret = record.typeof(field_names[idx.literal_value])
assert ret
return signature(ret, *args)
@infer
class StaticSetItemRecord(AbstractTemplate):
key = "static_setitem"
def generic(self, args, kws):
# Resolution of members for record and structured arrays
record, idx, value = args
if isinstance(record, types.Record):
if isinstance(idx, str):
expectedty = record.typeof(idx)
if self.context.can_convert(value, expectedty) is not None:
return signature(types.void, record, types.literal(idx),
value)
elif isinstance(idx, int):
if idx >= len(record.fields):
msg = f"Requested index {idx} is out of range"
raise NumbaIndexError(msg)
str_field = list(record.fields)[idx]
expectedty = record.typeof(str_field)
if self.context.can_convert(value, expectedty) is not None:
return signature(types.void, record, types.literal(idx),
value)
@infer_global(operator.setitem)
class StaticSetItemLiteralRecord(AbstractTemplate):
def generic(self, args, kws):
# Resolution of members for records
target, idx, value = args
if isinstance(target, types.Record) and isinstance(idx, types.StringLiteral):
if idx.literal_value not in target.fields:
msg = (f"Field '{idx.literal_value}' was not found in record "
f"with fields {tuple(target.fields.keys())}")
raise NumbaKeyError(msg)
expectedty = target.typeof(idx.literal_value)
if self.context.can_convert(value, expectedty) is not None:
return signature(types.void, target, idx, value)
@infer_getattr
class ArrayCTypesAttribute(AttributeTemplate):
key = types.ArrayCTypes
def resolve_data(self, ctinfo):
return types.uintp
@infer_getattr
class ArrayFlagsAttribute(AttributeTemplate):
key = types.ArrayFlags
def resolve_contiguous(self, ctflags):
return types.boolean
def resolve_c_contiguous(self, ctflags):
return types.boolean
def resolve_f_contiguous(self, ctflags):
return types.boolean
@infer_getattr
class NestedArrayAttribute(ArrayAttribute):
key = types.NestedArray
def _expand_integer(ty):
"""
If *ty* is an integer, expand it to a machine int (like Numpy).
"""
if isinstance(ty, types.Integer):
if ty.signed:
return max(types.intp, ty)
else:
return max(types.uintp, ty)
elif isinstance(ty, types.Boolean):
return types.intp
else:
return ty
def generic_homog(self, args, kws):
if args:
raise NumbaAssertionError("args not supported")
if kws:
raise NumbaAssertionError("kws not supported")
return signature(self.this.dtype, recvr=self.this)
def generic_expand(self, args, kws):
assert not args
assert not kws
return signature(_expand_integer(self.this.dtype), recvr=self.this)
def sum_expand(self, args, kws):
"""
sum can be called with or without an axis parameter, and with or without
a dtype parameter
"""
pysig = None
if 'axis' in kws and 'dtype' not in kws:
def sum_stub(axis):
pass
pysig = utils.pysignature(sum_stub)
# rewrite args
args = list(args) + [kws['axis']]
elif 'dtype' in kws and 'axis' not in kws:
def sum_stub(dtype):
pass
pysig = utils.pysignature(sum_stub)
# rewrite args
args = list(args) + [kws['dtype']]
elif 'dtype' in kws and 'axis' in kws:
def sum_stub(axis, dtype):
pass
pysig = utils.pysignature(sum_stub)
# rewrite args
args = list(args) + [kws['axis'], kws['dtype']]
args_len = len(args)
assert args_len <= 2
if args_len == 0:
# No axis or dtype parameter so the return type of the summation is a scalar
# of the type of the array.
out = signature(_expand_integer(self.this.dtype), *args,
recvr=self.this)
elif args_len == 1 and 'dtype' not in kws:
# There is an axis parameter, either arg or kwarg
if self.this.ndim == 1:
# 1d reduces to a scalar
return_type = _expand_integer(self.this.dtype)
else:
# the return type of this summation is an array of dimension one
# less than the input array.
return_type = types.Array(dtype=_expand_integer(self.this.dtype),
ndim=self.this.ndim-1, layout='C')
out = signature(return_type, *args, recvr=self.this)
elif args_len == 1 and 'dtype' in kws:
# No axis parameter so the return type of the summation is a scalar
# of the dtype parameter.
from .npydecl import parse_dtype
dtype, = args
dtype = parse_dtype(dtype)
out = signature(dtype, *args, recvr=self.this)
elif args_len == 2:
# There is an axis and dtype parameter, either arg or kwarg
from .npydecl import parse_dtype
dtype = parse_dtype(args[1])
return_type = dtype
if self.this.ndim != 1:
# 1d reduces to a scalar, 2d and above reduce dim by 1
# the return type of this summation is an array of dimension one
# less than the input array.
return_type = types.Array(dtype=return_type,
ndim=self.this.ndim-1, layout='C')
out = signature(return_type, *args, recvr=self.this)
else:
pass
return out.replace(pysig=pysig)
def generic_expand_cumulative(self, args, kws):
if args:
raise NumbaAssertionError("args unsupported")
if kws:
raise NumbaAssertionError("kwargs unsupported")
assert isinstance(self.this, types.Array)
return_type = types.Array(dtype=_expand_integer(self.this.dtype),
ndim=1, layout='C')
return signature(return_type, recvr=self.this)
def generic_hetero_real(self, args, kws):
assert not args
assert not kws
if isinstance(self.this.dtype, (types.Integer, types.Boolean)):
return signature(types.float64, recvr=self.this)
return signature(self.this.dtype, recvr=self.this)
def generic_hetero_always_real(self, args, kws):
assert not args
assert not kws
if isinstance(self.this.dtype, (types.Integer, types.Boolean)):
return signature(types.float64, recvr=self.this)
if isinstance(self.this.dtype, types.Complex):
return signature(self.this.dtype.underlying_float, recvr=self.this)
return signature(self.this.dtype, recvr=self.this)
def generic_index(self, args, kws):
assert not args
assert not kws
return signature(types.intp, recvr=self.this)
def install_array_method(name, generic, prefer_literal=True):
my_attr = {"key": "array." + name, "generic": generic,
"prefer_literal": prefer_literal}
temp_class = type("Array_" + name, (AbstractTemplate,), my_attr)
def array_attribute_attachment(self, ary):
return types.BoundFunction(temp_class, ary)
setattr(ArrayAttribute, "resolve_" + name, array_attribute_attachment)
# Functions that return a machine-width type, to avoid overflows
install_array_method("sum", sum_expand, prefer_literal=True)
@infer_global(operator.eq)
class CmpOpEqArray(AbstractTemplate):
#key = operator.eq
def generic(self, args, kws):
assert not kws
[va, vb] = args
if isinstance(va, types.Array) and va == vb:
return signature(va.copy(dtype=types.boolean), va, vb)