Skip to content

Commit

Permalink
improvements to autodidax reduce_sum
Browse files Browse the repository at this point in the history
* generalize reduce_sum to handle multiple axes
* add reduce_sum transpose rule

also fix bug in AD jaxpr formation related to deduplicating consts
  • Loading branch information
mattjj committed Mar 16, 2022
1 parent c35a3ca commit 43036e1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 21 deletions.
27 changes: 20 additions & 7 deletions docs/autodidax.ipynb
Expand Up @@ -113,11 +113,16 @@
"def neg(x): return bind1(neg_p, x)\n",
"def sin(x): return bind1(sin_p, x)\n",
"def cos(x): return bind1(cos_p, x)\n",
"def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)\n",
"def greater(x, y): return bind1(greater_p, x, y)\n",
"def less(x, y): return bind1(less_p, x, y)\n",
"def transpose(x, perm): return bind1(transpose_p, x, perm=perm)\n",
"def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)\n",
"def reduce_sum(x, axis=None):\n",
" if axis is None:\n",
" axis = tuple(range(np.ndim(x)))\n",
" if type(axis) is int:\n",
" axis = (axis,)\n",
" return bind1(reduce_sum_p, x, axis=axis)\n",
"\n",
"def bind1(prim, *args, **params):\n",
" out, = bind(prim, *args, **params)\n",
Expand Down Expand Up @@ -1120,8 +1125,8 @@
"\n",
"def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n",
" (x,), (x_bdim,) = vals_in, dims_in\n",
" new_axis = axis + (x_bdim <= axis)\n",
" out_bdim = x_bdim - (new_axis < x_bdim)\n",
" new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)\n",
" out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)\n",
" return [reduce_sum(x, new_axis)], [out_bdim]\n",
"vmap_rules[reduce_sum_p] = reduce_sum_batching_rule"
]
Expand Down Expand Up @@ -1615,8 +1620,10 @@
"abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval\n",
"abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval\n",
"\n",
"def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:\n",
" new_shape = [d for i, d in enumerate(x.shape) if i != axis]\n",
"def reduce_sum_abstract_eval(x: ShapedArray, *, axis: Tuple[int, ...]\n",
" ) -> List[ShapedArray]:\n",
" axis_ = set(axis)\n",
" new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]\n",
" return [ShapedArray(tuple(new_shape), x.dtype)]\n",
"abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval\n",
"\n",
Expand Down Expand Up @@ -2097,7 +2104,7 @@
" subc = xc.XlaBuilder('add')\n",
" shape = _xla_shape(ShapedArray((), x_aval.dtype))\n",
" xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n",
" return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]\n",
" return [xops.Reduce(c, [x], [zero], subc.build(), axis)]\n",
"xla_translations[reduce_sum_p] = reduce_sum_translation\n",
"\n",
"def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n",
Expand Down Expand Up @@ -2825,8 +2832,9 @@
" var = constid_to_var.get(id(val))\n",
" if var is None:\n",
" aval = raise_to_shaped(get_aval(val))\n",
" var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)\n",
" var = constid_to_var[id(val)] = Var(aval)\n",
" constvar_to_val[var] = val\n",
" tracer_to_var[id(t)] = var\n",
" elif isinstance(t.recipe, JaxprEqnRecipe):\n",
" if id(t.recipe) not in processed_eqns:\n",
" eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))\n",
Expand Down Expand Up @@ -3242,6 +3250,11 @@
" return [z_bar, z_bar]\n",
"transpose_rules[add_p] = add_transpose_rule\n",
"\n",
"def reduce_sum_transpose_rule(cts, x, *, axis):\n",
" y_bar, = cts\n",
" return [broadcast(y_bar, x.aval.shape, axis)]\n",
"transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule\n",
"\n",
"def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
" del num_consts # Unused\n",
" undef_primals = [type(x) is UndefPrimal for x in invals]\n",
Expand Down
27 changes: 20 additions & 7 deletions docs/autodidax.md
Expand Up @@ -104,11 +104,16 @@ def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
Expand Down Expand Up @@ -873,8 +878,8 @@ vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = axis + (x_bdim <= axis)
out_bdim = x_bdim - (new_axis < x_bdim)
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
```
Expand Down Expand Up @@ -1269,8 +1274,10 @@ abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: Tuple[int, ...]
) -> List[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
Expand Down Expand Up @@ -1647,7 +1654,7 @@ def reduce_sum_translation(c, in_avals, in_vals, *, axis):
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
Expand Down Expand Up @@ -2207,8 +2214,9 @@ def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
Expand Down Expand Up @@ -2557,6 +2565,11 @@ def add_transpose_rule(cts, x, y):
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
Expand Down
27 changes: 20 additions & 7 deletions docs/autodidax.py
Expand Up @@ -92,11 +92,16 @@ def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
Expand Down Expand Up @@ -873,8 +878,8 @@ def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):

def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = axis + (x_bdim <= axis)
out_bdim = x_bdim - (new_axis < x_bdim)
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule

Expand Down Expand Up @@ -1271,8 +1276,10 @@ def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval

def reduce_sum_abstract_eval(x: ShapedArray, *, axis: int) -> List[ShapedArray]:
new_shape = [d for i, d in enumerate(x.shape) if i != axis]
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: Tuple[int, ...]
) -> List[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval

Expand Down Expand Up @@ -1643,7 +1650,7 @@ def reduce_sum_translation(c, in_avals, in_vals, *, axis):
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation

def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
Expand Down Expand Up @@ -2201,8 +2208,9 @@ def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = tracer_to_var[id(t)] = constid_to_var[id(val)] = Var(aval)
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
Expand Down Expand Up @@ -2556,6 +2564,11 @@ def add_transpose_rule(cts, x, y):
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule

def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule

def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
Expand Down

0 comments on commit 43036e1

Please sign in to comment.