-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
custom_derivatives.py
1369 lines (1162 loc) · 57.5 KB
/
custom_derivatives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
from functools import update_wrapper, reduce, partial
import inspect
from typing import Any, Callable, Generic, TypeVar
from jax._src import config
from jax._src import core
from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
traceback_util.register_exclusion(__file__)
map = safe_map
zip = safe_zip
### util
def _resolve_kwargs(fun, args, kwargs):
if isinstance(fun, partial):
# functools.partial should have an opaque signature.
fun = lambda *args, **kwargs: None
ba = inspect.signature(fun).bind(*args, **kwargs)
ba.apply_defaults()
if ba.kwargs:
raise TypeError("keyword arguments could not be resolved to positions")
else:
return ba.args
def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts
def _close_jaxpr(jaxpr):
return pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)
def _zeros_like_pytree(x):
return tree_map(Zero.from_value, x)
@partial(partial, tree_map)
def _stop_gradient(x):
if isinstance(x, core.Tracer):
return stop_gradient_p.bind(x)
else:
return x
# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux
def _flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)
### JVPs
ReturnValue = TypeVar('ReturnValue')
@custom_api_util.register_custom_decorator_type
class custom_jvp(Generic[ReturnValue]):
"""Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are
callables that behave similarly to the underlying function to which the
decorator was applied, except when a differentiation transformation (like
:py:func:`jax.jvp` or :py:func:`jax.grad`) is applied, in which case a custom
user-supplied JVP rule function is used instead of tracing into and
performing automatic differentiation of the underlying function's
implementation.
There are two instance methods available for defining the custom JVP rule:
:py:func:`~jax.custom_jvp.defjvp` for defining a *single* custom JVP rule for
all the function's inputs, and for convenience
:py:func:`~jax.custom_jvp.defjvps`, which wraps
:py:func:`~jax.custom_jvp.defjvp`, and allows you to provide separate
definitions for the partial derivatives of the function w.r.t. each of its
arguments.
For example::
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
For a more detailed introduction, see the tutorial_.
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
"""
fun: Callable[..., ReturnValue]
nondiff_argnums: tuple[int, ...]
jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None
symbolic_zeros: bool = False
def __init__(self,
fun: Callable[..., ReturnValue],
nondiff_argnums: tuple[int, ...] = (),
):
update_wrapper(self, fun)
self.fun = fun
self.nondiff_argnums = nondiff_argnums
__getattr__ = custom_api_util.forward_attr
def defjvp(self,
jvp: Callable[..., tuple[ReturnValue, ReturnValue]],
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, ReturnValue]]:
"""Define a custom JVP rule for the function represented by this instance.
Args:
jvp: a Python callable representing the custom JVP rule. When there are no
``nondiff_argnums``, the ``jvp`` function should accept two arguments,
where the first is a tuple of primal inputs and the second is a tuple of
tangent inputs. The lengths of both tuples are equal to the number of
parameters of the ``custom_jvp`` function. The ``jvp`` function should
produce as output a pair where the first element is the primal output
and the second element is the tangent output. Elements of the input and
output tuples may be arrays or any nested tuples/lists/dicts thereof.
symbolic_zeros: boolean, indicating whether the rule should be passed
objects representing static symbolic zeros in its tangent argument in
correspondence with unperturbed values; otherwise, only standard JAX
types (e.g. array-likes) are passed. Setting this option to ``True``
allows a JVP rule to detect whether certain inputs are not involved in
differentiation, but at the cost of needing special handling for these
objects (which e.g. can't be passed into jax.numpy functions). Default
``False``.
Returns:
None.
Example::
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
"""
self.jvp = jvp
self.symbolic_zeros = symbolic_zeros
return jvp
def defjvps(self, *jvps: Callable[..., ReturnValue] | None):
"""Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with ``nondiff_argnums``.
Args:
*jvps: a sequence of functions, one for each positional argument of the
``custom_jvp`` function. Each function takes as arguments the tangent
value for the corresponding primal input, the primal output, and the
primal inputs. See the example below.
Returns:
None.
Example::
@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
"""
if self.nondiff_argnums:
raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.")
def jvp(primals, tangents):
primal_out = self(*primals)
zeros = _zeros_like_pytree(primal_out)
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
for t, jvp in zip(tangents, jvps)]
tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out)
return primal_out, tangent_out
self.defjvp(jvp)
@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.jvp:
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
for i, x in enumerate(args))
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args # type: ignore
jvp = lu.wrap_init(self.jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
out_type1)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
return tree_unflatten(out_tree, out_flat)
def _add_args(f, extra_args):
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
extra_args = tuple(arg.val for arg in extra_args)
all_args = (extra_args + args)
yield (yield all_args, kwargs)
@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
py_primals = tree_unflatten(in_tree, primals_in)
py_tangents = tree_unflatten(in_tree, tangents_in)
pair_out = yield (py_primals, py_tangents), {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) representing "
f"primal and tangent outputs, but got {pair_out}.")
raise TypeError(msg)
py_primals_out, py_tangents_out = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
if out_tree != out_tree2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce primal and tangent outputs with equal container (pytree) "
f"structures, but got {out_tree} and {out_tree2} respectively.")
raise TypeError(msg)
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular of the same container/pytree structure), but "
"instead the JVP rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular with leaves of the same shape/dtype), but "
"instead the JVP rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
for x in primals_out]
tangent_avals_out = [
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got {} and {} respectively.")
raise TypeError(msg.format(av1.str_short(), av2.str_short()))
else:
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got:\n{}")
disagreements = (
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)
class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
def bind(self, fun, jvp, *args, symbolic_zeros):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = process_env_traces(
fun, self, top_trace and top_trace.level, False)
jvp, env_trace_todo2 = process_env_traces(
jvp, self, top_trace and top_trace.level, True)
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers, # type: ignore
symbolic_zeros=symbolic_zeros) # type: ignore
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, _, *args):
with core.new_sublevel():
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, jvp_was_run: bool):
return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run)
def get_bind_params(self, params):
new_params = dict(params)
call_jaxpr = new_params.pop('call_jaxpr')
num_consts = new_params.pop('num_consts')
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
return [fun, jvp], new_params
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
@lu.wrap_init
def jvp(*xs):
n, ragged = divmod(len(xs), 2)
assert not ragged
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
zeros = [type(t) is SymbolicZero for t in tangents]
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
nz_out_tangents_ = iter(nz_out_tangents)
out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
if z else next(nz_out_tangents_)
for p, z in zip(out_primals, out_zeros)]
assert next(nz_out_tangents_, None) is None
return [*out_primals, *out_tangents]
return jvp
@partial(lu.transformation_with_aux, use_eq_store=True)
def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
outs = yield args, {}
todo = []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo = primitive.post_process(trace, outs, jvp_was_run)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
return call_jaxpr.out_avals, call_jaxpr.effects
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
args_ = map(mlir.wrap_singleton_ir_values, args)
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.name_stack, ctx.tokens_in, consts,
*args_, dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already
# been linearized, we can drop the jvp rule.
def _custom_jvp_call_transpose(params, jaxpr, args, ct, _, reduce_axes):
del params
return ad.backward_pass(jaxpr.jaxpr, reduce_axes, None, jaxpr.consts, args, ct)
ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
### VJPs
@custom_api_util.register_custom_decorator_type
class custom_vjp(Generic[ReturnValue]):
"""Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are
callables that behave similarly to the underlying function to which the
decorator was applied, except when a reverse-mode differentiation
transformation (like :py:func:`jax.grad`) is applied, in which case a custom
user-supplied VJP rule function is used instead of tracing into and performing
automatic differentiation of the underlying function's implementation. There
is a single instance method, :py:func:`~jax.custom_vjp.defvjp`, which may be
used to define the custom VJP rule.
This decorator precludes the use of forward-mode automatic differentiation.
For example::
@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial_.
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
"""
def __init__(self,
fun: Callable[..., ReturnValue],
nondiff_argnums: tuple[int, ...] = ()):
update_wrapper(self, fun)
self.fun = fun
self.nondiff_argnums = nondiff_argnums
self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None
self.bwd: Callable[..., tuple[Any, ...]] | None = None
self.symbolic_zeros = False
__getattr__ = custom_api_util.forward_attr
def defvjp(self,
fwd: Callable[..., tuple[ReturnValue, Any]],
bwd: Callable[..., tuple[Any, ...]],
symbolic_zeros: bool = False,
) -> None:
"""Define a custom VJP rule for the function represented by this instance.
Args:
fwd: a Python callable representing the forward pass of the custom VJP
rule. When there are no ``nondiff_argnums``, the ``fwd`` function has
the same input signature as the underlying primal function. It should
return as output a pair, where the first element represents the primal
output and the second element represents any "residual" values to store
from the forward pass for use on the backward pass by the function
``bwd``. Input arguments and elements of the output pair may be arrays
or nested tuples/lists/dicts thereof.
bwd: a Python callable representing the backward pass of the custom VJP
rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes
two arguments, where the first is the "residual" values produced on the
forward pass by ``fwd``, and the second is the output cotangent with the
same structure as the primal function output. The output of ``bwd`` must
be a tuple of length equal to the number of arguments of the primal
function, and the tuple elements may be arrays or nested
tuples/lists/dicts thereof so as to match the structure of the primal
input arguments.
symbolic_zeros: boolean, determining whether to indicate symbolic zeros
to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom
derivative rules to detect when certain inputs, and when certain
output cotangents, are not involved in differentiation. If ``True``:
* ``fwd`` must accept, in place of each leaf value ``x`` in
the pytree comprising an argument to the original function,
an object (of type
``jax.custom_derivatives.CustomVJPPrimal``) with two
attributes instead: ``value`` and ``perturbed``. The
``value`` field is the original primal argument, and
``perturbed`` is a boolean. The ``perturbed`` bit indicates
whether the argument is involved in differentiation (i.e.,
if it is ``False``, then the corresponding Jacobian "column"
is zero).
* ``bwd`` will be passed objects representing static symbolic zeros in
its cotangent argument in correspondence with unperturbed values;
otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to ``True`` allows these rules to detect whether
certain inputs and outputs are not involved in differentiation, but at
the cost of special handling. For instance:
* The signature of ``fwd`` changes, and the objects it is passed cannot
be output from the rule directly.
* The ``bwd`` rule is passed objects that are not entirely array-like,
and that cannot be passed to most ``jax.numpy`` functions.
* Any custom pytree nodes involved in the primal function's arguments
must accept, in their unflattening functions, the two-field record
objects that are given as input leaves to the ``fwd`` rule.
Default ``False``.
Returns:
None.
Example::
@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
"""
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.enable_custom_vjp_by_custom_transpose.value:
if self.nondiff_argnums:
raise NotImplementedError(
'nondiff_argnums not implemented for new custom_vjp')
return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args)
else:
if self.nondiff_argnums:
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
args, require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)
@dataclasses.dataclass
class CustomVJPPrimal:
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
value: Any
perturbed: bool
def custom_vjp_primal_tree_values(tree):
"""Strips away perturbation information from forward rule arguments.
This is a helper function for user with the ``symbolic_zeros`` option to
the ``defvjp`` method of a ``custom_vjp``-decorated function.
In ``symbolic_zeros`` mode, the custom forward rule receives arguments
whose pytree leaves are records with a ``value`` attribute that carries
the primal argument. This is a way to convert such argument trees back to
their original form, replacing each such record with its carried value at
each leaf.
"""
def value(leaf):
if type(leaf) is not CustomVJPPrimal:
raise TypeError(f"unexpected leaf type {type(leaf)}")
return leaf.value
return tree_map(value, tree)
def _check_for_tracers(x):
for leaf in tree_leaves(x):
if isinstance(leaf, core.Tracer):
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
"function in a position indicated by nondiff_argnums as "
"non-differentiable. Tracers cannot be passed as non-differentiable "
"arguments to custom_vjp functions; instead, nondiff_argnums should "
"only be used for arguments that can't be or contain JAX tracers, "
"e.g. function-valued arguments. In particular, array-valued "
"arguments should typically not be indicated as nondiff_argnums.")
raise UnexpectedTracerError(msg)
@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output (equal to those of the "
f"custom_vjp-decorated function {primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
f"instead of a pair the fwd rule {fwd_name} produced {pair_out}.")
raise TypeError(msg)
py_primals_out, res = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
res, res_tree = tree_flatten(res)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
"instead the fwd rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the second element represents residuals "
"(i.e. values stored from the forward pass for use on the "
"backward pass), but "
"instead the fwd rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
yield (*res, *primals_out), (out_tree, res_tree)
@lu.transformation
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
out_tree, res_tree = out_trees()
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
res, cts_out = split_list(args, [res_tree.num_leaves])
py_res = tree_unflatten(res_tree, res)
py_cts_out = tree_unflatten(out_tree, cts_out)
py_cts_in = yield (py_res, py_cts_out), {}
# For each None in py_cts_in, indicating an argument for which the rule
# produces no cotangent, we replace it with a pytree with the structure of the
# corresponding subtree of in_tree and with leaves of a non-pytree sentinel
# object, to be replaced with Nones in the final returned result.
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
cts_in_flat = []
append = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) or x
try:
if not isinstance(py_cts_in, tuple):
raise ValueError
tree_map(append,
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
except ValueError:
_, in_tree2 = tree_flatten(py_cts_in)
msg = ("Custom VJP rule must produce an output with the same container "
"(pytree) structure as the args tuple of the primal function, "
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"structure {} for primal input structure {}.")
raise TypeError(msg.format(in_tree2, in_tree)) from None
# Ignore any None cotangents, and any corresponding to inputs for which the
# type doesn't equal the tangent type (i.e. float0s)
# TODO(mattjj): change this to check if tangent type represents 0dim vspace
yield [Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace()
else ct for a, ct in zip(in_avals, cts_in_flat)]
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = process_env_traces(
fun, self, top_trace and top_trace.level, False)
fwd, env_trace_todo2 = process_env_traces_fwd(
fwd, top_trace and top_trace.level, out_trees)
tracers = map(top_trace.full_raise, args) # type: ignore
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if fst:
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
else:
env_trace_todo, bwd_transform = env_trace_todo
bwd = _apply_bwd_transform(bwd_transform, bwd)
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
with core.new_sublevel():
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_vjp_call(out_tracers, params)
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
@partial(lu.transformation_with_aux, use_eq_store=True)
def process_env_traces_fwd(level: int, out_trees, *args):
outs = yield args, {}
todo = []
bwd_transforms = []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo, bwd_xform = trace.post_process_custom_vjp_call_fwd(outs, out_trees)
todo.append(cur_todo)
bwd_transforms.append(bwd_xform)
yield outs, (tuple(todo), tuple(bwd_transforms))
def _apply_bwd_transform(todos, bwd):
todos_list = list(todos)
while todos_list:
bwd = todos_list.pop()(bwd)
return bwd
def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_vjp`: {disallowed_effects}')
return fun_jaxpr.out_avals, fun_jaxpr.effects
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p.multiple_results = True
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
_custom_vjp_call_jaxpr_impl, multiple_results=True))
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomVJPException()
zeros = [type(t) is not Zero for t in args_dot]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers!
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name,
main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []
@pe._memoize
def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
spmd_axis_name)
batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp
mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp)
def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
While the canonical way to define custom VJP rules is via ``jax.custom_vjp``,
the ``custom_gradient`` convenience wrapper follows TensorFlow's
``tf.custom_gradient`` API. The difference here is that ``custom_gradient``
can be used as a decorator on one function that returns both the primal value
(representing the output of the mathematical function to be differentiated)
and the VJP (gradient) function. See
https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
If the mathematical function to be differentiated has Haskell-like signature
``a -> b``, then the Python callable ``fun`` should have the signature
``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type
for ``x`` and the ``--o`` arrow to denote a linear function. See the example
below. That is, ``fun`` should return a pair where the first element
represents the value of the mathematical function to be differentiated and the
second element is a function to be called on the backward pass of reverse-mode
automatic differentiation (i.e. the "custom gradient" function).
The function returned as the second element of the output of ``fun`` can close
over intermediate values computed when evaluating the function to be
differentiated. That is, use lexical closure to share work between the forward
pass and the backward pass of reverse-mode automatic differentiation. However,
it cannot perform Python control flow which depends on the values of the
closed-over intermediate values or its cotangent arguments; if the function
includes such control flow, an error is raised.
Args:
fun: a Python callable specifying both the mathematical function to be
differentiated and its reverse-mode differentiation rule. It should return
a pair consisting of an output value and a Python callable that represents
the custom gradient function.
Returns:
A Python callable that accepts the same arguments as ``fun`` and returns the
output value specified by the first element of ``fun``'s output pair.
For example:
>>> @jax.custom_gradient
... def f(x):
... return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
>>> print(jax.grad(f)(3.))
3.0
An example with a function on two arguments, so that the VJP function must
return a tuple of length two:
>>> @jax.custom_gradient
... def f(x, y):
... return x * y, lambda g: (g * y, g * x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
"""
@custom_vjp
def wrapped_fun(*args, **kwargs):
ans, _ = fun(*args, **kwargs)
return ans
def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def bwd(res, cts):
jaxpr, in_tree, out_tree, consts = res
cts_flat, out_tree_ = tree_flatten((cts,))
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
cts_out = tree_unflatten(in_tree, cts_out)
if treedef_is_leaf(in_tree):
cts_out = (cts_out,)
return cts_out
wrapped_fun.defvjp(fwd, bwd)
return wrapped_fun
@register_pytree_node_class
class Residuals:
def __init__(self, jaxpr, in_tree, out_tree, consts):
self.jaxpr = jaxpr
self.in_tree = in_tree