-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
utils.py
92 lines (80 loc) · 3.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas utility functions."""
import math
import numpy as np
from jax import lax
from jax._src import core as jax_core
import jax.numpy as jnp
def when(condition):
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
lax.cond(condition, f, lambda: None)
return _wrapped
def cdiv(a: int, b: int) -> int:
return (a + b - 1) // b
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
size = np.prod(shape)
strides = []
for s in shape:
size = size // s
strides.append(int(size))
return tuple(strides)
def next_power_of_2(x: int) -> int:
if x == 0:
return 1
return int(2 ** math.ceil(math.log2(x)))
def pattern_match_scan_to_fori_loop(
jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int
) -> tuple[jax_core.Jaxpr, bool]:
if num_carry > 0:
# Pattern match onto fori_loop:
# We expect the first carry argument to the jaxpr to be the loop index and
# for the loop index + 1 to be returned as the first value out of the loop.
in_index_var = jaxpr.invars[num_consts]
out_index_var = jaxpr.outvars[0]
# Check that the loop index argument is an int32 scalar
if (in_index_var.aval.shape or
in_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index in: {in_index_var.aval} {jaxpr=}")
if (out_index_var.aval.shape or
out_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index out: {out_index_var.aval} {jaxpr=}")
# Look for the equation that increments the loop index
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive == lax.add_p:
if eqn.invars[0] == in_index_var:
if isinstance(eqn.invars[1], jax_core.Literal):
if eqn.invars[1].val == 1:
if eqn.outvars[0] == out_index_var:
eqn_index = i
break
else:
raise NotImplementedError("Unable to match fori_loop pattern")
# Delete the equation that increments and remove the loop index from the
# output. Incrementing the loop index will be done implicitly.
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
outvars=jaxpr.outvars[1:])
has_loop_index = True
else:
# If there's no carry, the loop index has been DCEd and the body does *not*
# expect a loop index as an argument.
has_loop_index = False
return jaxpr, has_loop_index