Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename references to "K" as "backend" for consistency. #15242

Merged
merged 1 commit into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,9 +1661,8 @@ def zeros_like(x, dtype=None, name=None):
Example:
```python
from tensorflow.keras import backend as K
kvar = K.variable(np.random.random((2,3)))
kvar_zeros = K.zeros_like(kvar)
kvar = tf.keras.backend.variable(np.random.random((2,3)))
kvar_zeros = tf.keras.backend.zeros_like(kvar)
K.eval(kvar_zeros)
# array([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32)
```
Expand Down
26 changes: 13 additions & 13 deletions keras/callbacks_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import os
import numpy as np
from keras import backend as K
from keras import backend
from keras import callbacks
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
Expand Down Expand Up @@ -160,9 +160,10 @@ def _init_writer(self, model):
self.writer = tf.summary.create_file_writer(self.log_dir)
if not model.run_eagerly and self.write_graph:
with self.writer.as_default():
tf.summary.graph(K.get_graph())
tf.summary.graph(backend.get_graph())
elif self.write_graph:
self.writer = tf.compat.v1.summary.FileWriter(self.log_dir, K.get_graph())
self.writer = tf.compat.v1.summary.FileWriter(
self.log_dir, backend.get_graph())
else:
self.writer = tf.compat.v1.summary.FileWriter(self.log_dir)

Expand All @@ -176,27 +177,26 @@ def _make_histogram_ops(self, model):
tf.compat.v1.summary.histogram(mapped_weight_name, weight)
if self.write_images:
w_img = tf.compat.v1.squeeze(weight)
shape = K.int_shape(w_img)
shape = tuple(w_img.shape)
if len(shape) == 2: # dense layer kernel case
if shape[0] > shape[1]:
w_img = tf.compat.v1.transpose(w_img)
shape = K.int_shape(w_img)
shape = tuple(w_img.shape)
w_img = tf.reshape(w_img, [1, shape[0], shape[1], 1])
elif len(shape) == 3: # convnet case
if K.image_data_format() == 'channels_last':
if backend.image_data_format() == 'channels_last':
# switch to channels_first to display
# every kernel as a separate image
w_img = tf.compat.v1.transpose(w_img, perm=[2, 0, 1])
shape = K.int_shape(w_img)
w_img = tf.reshape(w_img,
[shape[0], shape[1], shape[2], 1])
shape = tuple(w_img.shape)
w_img = tf.reshape(w_img, [shape[0], shape[1], shape[2], 1])
elif len(shape) == 1: # bias case
w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
else:
# not possible to handle 3D convnets etc.
continue

shape = K.int_shape(w_img)
shape = tuple(w_img.shape)
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf.compat.v1.summary.image(mapped_weight_name, w_img)

Expand Down Expand Up @@ -421,7 +421,7 @@ def on_epoch_end(self, epoch, logs=None):
embeddings_data = self.embeddings_data
n_samples = embeddings_data[0].shape[0]
i = 0
sess = K.get_session()
sess = backend.get_session()
while i < n_samples:
step = min(self.batch_size, n_samples - i)
batch = slice(i, i + step)
Expand All @@ -436,8 +436,8 @@ def on_epoch_end(self, epoch, logs=None):

feed_dict.update({self.batch_id: i, self.step: step})

if not isinstance(K.learning_phase(), int):
feed_dict[K.learning_phase()] = False
if not isinstance(backend.learning_phase(), int):
feed_dict[backend.learning_phase()] = False

