-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
_shape_poly.py
2177 lines (1862 loc) · 83.8 KB
/
_shape_poly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shape polymorphism support.
We introduce a set of dimension variables at the top-level of a `jit` function.
They are introduced implicitly by way of specifying for each dimension of each
argument a symbolic dimension expression in terms of some dimension variables.
All dimension variables are assumed to range over integers greater or equal to 1.
Symbolic dimensions overload some integer operations, such as
add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been
touched up to be sensitive to handling shapes that contain symbolic dimensions.
This enables many JAX programs to be traced with symbolic dimensions
in some dimensions. A priority has been to enable the batch
dimension in neural network examples to be polymorphic.
This was built initially for jax2tf, but it is now
independent of TF. The best documentation at the moment is in the
jax2tf.convert docstring, and the
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
"""
from __future__ import annotations
import collections
import enum
from collections.abc import Iterable, Sequence
import dataclasses
from enum import Enum
import functools
import itertools
import io
import copy
import operator as op
import threading
import tokenize
from typing import Any, Callable, Union
import warnings
import numpy as np
import opt_einsum
import jax
from jax.interpreters import xla
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src.lax import lax
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
DimSize = Union["_DimExpr", int]
TfVal = Any
DimVarEnv = dict[str, jax.Array]
DType = Any
# Tuples of terms and their coefficients, sorted with the largest term first.
SortedTerms = Sequence[tuple["_DimMon", int]]
# Normalization rules represent the explicit constraint `t*tk == e` as
# a mapping of `t` to `(e, tk)`.
NormalizationRules = dict["_DimMon", tuple["_DimExpr", int]]
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
_help_msg = """
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison_of_symbolic_dimensions_is_partially_supported
for more details.
"""
def __init__(self, message: str):
error_msg = f"{message}{InconclusiveDimensionOperation._help_msg}"
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore
class _ShapePolyThreadLocalState(threading.local):
def __init__(self):
# TODO(necula): this does not play well with some lowering caches, because
# this state is not part of the cache key.
self.enable_shape_assertions = True
thread_local_state = _ShapePolyThreadLocalState()
class Comparator(Enum):
EQ = 1
GEQ = 2
@dataclasses.dataclass(frozen=True)
class _SymbolicConstraint:
cmp: Comparator
debug_str: str # The form in which the user expressed it, for error messages
diff: _DimExpr # For GEQ: diff >= 0, and for EQ: diff == 0
def __repr__(self):
return f"Constraint({self.debug_str}: {self.diff})"
class _DimAtom:
"""Represents an atom in a symbolic dimension expression.
Atoms are either variables, or expressions of the form floordiv(E1, E2) or
mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and
monomials are added to form symbolic expressions (see _DimExpr).
Args:
* var: if specified then the atom is a dimension variable. `operation`
must be `None`.
* operation: if specified then the atom is an operation applied to
`operands`. One of `FLOORDIR` or `MOD` or `NON_NEGATIVE`. `var` must be `None`
* operands: the operands to which the operation is applied.
"""
# The supported operations
# FLOORDIV(e1, e2) and MOD(e1, e2) have the same semantics as in Python:
# FLOORDIV(e1, e2) = e1 // e2 = floor(e1 / e2)
# if e2 > 0 then 0 <= MOD(e1, e2) < e2
# if e2 < 0 then e2 < MOD(e1, e2) <= 0
# e1 = e2 * FLOORDIV(e1, e2) + MOD(e1, e2)
#
FLOORDIV = "floordiv"
MOD = "mod"
MAX = "max"
MIN = "min"
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
def __init__(self, *operands: _DimExpr,
var: str | None = None,
operation: str | None = None):
if var is not None:
assert operation is None
assert not operands
else:
assert operation is not None
self.var = var
self.operation = operation
self.operands = operands
# Precompute the hash (used extensively because these are kept in
# dictionaries) and the size (used for sorting, which is important for
# some of the reasoning). It is important for the size of _DimAtom,
# _DimMon, or _DimExpr, to be strictly larger than the size of any
# constituent sub-item.
self._hash = hash((self.var, self.operation, *self.operands))
self._size = 1 if var is not None else 1 + sum(o._size for o in operands)
@classmethod
def from_var(cls, v: str) -> _DimAtom:
return _DimAtom(var=v)
@classmethod
@functools.cache # use caching to ensure we can compare _DimAtom by id
def from_operation(cls, operation: str, *operands: DimSize,
scope: SymbolicScope) -> _DimAtom:
return _DimAtom(*(_ensure_poly(o, operation, scope) for o in operands),
operation=operation)
def to_var(self) -> str | None:
return self.var
def get_vars(self) -> set[str]:
# All the vars that appear
if self.var is not None:
return {self.var}
else:
acc = set()
for opnd in self.operands:
acc.update(opnd.get_vars())
return acc
def __str__(self):
if self.var is not None:
return self.var
opnd_str = ", ".join([str(opnd) for opnd in self.operands])
return f"{self.operation}({opnd_str})"
__repr__ = __str__
def __hash__(self):
return self._hash
def _syntactic_cmp(self, other: _DimAtom) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if c := cmp_comparable(self._size, other._size): return c
if self.var is not None:
return cmp_comparable(self.var, other.var)
if c := cmp_comparable(self.operation, other.operation): return c # type: ignore
return cmp_sequence(self.operands, other.operands,
lambda s_o, o_o: s_o._syntactic_cmp(o_o))
def __eq__(self, other: Any):
"""Lexicographic comparison."""
if not isinstance(other, _DimAtom): return False
return self._syntactic_cmp(other) == 0
def __lt__(self, other: _DimAtom):
"""Lexicographic comparison."""
return self._syntactic_cmp(other) < 0
def __le__(self, other: _DimAtom):
"""Lexicographic comparison."""
return self._syntactic_cmp(other) <= 0
def __gt__(self, other: _DimAtom):
"""Lexicographic comparison."""
return self._syntactic_cmp(other) > 0
def __ge__(self, other: _DimAtom):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0
def evaluate(self, env: DimVarEnv):
if self.var is not None:
try:
return env[self.var]
except KeyError:
err_msg = (
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n"
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
raise KeyError(err_msg)
else:
operand_values = [opnd.evaluate(env) for opnd in self.operands]
if self.operation == _DimAtom.FLOORDIV:
return divmod(*operand_values)[0] # type: ignore
elif self.operation == _DimAtom.MOD:
return divmod(*operand_values)[1] # type: ignore
elif self.operation == _DimAtom.MAX:
op1, op2 = operand_values
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
return max(op1, op2)
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
return core.max_dim(op1, op2)
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
return lax.max(op1, op2)
elif self.operation == _DimAtom.MIN:
op1, op2 = operand_values
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
return min(op1, op2)
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
return core.min_dim(op1, op2)
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
return lax.min(op1, op2)
else:
assert False, self.operation
def __deepcopy__(self, memo):
return _DimAtom(*copy.deepcopy(self.operands, memo),
var=copy.deepcopy(self.var, memo),
operation=copy.deepcopy(self.operation, memo))
class _DimMon(dict):
"""Represents a multiplication of atoms.
The representation is a dictionary mapping _DimAtom to exponent.
The exponents are integers >= 1.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hash = hash(frozenset(self.items()))
self._size = sum((1 + a_exp * a._size) for a, a_exp in self.items())
def __hash__(self):
return self._hash
def __str__(self):
return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key)
for key, exponent in sorted(self.items()))
__repr__ = __str__
@classmethod
def from_var(cls, v: str) -> _DimMon:
return _DimMon({_DimAtom.from_var(v): 1})
@classmethod
def from_atom(clscls, a: _DimAtom, aexp: int):
return _DimMon({a: aexp})
@classmethod
def from_operation(cls, operation: str, *operands: DimSize,
scope: SymbolicScope) -> _DimMon:
return _DimMon({_DimAtom.from_operation(operation, *operands,
scope=scope): 1})
def to_var(self) -> str | None:
"""Extract the variable name from a monomial.
Return None if the monomial is not a single variable."""
a = self.to_atom()
return a.to_var() if a is not None else None
def to_atom(self) -> _DimAtom | None:
"""Extract the single atom from a monomial.
Return None if the monomial is not a single atom."""
items = self.items()
if len(items) != 1:
return None
(a, aexp), = items
if aexp != 1:
return None
return a
def get_vars(self) -> set[str]:
# All the vars that appear in the monomial
acc = set()
for a in self.keys():
acc.update(a.get_vars())
return acc
@property
def degree(self):
return sum(self.values())
def _syntactic_cmp(self, other: _DimMon) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if c := cmp_comparable(self._size, other._size): return c
if c := cmp_comparable(self.degree, other.degree): return c
def cmp_atom(s_a: tuple[_DimAtom, int], o_a: tuple[_DimAtom, int]) -> int:
if c := s_a[0]._syntactic_cmp(o_a[0]): return c
# Consider the monomials with exponents to be expanded as multiplications.
# Then a higher exponent for a "small" atom should lead to a "smaller" monomial.
return - cmp_comparable(s_a[1], o_a[1])
return cmp_sequence(sorted(self.items()), sorted(other.items()), cmp_atom)
def __lt__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) < 0
def __le__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) <= 0
def __gt__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) > 0
def __ge__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0
def mul(self, other: _DimMon) -> _DimMon:
"""
Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m.
"""
return _DimMon(collections.Counter(self) + collections.Counter(other))
def divide(self, divisor: _DimMon) -> _DimMon:
"""
Divides by another monomial. Raises a InconclusiveDimensionOperation
if the result is not a monomial.
For example, (n^3 * m) // n == n^2*m, but n // m fails.
"""
d = collections.Counter(self)
for key, exponent in divisor.items():
diff = self.get(key, 0) - exponent
if diff < 0:
raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
elif diff == 0: del d[key]
elif diff > 0: d[key] = diff
return _DimMon(d)
def evaluate(self, env: DimVarEnv):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
def pow_opt(v, p: int):
return v if p == 1 else prod([v] * p)
return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()])
def __deepcopy__(self, memo):
return _DimMon(copy.deepcopy(dict(self), memo))
class _DimExpr:
"""Symbolic expression in terms of dimension variables.
A dimension expression is an addition of terms (_DimMon), which themselves
are products of factors (_DimAtom).
The representation of a _DimExpr is as sequence of pairs `(_DimMon, coeff)`,
representing the linear combination of terms with the given coefficients.
The sequence is sorted by lexicographic (syntactic) ordering of `_DimMon`,
with the largest terms first. The special monomial `_DimMon()` is mapped
to the free integer coefficient of the expression.
We overload integer operations, but we do that soundly, raising
:class:`InconclusiveDimensionOperation` when the result is not
representable as a _DimExpr.
"""
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
def __init__(self, terms: SortedTerms,
scope: SymbolicScope):
# Do not construct _DimExpr directly, unless you are sure that terms is
# normalized; Use _DimExpr.normalize.
self._monomials_sorted = tuple(terms) or ((_DimMon(), 0),)
self._scope = scope
self._hash = hash((self._monomials_sorted, self.scope))
# _size speeds up _syntactic_cmp, which is used a lot for hashing.
self._size = sum((1 + abs(m_count) * m._size)
for m, m_count in self._monomials_sorted)
@property
def scope(self):
# We make the expression scope visible, but read-only.
return self._scope
@classmethod
def _coeff_dict_to_terms(cls, coeffs: dict[_DimMon, int]) -> SortedTerms:
return sorted(coeffs.items(), reverse=True)
@classmethod
def from_dict(cls, coeffs: dict[_DimMon, int], scope: SymbolicScope) -> _DimExpr:
return _DimExpr(_DimExpr._coeff_dict_to_terms(coeffs), scope)
@classmethod
def from_monomial(cls, mon: _DimMon, exp: int, scope: SymbolicScope):
return _DimExpr.normalize_coeffs({mon: exp}, scope)
@classmethod
def from_constant(cls, c: int, scope: SymbolicScope):
return _DimExpr(((_DimMon(), op.index(c)),), scope)
@classmethod
def from_var(cls, v: str, scope: SymbolicScope) -> _DimExpr:
return _DimExpr(((_DimMon.from_var(v), 1),), scope)
@classmethod
def from_operation(cls, operation: str, *operands: DimSize,
scope: SymbolicScope) -> _DimExpr:
if operation == _DimAtom.NON_NEGATIVE: # For parsing
return _DimExpr.from_monomial(
_DimMon.from_operation(_DimAtom.MAX, *operands, 0,
scope=scope), 1,
scope=scope)
return _DimExpr.from_monomial(
_DimMon.from_operation(operation, *operands, scope=scope), 1,
scope=scope)
def monomials(self) -> Iterable[tuple[_DimMon, int]]:
"""The monomials in sorted reverse lexicographic order.
Higher-degree monomials come earlier in the order.
"""
return self._monomials_sorted
@property
def leading_term(self) -> tuple[_DimMon, int]:
"""Returns the highest degree term that comes last lexicographically."""
return self._monomials_sorted[0]
def to_single_term(self) -> tuple[int, int, _DimMon] | None:
"""Extracts the single term: k + c * term.
Returns None if the expression is not a single term, or (k, c, term)
"""
n1 = 0
n2 = 0
mon = None
for m, c in self.monomials():
if m.degree == 0:
n1 = c
continue
if mon is None:
mon = m
n2 = c
continue
return None
assert mon is not None
return (n1, n2, mon)
@classmethod
def add_coeff(cls, coeffs: dict[_DimMon, int], t: _DimMon, coeff: int):
"""coeffs[t] += coeff, with squashing 0 coefficients."""
if coeff == 0: return
n_coeff = coeffs.get(t, 0) + coeff
if n_coeff == 0:
del coeffs[t]
else:
coeffs[t] = n_coeff
@classmethod
def normalize_coeffs(cls,
coeffs: dict[_DimMon, int],
scope: SymbolicScope) -> DimSize:
"""Constructs a _DimExpr in normal form.
Ensures that the symbolic dimension is normalized, e.g., does not
have terms with coefficient 0, it reflects all the scope
normalization_rules, and it is represented as a Python integer if it is
known to be a constant.
Does not attempt to normalize the keys (terms) inside `coeffs`.
"""
new_coeffs = coeffs.copy()
for t, t_k in coeffs.items():
if t_k == 0:
del new_coeffs[t]
continue
after, t_k_after = scope._normalization_rules.get(t, (None, 0))
if after is not None and t_k % t_k_after == 0:
# We have t*t_k_after -> after.
# We subtract `t*t_k` and add `after * (- (t_k // t_k_after))`.
_DimExpr.add_coeff(new_coeffs, t, - t_k)
for t2, tc2 in after._monomials_sorted:
_DimExpr.add_coeff(new_coeffs, t2, tc2 * (t_k // t_k_after))
elif t.degree > 1: # A product of factors; look up individually
for f, fexp in t.items():
f_after, f_k_after = scope._normalization_rules.get(_DimMon({f: fexp}), (None, 0))
if f_after is not None and t_k % f_k_after == 0:
# We subtract `t*t_k`.
_DimExpr.add_coeff(new_coeffs, t, - t_k)
t_without_f = t.divide(_DimMon({f: fexp}))
# And add `(t // f**fexp) * f_after * (- (t_k // f_k_after))`
for t2, tc2 in f_after._monomials_sorted:
_DimExpr.add_coeff(new_coeffs, t2.mul(t_without_f), tc2 * (t_k // f_k_after))
break
new_terms = _DimExpr._coeff_dict_to_terms(new_coeffs)
if not new_terms: return 0
if new_terms[0][0].degree == 0: return new_terms[0][1]
return _DimExpr(new_terms, scope)
def to_monomial(self) -> _DimMon | None:
"""Extract the single monomial from a symbolic expression.
Returns None if the expression is not a single monomial."""
(mon, mon_count), *rest = self.monomials()
if rest: # type: ignore
return None
return mon if mon_count == 1 else None
def to_atom(self) -> _DimAtom | None:
"""Extract the atom from a symbolic expression.
Returns None if the expression is not a single atom."""
mon = self.to_monomial()
return mon.to_atom() if mon is not None else None
def to_var(self) -> str | None:
"""Extract the variable name from a symbolic expression.
Returns None if the expression is not a single variable."""
mon = self.to_atom()
return mon.to_var() if mon is not None else None
@classmethod
def to_constant(cls, e: DimSize) -> int | None:
"""Extract the constant from a symbolic expression.
Returns None if the expression is not a single constant."""
if not isinstance(e, _DimExpr):
return int(e)
m, m_c = e.leading_term
return m_c if m.degree == 0 else None
@property
def is_constant(self):
return _DimExpr.to_constant(self) is not None
def get_vars(self) -> set[str]:
"""The variables that appear in a symbolic dimension."""
acc = set()
for mon, _ in self.monomials():
acc.update(mon.get_vars())
return acc
@classmethod
def _linear_combination(
cls,
e1: SortedTerms, i1: int, f1: int,
e2: SortedTerms, i2: int, f2: int) -> SortedTerms:
"""Computes e1[i1:] * f1 + e2[i2:] * f2.
e1, e2, and the result are sorted with largest term first.
This is an optimization for a common operation. The unoptimized code would
compute each subexpression in term.
"""
acc = []
while i1 < len(e1) and i2 < len(e2):
m1, m1_c = e1[i1]
m2, m2_c = e2[i2]
cmp = m1._syntactic_cmp(m2) # Pick the largest monomial
if cmp < 0:
acc.append((m2, m2_c * f2))
i2 += 1
elif cmp > 0:
acc.append((m1, m1_c * f1))
i1 += 1
else: # They are equal, combine them
i1 += 1
i2 += 1
m1_c = m1_c * f1 + m2_c * f2
if m1_c == 0: continue
acc.append((m1, m1_c))
acc.extend((m1, m1_c * f1) for m1, m1_c in itertools.islice(e1, i1, len(e1)) if m1_c != 0)
acc.extend((m2, m2_c * f2) for m2, m2_c in itertools.islice(e2, i2, len(e2)) if m2_c != 0)
return acc
def _syntactic_cmp(self, other: _DimExpr) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
s_mons = self._monomials_sorted
o_mons = other._monomials_sorted
if c := cmp_comparable(self._size, other._size): return c
def cmp_mon(s_mon: tuple[_DimMon, int], o_mon: tuple[_DimMon, int]) -> int:
if c := s_mon[0]._syntactic_cmp(o_mon[0]): return c
return cmp_comparable(s_mon[1], o_mon[1])
return cmp_sequence(s_mons, o_mons, cmp_mon)
def eq(self, other: _DimExpr) -> bool:
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
if is_symbolic_dim(diff):
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
return False
return diff == 0
def __hash__(self):
return self._hash
def __str__(self):
def _one_monomial(mon, c):
abs_c = abs(c)
sgn_c = "+" if c > 0 else "-"
if mon.degree == 0:
return f"{sgn_c} {abs_c}" if abs_c != 0 else "0"
if abs_c == 1:
return f"{sgn_c} {mon}"
return f"{sgn_c} {abs_c}*{mon}"
# We print first the "larger" monomials, so that the constant is last.
res = " ".join(_one_monomial(mon, c)
for mon, c in self._monomials_sorted)
if res.startswith("+ "):
res = res[2:]
return res
def __repr__(self):
return str(self)
# We overload +, -, *, because they are fully defined for _DimExpr.
def __add__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__add__(other)
if isinstance(other, int) and other == 0: return self
other = _ensure_poly(other, "add", self.scope)
coeffs = dict(self._monomials_sorted)
for mon, coeff in other.monomials():
_DimExpr.add_coeff(coeffs, mon, coeff)
return _DimExpr.normalize_coeffs(coeffs, self.scope)
def __radd__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__radd__(other)
if isinstance(other, int) and other == 0: return self
return _ensure_poly(other, "add", self.scope).__add__(self)
def __sub__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__sub__(other)
if isinstance(other, int) and other == 0: return self
return self + -_ensure_poly(other, "sub", self.scope)
def __rsub__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__rsub__(other)
return _ensure_poly(other, "sub", self.scope).__sub__(self)
def __neg__(self) -> _DimExpr:
return _DimExpr(
tuple((t, -coeff) for t, coeff in self._monomials_sorted),
self.scope)
def __mul__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__mul__(other)
if isinstance(other, int) and other == 1: return self
other = _ensure_poly(other, "mul", self.scope)
coeffs: dict[_DimMon, int] = {}
for mon1, coeff1 in self.monomials():
for mon2, coeff2 in other.monomials():
mon = mon1.mul(mon2)
_DimExpr.add_coeff(coeffs, mon, coeff1 * coeff2)
return _DimExpr.normalize_coeffs(coeffs, self.scope)
def __rmul__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__rmul__(other)
if isinstance(other, int) and other == 1: return self
return _ensure_poly(other, "mul", self.scope).__mul__(self)
def __pow__(self, power, modulo=None):
assert modulo is None
try:
power = int(power)
except:
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
return functools.reduce(op.mul, [self] * power)
def __floordiv__(self, divisor):
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
return self.__jax_array__().__floordiv__(divisor)
return self.divmod(_ensure_poly(divisor, "floordiv", self.scope))[0]
def __rfloordiv__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
return self.__jax_array__().__rfloordiv__(other)
return _ensure_poly(other, "floordiv", self.scope).__floordiv__(self)
def __truediv__(self, divisor):
# Used for "/", which always returns a float
return self.__jax_array__().__truediv__(divisor)
def __rtruediv__(self, dividend):
# Used for "/", when dividend is not a _DimExpr
return self.__jax_array__().__rtruediv__(dividend)
def __mod__(self, divisor):
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
return self.__jax_array__().__mod__(divisor)
return self.divmod(_ensure_poly(divisor, "mod", self.scope))[1]
def __rmod__(self, dividend):
if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend):
return self.__jax_array__().__rmod__(dividend)
return _ensure_poly(dividend, "mod", self.scope).__mod__(self)
def __divmod__(self, divisor):
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
return self.__jax_array__().__divmod__(divisor)
return self.divmod(_ensure_poly(divisor, "divmod", self.scope))
def __rdivmod__(self, dividend):
if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend):
return self.__jax_array__().__rdivmod__(dividend)
return _ensure_poly(dividend, "divmod", self.scope).__divmod__(self)
def __int__(self):
if (c := _DimExpr.to_constant(self)) is not None:
return c
raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant")
# We must overload __eq__ and __ne__, or else we get unsound defaults.
def __eq__(self, other: Any) -> bool:
if isinstance(other, _DimExpr):
if self.scope is not other.scope:
return False
elif not core.is_constant_dim(other):
return False
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
if is_symbolic_dim(diff):
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
return False
return diff == 0
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __ge__(self, other: DimSize) -> bool:
return _geq_decision(self, other, lambda: f"'{self}' >= '{other}'")
def __le__(self, other: DimSize):
return _geq_decision(other, self, lambda: f"'{self}' <= '{other}'")
def __gt__(self, other: DimSize):
return not _geq_decision(other, self, lambda: f"'{self}' > '{other}'")
def __lt__(self, other: DimSize):
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")
def divmod(self, divisor: _DimExpr) -> tuple[DimSize, int]:
"""
Floor division with remainder (divmod) generalized to polynomials.
If the `divisor` is not a constant, the remainder must be 0.
If the `divisor` is a constant, the remainder may be non 0, for consistency
with integer divmod.
:return: Quotient resulting from polynomial division and integer remainder.
"""
assert isinstance(divisor, _DimExpr)
try:
dmon, dcount = divisor.leading_term
dividend, quotient = self, 0
# invariant: self = dividend + divisor * quotient
# quotient and dividend are changed in the loop; the leading term of
# dividend decreases at each iteration.
while is_symbolic_dim(dividend) and not dividend.is_constant:
mon, count = dividend.leading_term
try:
qmon = mon.divide(dmon)
except InconclusiveDimensionOperation:
raise InconclusiveDimensionOperation("")
qcount, rcount = divmod(count, dcount)
if rcount != 0:
raise InconclusiveDimensionOperation("")
q = _DimExpr.from_monomial(qmon, qcount, self.scope)
quotient += q
dividend -= q * divisor # type: ignore[assignment]
dividend = int(dividend) # type: ignore[assignment]
if divisor.is_constant:
q, r = divmod(dividend, int(divisor)) # type: ignore
quotient += q
remainder = r
else:
if dividend != 0:
raise InconclusiveDimensionOperation("")
remainder = 0
if config.enable_checks.value:
v1 = divisor * quotient
v2 = v1 + remainder
assert self == v2, (self, v2, type(self), type(v2))
assert self == divisor * quotient + remainder, (self, divisor, quotient, remainder)
return quotient, remainder
except InconclusiveDimensionOperation:
return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor,
scope=self.scope), # type: ignore
_DimExpr.from_operation(_DimAtom.MOD, self, divisor,
scope=self.scope))
def evaluate(self, env: DimVarEnv):
# Evaluates as a value of dtype=core.dim_value_dtype()
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff))
for mon, coeff in self.monomials()]
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
def max(self, other: DimSize) -> DimSize:
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
if 0 <= lb: return self
if ub <= 0: return other
return _DimExpr.from_operation(_DimAtom.MAX, self, other, scope=self.scope)
def rmax(self, other: DimSize) -> DimSize:
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
if 0 <= lb: return self
if ub <= 0: return other
return _DimExpr.from_operation(_DimAtom.MAX, other, self, scope=self.scope)
def min(self, other: DimSize) -> DimSize:
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
if 0 <= lb: return other
if ub <= 0: return self
return _DimExpr.from_operation(_DimAtom.MIN, self, other, scope=self.scope)
def rmin(self, other: DimSize) -> DimSize:
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
if 0 <= lb: return other
if ub <= 0: return self
return _DimExpr.from_operation(_DimAtom.MIN, other, self, scope=self.scope)
@staticmethod
def get_aval(dim: _DimExpr):
return core.dim_value_aval()
def dimension_as_value(self):
"""Turns a dimension size into a Jax value that we can compute with."""
return _dim_as_value(self)
def __jax_array__(self):
# Used for implicit coercions of polynomials as JAX arrays
return _dim_as_value(self)
def __deepcopy__(self, memo):
return _DimExpr(
copy.deepcopy(self._monomials_sorted, memo),
copy.deepcopy(self._scope, memo))
def cmp_comparable(i1, i2) -> int:
if i1 < i2: return -1
if i1 > i2: return 1
return 0
def cmp_sequence(s1, s2, elem_cmp) -> int:
"""Compares two sequences using `elem_cmp`."""
l2 = len(s2)
for i, e1 in enumerate(s1):
if i >= l2: return 1
if c := elem_cmp(e1, s2[i]): return c
if len(s1) < l2: return -1
return 0
class SymbolicScope:
"""Indentifies a scope for symbolic expressions.
All symbolic expressions that interact (e.g., appear in the argument shapes
for one JAX function invocation, or are involved in arithmetic operations)
must be from the same scope and must share the same SymbolicScope object.
Holds the constraints on symbolic expressions.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints)
for more details.
Args:
constraints_str: A sequence of constraints on symbolic dimension expressions,
of the form `e1 >= e2` or `e1 <= e2`.
"""
def __init__(self,
constraints_str: Sequence[str] = ()):
if isinstance(constraints_str, str):
raise ValueError(
"The symbolic constraints should be a sequence of strings. "
f"Got {repr(constraints_str)}")
self._initialized = False
self._location_frame = source_info_util.user_frame(source_info_util.current())
# Keep the explicit constraints in the order in which they were added
self._explicit_constraints: list[_SymbolicConstraint] = []
# We cache the _DimExpr.bounds calls. The result depends only on the
# explicit and implicit constraints, so it is safe to keep it in the
# scope. Set the cache before we parse constraints. We also keep the
# bounds precision with which we computed the cached result.
self._bounds_cache: dict[_DimExpr,
tuple[float, float, BoundsPrecision]] = {}
# We store here a decision procedure state initialized with all the
# _explicit_constraints.
self._decision_initial_state: Any | None = None
# We turn the equality constraints into normalization rules.
# For an explicit constraint `t*tk == e`, we keep
# `_normalization_rules[t] = (e, tk)`.
# During building of expressions, if we encounter the term
# `t*tk1` and `tk1 % tk == 0`, we replace it with `e*(tk1 // tk)`.
self._normalization_rules: NormalizationRules = {}
for c_str in constraints_str:
self._parse_and_process_explicit_constraint(c_str)
self._bounds_cache.clear()
self._initialized = True
def __str__(self) -> str:
extras = []
if self._explicit_constraints:
extras.append(" with constraints:")
for constr in self._explicit_constraints:
extras.append(f" {constr.debug_str}")
loc = source_info_util._summarize_frame(self._location_frame) if self._location_frame else "unknown"
return (
f"{id(self)} created at {loc}" +
"\n".join(extras))
__repr__ = __str__
def _parse_and_process_explicit_constraint(self, c_str: str):
if not isinstance(c_str, str):
raise ValueError(
f"SymbolicScope constraint must be a string: got {repr(c_str)}")
cmp_pos, cmp, is_geq = c_str.find("=="), Comparator.EQ, True