-
-
Notifications
You must be signed in to change notification settings - Fork 422
/
tensorflow.py
45 lines (33 loc) · 1.13 KB
/
tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import cast, Any
import eagerpy as ep
from ..types import BoundsInput, Preprocessing
from .base import ModelWithPreprocessing
def get_device(device: Any) -> Any:
import tensorflow as tf
if device is None:
device = tf.device("/GPU:0" if tf.test.is_gpu_available() else "/CPU:0")
if isinstance(device, str):
device = tf.device(device)
return device
class TensorFlowModel(ModelWithPreprocessing):
def __init__(
self,
model: Any,
bounds: BoundsInput,
device: Any = None,
preprocessing: Preprocessing = None,
):
import tensorflow as tf
if not tf.executing_eagerly():
raise ValueError(
"TensorFlowModel requires TensorFlow Eager Mode"
) # pragma: no cover
device = get_device(device)
with device:
dummy = ep.tensorflow.zeros(0)
super().__init__(model, bounds, dummy, preprocessing=preprocessing)
self.device = device
@property
def data_format(self) -> str:
import tensorflow as tf
return cast(str, tf.keras.backend.image_data_format())