/
derived.py
154 lines (135 loc) · 5.99 KB
/
derived.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Derived value from other hyper primitives."""
import abc
import copy
from typing import Any, Callable, List, Optional, Tuple, Union
from pyglove.core import object_utils
from pyglove.core import symbolic
from pyglove.core import typing as pg_typing
@symbolic.members([
('reference_paths', pg_typing.List(pg_typing.Object(object_utils.KeyPath)),
('Paths of referenced values, which are relative paths searched from '
'current node to root.'))
])
class DerivedValue(symbolic.Object, pg_typing.CustomTyping):
"""Base class of value that references to other values in object tree."""
@abc.abstractmethod
def derive(self, *args: Any) -> Any:
"""Derive the value from referenced values."""
def resolve(
self, reference_path_or_paths: Optional[Union[str, List[str]]] = None
) -> Union[Tuple[symbolic.Symbolic, object_utils.KeyPath],
List[Tuple[symbolic.Symbolic, object_utils.KeyPath]]]:
"""Resolve reference paths based on the location of this node.
Args:
reference_path_or_paths: (Optional) a string or KeyPath as a reference
path or a list of strings or KeyPath objects as a list of
reference paths.
If this argument is not provided, prebound reference paths of this
object will be used.
Returns:
A tuple (or list of tuple) of (resolved parent, resolved full path)
"""
single_input = False
if reference_path_or_paths is None:
reference_paths = self.reference_paths
elif isinstance(reference_path_or_paths, str):
reference_paths = [object_utils.KeyPath.parse(reference_path_or_paths)]
single_input = True
elif isinstance(reference_path_or_paths, object_utils.KeyPath):
reference_paths = [reference_path_or_paths]
single_input = True
elif isinstance(reference_path_or_paths, list):
paths = []
for path in reference_path_or_paths:
if isinstance(path, str):
path = object_utils.KeyPath.parse(path)
elif not isinstance(path, object_utils.KeyPath):
raise ValueError('Argument \'reference_path_or_paths\' must be None, '
'a string, KeyPath object, a list of strings, or a '
'list of KeyPath objects.')
paths.append(path)
reference_paths = paths
else:
raise ValueError('Argument \'reference_path_or_paths\' must be None, '
'a string, KeyPath object, a list of strings, or a '
'list of KeyPath objects.')
resolved_paths = []
for reference_path in reference_paths:
parent = self.sym_parent
while parent is not None and not reference_path.exists(parent):
parent = getattr(parent, 'sym_parent', None)
if parent is None:
raise ValueError(
f'Cannot resolve \'{reference_path}\': parent not found.')
resolved_paths.append((parent, parent.sym_path + reference_path))
return resolved_paths if not single_input else resolved_paths[0]
def __call__(self):
"""Generate value by deriving values from reference paths."""
referenced_values = []
for reference_path, (parent, _) in zip(
self.reference_paths, self.resolve()):
referenced_value = reference_path.query(parent)
# Make sure referenced value does not have referenced value.
# NOTE(daiyip): We can support dependencies between derived values
# in future if needed.
if not object_utils.traverse(
referenced_value, self._contains_not_derived_value):
raise ValueError(
f'Derived value (path={referenced_value.sym_path}) should not '
f'reference derived values. '
f'Encountered: {referenced_value}, '
f'Referenced at path {self.sym_path}.')
referenced_values.append(referenced_value)
return self.derive(*referenced_values)
def _contains_not_derived_value(
self, path: object_utils.KeyPath, value: Any) -> bool:
"""Returns whether a value contains derived value."""
if isinstance(value, DerivedValue):
return False
elif isinstance(value, symbolic.Object):
for k, v in value.sym_items():
if not object_utils.traverse(
v, self._contains_not_derived_value,
root_path=object_utils.KeyPath(k, path)):
return False
return True
class ValueReference(DerivedValue):
"""Class that represents a value referencing another value."""
def _on_bound(self):
"""Custom init."""
super()._on_bound()
if len(self.reference_paths) != 1:
raise ValueError(
f'Argument \'reference_paths\' should have exact 1 '
f'item. Encountered: {self.reference_paths}')
def derive(self, referenced_value: Any) -> Any:
"""Derive value by return a copy of the referenced value."""
return copy.copy(referenced_value)
def custom_apply(
self,
path: object_utils.KeyPath,
value_spec: pg_typing.ValueSpec,
allow_partial: bool,
child_transform: Optional[Callable[
[object_utils.KeyPath, pg_typing.Field, Any], Any]] = None
) -> Tuple[bool, 'DerivedValue']:
"""Implement pg_typing.CustomTyping interface."""
# TODO(daiyip): perform possible static analysis on referenced paths.
del path, value_spec, allow_partial, child_transform
return (False, self)
def reference(reference_path: str) -> ValueReference:
"""Create a referenced value from a referenced path."""
return ValueReference(reference_paths=[reference_path])