Skip to content

Commit

Permalink
rename ret to new_op
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Apr 19, 2017
1 parent 9854a55 commit ee818ee
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions edward/util/random_variables.py
Expand Up @@ -264,31 +264,28 @@ def copy(org_instance, dict_swap=None, scope="copied",
else:
new_original_op = None

# Make a copy of the node def.
# As an instance of tensorflow.core.framework.graph_pb2.NodeDef, it
# stores string-based info such as name, device, and type of the op.
# It is unique to every Operation instance.
# Copy the node def.
new_node_def = deepcopy(op.node_def)
new_node_def.name = new_name

# Copy the other inputs needed for initialization.
output_types = op._output_types[:]

# Make a copy of the op def.
# Copy the op def.
# It is unique to every Operation type.
op_def = deepcopy(op.op_def)

ret = tf.Operation(new_node_def,
graph,
[],
output_types,
[],
[],
new_original_op,
op_def)
new_op = tf.Operation(new_node_def,
graph,
[],
output_types,
[],
[],
new_original_op,
op_def)

# advertise op early to break recursions
graph._add_op(ret)
graph._add_op(new_op)

# If it has control inputs, copy them.
elems = []
Expand All @@ -299,15 +296,15 @@ def copy(org_instance, dict_swap=None, scope="copied",

elems.append(elem)

ret._add_control_inputs(elems)
new_op._add_control_inputs(elems)

# If it has inputs, copy them.
for x in op.inputs:
elem = copy(x, dict_swap, scope, True, copy_q)
if not isinstance(elem, tf.Operation):
elem = tf.convert_to_tensor(elem)

ret._add_input(elem)
new_op._add_input(elem)

# Use Graph's private methods to add the op, following
# implementation of `tf.Graph().create_op()`.
Expand All @@ -316,11 +313,11 @@ def copy(org_instance, dict_swap=None, scope="copied",
op_type = new_name

if compute_shapes:
set_shapes_for_outputs(ret)
graph._record_op_seen_by_control_dependencies(ret)
set_shapes_for_outputs(new_op)
graph._record_op_seen_by_control_dependencies(new_op)

if compute_device:
graph._apply_device_functions(ret)
graph._apply_device_functions(new_op)

if graph._colocation_stack:
all_colocation_groups = []
Expand All @@ -330,17 +327,17 @@ def copy(org_instance, dict_swap=None, scope="copied",
# Make this device match the device of the colocated op, to
# provide consistency between the device and the colocation
# property.
if ret.device and ret.device != colocation_op.device:
if new_op.device and new_op.device != colocation_op.device:
logging.warning("Tried to colocate %s with an op %s that had "
"a different device: %s vs %s. "
"Ignoring colocation property.",
name, colocation_op.name, ret.device,
name, colocation_op.name, new_op.device,
colocation_op.device)
else:
ret._set_device(colocation_op.device)
new_op._set_device(colocation_op.device)

all_colocation_groups = sorted(set(all_colocation_groups))
ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))

# Sets "container" attribute if
Expand All @@ -351,12 +348,12 @@ def copy(org_instance, dict_swap=None, scope="copied",
if (graph._container and
op_type in graph._registered_ops and
graph._registered_ops[op_type].is_stateful and
"container" in ret.node_def.attr and
not ret.node_def.attr["container"].s):
ret.node_def.attr["container"].CopyFrom(
"container" in new_op.node_def.attr and
not new_op.node_def.attr["container"].s):
new_op.node_def.attr["container"].CopyFrom(
attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container)))

return ret
return new_op
else:
raise TypeError("Could not copy instance: " + str(org_instance))

Expand Down

0 comments on commit ee818ee

Please sign in to comment.