sess.run(self.assign_embeddings, feed_dict=feed_dict)
self.saver.save(sess,
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras import activations
from keras import backend as K
from keras import backend
from keras import constraints
from keras import initializers
from keras import regularizers
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self,
self.supports_masking = True

def build(self, input_shape):
dtype = tf.as_dtype(self.dtype or K.floatx())
dtype = tf.as_dtype(self.dtype or backend.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError('A Dense layer can only be built with a floating-point '
f'dtype. Received: dtype={dtype}')
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/core/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Contains the dropout layer."""
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras import backend as K
from keras import backend
from keras.engine.base_layer import Layer
from keras.utils import control_flow_util
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -110,7 +110,7 @@ def _get_noise_shape(self, inputs):

def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
training = backend.learning_phase()

def dropped_inputs():
return tf.nn.dropout(
Expand Down
5 changes: 2 additions & 3 deletions keras/layers/core/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Contains the Masking layer."""
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras import backend as K
from keras.engine.base_layer import Layer
import tensorflow.compat.v2 as tf
from tensorflow.python.util.tf_export import keras_export
Expand Down Expand Up @@ -69,10 +68,10 @@ def __init__(self, mask_value=0., **kwargs):
self._compute_output_and_mask_jointly = True

def compute_mask(self, inputs, mask=None):
return K.any(tf.not_equal(inputs, self.mask_value), axis=-1)
return tf.reduce_any(tf.not_equal(inputs, self.mask_value), axis=-1)

def call(self, inputs):
boolean_mask = K.any(
boolean_mask = tf.reduce_any(
tf.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
outputs = inputs * tf.cast(boolean_mask, inputs.dtype)
# Compute the mask and outputs simultaneously.
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/core/repeat_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Contains the RepeatVector layer."""
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras import backend as K
from keras import backend
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -57,7 +57,7 @@ def compute_output_shape(self, input_shape):
return tf.TensorShape([input_shape[0], self.n, input_shape[1]])

def call(self, inputs):
return K.repeat(inputs, self.n)
return backend.repeat(inputs, self.n)

def get_config(self):
config = {'n': self.n}
Expand Down
8 changes: 4 additions & 4 deletions keras/layers/core/tf_op_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order

from keras import backend as K
from keras import backend
from keras.engine import keras_tensor
from keras.engine.base_layer import Layer

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, cls_ref, method_name, **kwargs):
get_canonical_name_for_symbol(
self.cls_ref, api_name='keras', add_prefix_to_v1_names=True))
if 'name' not in kwargs:
kwargs['name'] = K.unique_object_name(
kwargs['name'] = backend.unique_object_name(
'tf.' + self.cls_symbol + '.' + self.method_name,
zero_based=True,
avoid_observed_names=True)
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(self, attr_name, **kwargs):
self.attr_name = attr_name

if 'name' not in kwargs:
kwargs['name'] = K.unique_object_name(
kwargs['name'] = backend.unique_object_name(
'input.' + self.attr_name, zero_based=True, avoid_observed_names=True)
kwargs['autocast'] = False

Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(self, function, **kwargs):
name = 'tf.' + self.symbol
else:
name = self.function.__name__
kwargs['name'] = K.unique_object_name(
kwargs['name'] = backend.unique_object_name(
name, zero_based=True, avoid_observed_names=True)
kwargs['autocast'] = False

Expand Down
4 changes: 2 additions & 2 deletions keras/saving/saved_model/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os

from keras import backend as K
from keras import backend
from keras.protobuf import saved_metadata_pb2
from keras.protobuf import versions_pb2
from keras.saving import saving_utils
Expand Down Expand Up @@ -85,7 +85,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
# already-set learning phase placeholder.
# This is needed for compatibility reasons until learning phase setting
# is removed from the public apis.
with K.deprecated_internal_learning_phase_scope(0):
with backend.deprecated_internal_learning_phase_scope(0):
with utils.keras_option_scope(save_traces):
saved_nodes, node_paths = save_lib.save_and_return_nodes(
model, filepath, signatures, options)
Expand Down
6 changes: 3 additions & 3 deletions keras/saving/saved_model/save_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import threading
import weakref

from keras import backend as K
from keras import backend
from keras.engine import base_layer_utils
from keras.engine import input_spec
from keras.mixed_precision import autocast_variable
Expand Down Expand Up @@ -355,7 +355,7 @@ def tracing_scope():
while _thread_local_data.trace_queue:
fn, args, kwargs, training = _thread_local_data.trace_queue.pop()
if training is not None:
with K.deprecated_internal_learning_phase_scope(training):
with backend.deprecated_internal_learning_phase_scope(training):
fn.get_concrete_function(*args, **kwargs)
else:
fn.get_concrete_function(*args, **kwargs)
Expand Down Expand Up @@ -694,7 +694,7 @@ def _wrap_activity_regularizer(layer):
layer._activity_regularizer,
'{}_activity_regularizer'.format(layer.name),
input_signature=[
tf.TensorSpec(None, layer._compute_dtype or K.floatx())
tf.TensorSpec(None, layer._compute_dtype or backend.floatx())
])
# pylint: enable=protected-access

Expand Down
9 changes: 5 additions & 4 deletions keras/saving/saved_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import itertools
import threading
import types
from keras import backend as K
from keras import backend
from keras.engine import base_layer_utils
from keras.utils import control_flow_util
from keras.utils import tf_contextlib
Expand All @@ -43,7 +43,7 @@ def use_wrapped_call(layer, call_fn, default_training_value=None,
call_fn: tf.function that takes layer inputs (and possibly a training arg),
and returns a tuple of (outputs, list of losses).
default_training_value: Default value of the training kwarg. If `None`, the
default is `K.learning_phase()`.
default is `tf.keras.backend.learning_phase()`.
return_method: Whether to return a method bound to the layer.
Returns:
Expand Down Expand Up @@ -138,7 +138,8 @@ def maybe_add_training_arg(
wrapped_call: Wrapped call function.
expects_training_arg: Whether to include 'training' argument.
default_training_value: Default value of the training kwarg to include in
the arg spec. If `None`, the default is `K.learning_phase()`.
the arg spec. If `None`, the default is
`tf.keras.backend.learning_phase()`.
Returns:
Tuple of (
Expand All @@ -152,7 +153,7 @@ def wrap_with_training_arg(*args, **kwargs):
training_arg_index = get_training_arg_index(original_call)
training = get_training_arg(training_arg_index, args, kwargs)
if training is None:
training = default_training_value or K.learning_phase()
training = default_training_value or backend.learning_phase()

args = list(args)
kwargs = kwargs.copy()
Expand Down
4 changes: 2 additions & 2 deletions keras/saving/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import copy
import os
from keras import backend as K
from keras import backend
from keras import losses
from keras import optimizer_v1
from keras import optimizers
Expand Down Expand Up @@ -149,7 +149,7 @@ def model_metadata(model, include_optimizer=True, require_config=True):

metadata = dict(
keras_version=str(keras_version),
backend=K.backend(),
backend=backend.backend(),
model_config=model_config)
if model.optimizer and include_optimizer:
if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
Expand Down
7 changes: 4 additions & 3 deletions keras/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import copy
import numpy as np
from tensorflow.python.framework import ops
from keras import backend as K
from keras import backend
from keras.engine import keras_tensor
from keras.utils import object_identity
from keras.utils import tf_contextlib
Expand Down Expand Up @@ -438,7 +438,7 @@ def maybe_init_scope(layer):
def graph_context_for_symbolic_tensors(*args, **kwargs):
"""Returns graph context manager if any of the inputs is a symbolic tensor."""
if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
with K.get_graph().as_default():
with backend.get_graph().as_default():
yield
else:
yield
Expand All @@ -450,7 +450,8 @@ def dataset_is_infinite(dataset):
return tf.equal(
tf.data.experimental.cardinality(dataset), tf.data.experimental.INFINITE_CARDINALITY)
else:
dataset_size = K.get_session().run(tf.data.experimental.cardinality(dataset))
dataset_size = backend.get_session().run(
tf.data.experimental.cardinality(dataset))
return dataset_size == tf.data.experimental.INFINITE_CARDINALITY


Expand Down