/
pad.py
219 lines (179 loc) · 8.32 KB
/
pad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
.. autofunction:: pad
"""
__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni"
from pytato.array import Array, IndexLambda
from pytato.scalar_expr import IntegralT, INT_CLASSES, ScalarType
from typing import Union, Sequence, Any, Tuple, List, Dict
from pytools import UniqueNameGenerator
import collections.abc as abc
import pymbolic.primitives as prim
import numpy as np
def _get_constant_padded_idx_lambda(
array: Array,
pad_widths: Sequence[Tuple[IntegralT, IntegralT]],
constant_vals: Sequence[Tuple[ScalarType, ScalarType]]
) -> IndexLambda:
"""
Internal routine used by :func:`pad` for constant-mode padding.
"""
from pytato.array import make_index_lambda
assert array.ndim == len(pad_widths) == len(constant_vals)
array_name_in_expr = "in_0"
bindings: Dict[str, Array] = {array_name_in_expr: array}
vng = UniqueNameGenerator()
vng.add_name(array_name_in_expr)
expr = prim.Variable(array_name_in_expr)[
tuple((prim.Variable(f"_{idim}") - pad_width[0])
for idim, pad_width in enumerate(pad_widths))]
for idim, (pad_width, constant_val) in enumerate(zip(pad_widths, constant_vals)):
idx_var = prim.Variable(f"_{idim}")
axis_len = array.shape[idim]
if isinstance(axis_len, Array):
binding_name = vng("in_0")
bindings[binding_name] = axis_len + pad_width[0]
expr = prim.If(
prim.Comparison(idx_var, ">=", prim.Variable(binding_name)),
constant_val[1], expr)
else:
assert isinstance(axis_len, INT_CLASSES)
expr = prim.If(
prim.Comparison(idx_var, ">=", axis_len + pad_width[0]),
constant_val[1], expr)
expr = prim.If(prim.Comparison(idx_var, "<", pad_width[0]),
constant_val[0],
expr)
return make_index_lambda(
expr,
bindings,
shape=tuple(axis_len + pad_width[0] + pad_width[1]
for axis_len, pad_width in zip(array.shape, pad_widths)),
dtype=array.dtype)
def _normalize_pad_width(
array: Array,
pad_width: Union[IntegralT, Sequence[IntegralT]],
) -> Sequence[Tuple[IntegralT, IntegralT]]:
processed_pad_widths: List[Tuple[IntegralT, IntegralT]]
if isinstance(pad_width, INT_CLASSES):
processed_pad_widths = [(pad_width, pad_width)
for _ in range(array.ndim)]
elif (isinstance(pad_width, abc.Sequence)
and len(pad_width) == 1
and isinstance(pad_width, INT_CLASSES)):
processed_pad_widths = [(pad_width[0], pad_width[0])
for _ in range(array.ndim)]
elif (isinstance(pad_width, abc.Sequence)
and len(pad_width) == 2
and isinstance(pad_width[0], INT_CLASSES)
and isinstance(pad_width[1], INT_CLASSES)
):
processed_pad_widths = [(pad_width[0], pad_width[1])] * array.ndim
elif isinstance(pad_width, abc.Sequence):
if len(pad_width) != array.ndim:
raise ValueError(f"Number of pad widths != {array.ndim}"
" (the array's dimension)")
processed_pad_widths = []
for k in pad_width:
if (isinstance(k, tuple)
and len(k) == 2
and isinstance(k[0], INT_CLASSES)
and isinstance(k[1], INT_CLASSES)):
processed_pad_widths.append(k)
else:
raise ValueError("Elements of pad_width must be of type"
f" `Tuple[int, int]`, got '{k}'.")
if all(isinstance(k, INT_CLASSES) for k in pad_width):
processed_pad_widths = [(k, k) for k in pad_width]
else:
raise TypeError("'pad_width' can be an int or "
" sequence of pad widths along each"
" direction.")
return processed_pad_widths
def pad(array: Array,
pad_width: Union[IntegralT, Sequence[IntegralT]],
mode: str = "constant",
**kwargs: Any) -> Array:
r"""
Returns an array with padded elements along each axis.
:param array: The array to be padded.
:param pad_width: Number of elements to be padded along each axis. Can be
one of:
- An instance of :class:`int` denoting the constant number of elements
to pad before and after each axis.
- A tuple of the form ``(before, after)`` denoting that *before* number
of padded elements must precede each axis and *after* number of
padded elements must succeed each axis.
- A sequence with i-th element as the tuple ``(before_i, after_i)``
denoting that *before_i* number of padded elements must precede the
i-th axis and *after_i* number of padded elements must succeed the
i-th axis.
:param mode: An instance of :class:`str` denoting the values of the padded
elements in the returned array. It can be one of:
- ``"constant"`` denoting that the padded elements must be filled with
constant entries. See *constant_values*.
:param constant_values: Optional argument when operating under
``"constant"`` *mode*. Can be one of:
- An instance of :class:`int` denoting the value of every padded
element.
- A :class:`tuple` of the form ``(before, after)`` denoting that every
padded element that precedes *array*'s axes must be set to
*before* and every padded element that succeeds *array*'s axes must
be set to *after*.
- A sequence with the i-th element of the form ``(before_i, after_i)``
denoting that the padded elements preceding *array*'s i-th axis must
be set to *before_i* and the padded elements succeeding *array*'s
i-th axis must be set to *after_i*.
Defaults to *0*.
.. note::
As of March, 2023 the values of the padded elements that are preceding
wrt certain axes and succeeding wrt other axes is undefined as per
:func:`numpy.pad`\ 's spec.
"""
processed_pad_widths = _normalize_pad_width(array, pad_width)
if mode == "constant":
# {{{ normalize constant_values
processed_constant_vals: Sequence[Tuple[ScalarType, ScalarType]]
try:
constant_vals = kwargs.pop("constant_values")
except KeyError:
processed_constant_vals = [(0, 0) for _ in range(array.ndim)]
else:
if np.isscalar(constant_vals):
# type-ignore-reason: mypy does not understand the guarding
# predicate
processed_constant_vals = [
(constant_vals, constant_vals) # type: ignore[misc]
for _ in range(array.ndim)]
elif (isinstance(constant_vals, tuple)
and len(constant_vals) == 2
and np.isscalar(constant_vals[0])
and np.isscalar(constant_vals[1])
):
processed_constant_vals = [constant_vals for _ in range(array.ndim)]
elif isinstance(constant_vals, abc.Sequence):
if len(constant_vals) != array.ndim:
raise ValueError("")
processed_constant_vals = []
for constant_val in constant_vals:
if (isinstance(constant_val, tuple)
and len(constant_val) == 2
and np.isscalar(constant_val[0])
and np.isscalar(constant_val[1])):
processed_constant_vals.append(constant_val)
else:
raise ValueError(
"Elements of `constant_vals` must be of type"
f"Tuple[int, int], got '{constant_val}'")
else:
raise TypeError("`constant_vals` must be of type int"
f" or a sequence of ints, got '{constant_vals}'")
# }}}
idx_lambda = _get_constant_padded_idx_lambda(
array, processed_pad_widths, processed_constant_vals)
else:
raise NotImplementedError(f"Mode: '{mode}'")
if kwargs:
raise ValueError(f"Some options ('{kwargs.keys()}')"
" were left unused.")
return idx_lambda
# vim: fdm=marker