-
Notifications
You must be signed in to change notification settings - Fork 158
/
base.py
72 lines (52 loc) · 2.13 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base interfaces and datatypes."""
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
import chex
# pylint:disable=no-value-for-parameter
NO_PARAMS_MSG = (
'You are using a transformation that requires the current value of '
'parameters, but you are not passing `params` when calling `update`.')
PyTree = Any
Shape = Sequence[int]
OptState = NamedTuple # Transformation states are (possibly empty) namedtuples.
Params = chex.ArrayTree # Parameters are arbitrary nests of `jnp.ndarrays`.
Updates = Params # Gradient updates are of the same type as parameters.
TransformInitFn = Callable[
[Params],
Union[OptState, Sequence[OptState]]]
TransformUpdateFn = Callable[
[Updates, OptState, Optional[Params]],
Tuple[Updates, OptState]]
Schedule = Callable[
[chex.Numeric],
chex.Numeric]
class GradientTransformation(NamedTuple):
"""Optax transformations consists of a function pair: (initialise, update)."""
init: TransformInitFn
update: TransformUpdateFn
class EmptyState(OptState):
"""An empty state for the simplest stateless transformations."""
def identity() -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return EmptyState()
def update_fn(updates, state, params=None):
del params
return updates, state
return GradientTransformation(init_fn, update_fn)