-
Notifications
You must be signed in to change notification settings - Fork 34
/
one_of.py
97 lines (83 loc) · 4.41 KB
/
one_of.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright 2019 The FastEstimator Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Dict, List, Optional, Set, TypeVar, Union
import numpy as np
import tensorflow as tf
import torch
from fastestimator.backend._cast import cast
from fastestimator.op.tensorop.tensorop import TensorOp
from fastestimator.util.traceability_util import traceable
Tensor = TypeVar('Tensor', tf.Tensor, torch.Tensor)
Model = TypeVar('Model', tf.keras.Model, torch.nn.Module)
@traceable()
class OneOf(TensorOp):
"""Perform one of several possible TensorOps.
Args:
*tensor_ops: Ops to choose between with a specified (or uniform) probability.
probs: List of probabilities, must sum to 1. When None, the probabilities will be equally distributed.
"""
def __init__(self, *tensor_ops: TensorOp, probs: Optional[List[float]] = None) -> None:
inputs = tensor_ops[0].inputs
outputs = tensor_ops[0].outputs
mode = tensor_ops[0].mode
ds_id = tensor_ops[0].ds_id
super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id)
self.in_list = tensor_ops[0].in_list
self.out_list = tensor_ops[0].out_list
for op in tensor_ops[1:]:
assert inputs == op.inputs, "All ops within a OneOf must share the same inputs"
assert self.in_list == op.in_list, "All ops within OneOf must share the same input configuration"
assert outputs == op.outputs, "All ops within a OneOf must share the same outputs"
assert self.out_list == op.out_list, "All ops within OneOf must share the same output configuration"
assert mode == op.mode, "All ops within a OneOf must share the same mode"
assert ds_id == op.ds_id, "All ops within a OneOf must share the same ds_id"
if probs:
assert len(tensor_ops) == len(probs), "The number of probabilities do not match with number of Operators"
assert abs(sum(probs) - 1) < 1e-8, "Probabilities must sum to 1"
else:
probs = [1 / len(tensor_ops) for _ in tensor_ops]
self.ops = tensor_ops
self.probs = probs
self.framework = None
def build(self, framework: str, device: Optional[torch.device] = None) -> None:
assert framework in {"tf", "torch"}, "unrecognized framework: {}".format(framework)
self.framework = framework
for op in self.ops:
op.build(framework, device)
def get_fe_loss_keys(self) -> Set[str]:
return set.union(*[op.get_fe_loss_keys() for op in self.ops])
def get_fe_models(self) -> Set[Model]:
return set.union(*[op.get_fe_models() for op in self.ops])
def fe_retain_graph(self, retain: Optional[bool] = None) -> Optional[bool]:
resp = None
for op in self.ops:
resp = resp or op.fe_retain_graph(retain)
return resp
def __getstate__(self) -> Dict[str, List[Dict[Any, Any]]]:
return {'ops': [elem.__getstate__() if hasattr(elem, '__getstate__') else {} for elem in self.ops]}
def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
"""Execute a randomly selected op from the list of `numpy_ops`.
Args:
data: The information to be passed to one of the wrapped operators.
state: Information about the current execution context, for example {"mode": "train"}.
Returns:
The `data` after application of one of the available numpyOps.
"""
if self.framework == 'tf':
idx = cast(tf.random.categorical(tf.math.log([self.probs]), 1), dtype='int32')[0, 0]
results = tf.switch_case(idx, [lambda op=op: op.forward(data, state) for op in self.ops])
else:
results = np.random.choice(self.ops, p=self.probs).forward(data, state)
return results