diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 1a388ffdc..4a81e5fde 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -264,31 +264,28 @@ def copy(org_instance, dict_swap=None, scope="copied", else: new_original_op = None - # Make a copy of the node def. - # As an instance of tensorflow.core.framework.graph_pb2.NodeDef, it - # stores string-based info such as name, device, and type of the op. - # It is unique to every Operation instance. + # Copy the node def. new_node_def = deepcopy(op.node_def) new_node_def.name = new_name # Copy the other inputs needed for initialization. output_types = op._output_types[:] - # Make a copy of the op def. + # Copy the op def. # It is unique to every Operation type. op_def = deepcopy(op.op_def) - ret = tf.Operation(new_node_def, - graph, - [], - output_types, - [], - [], - new_original_op, - op_def) + new_op = tf.Operation(new_node_def, + graph, + [], + output_types, + [], + [], + new_original_op, + op_def) # advertise op early to break recursions - graph._add_op(ret) + graph._add_op(new_op) # If it has control inputs, copy them. elems = [] @@ -299,7 +296,7 @@ def copy(org_instance, dict_swap=None, scope="copied", elems.append(elem) - ret._add_control_inputs(elems) + new_op._add_control_inputs(elems) # If it has inputs, copy them. for x in op.inputs: @@ -307,7 +304,7 @@ def copy(org_instance, dict_swap=None, scope="copied", if not isinstance(elem, tf.Operation): elem = tf.convert_to_tensor(elem) - ret._add_input(elem) + new_op._add_input(elem) # Use Graph's private methods to add the op, following # implementation of `tf.Graph().create_op()`. @@ -316,11 +313,11 @@ def copy(org_instance, dict_swap=None, scope="copied", op_type = new_name if compute_shapes: - set_shapes_for_outputs(ret) - graph._record_op_seen_by_control_dependencies(ret) + set_shapes_for_outputs(new_op) + graph._record_op_seen_by_control_dependencies(new_op) if compute_device: - graph._apply_device_functions(ret) + graph._apply_device_functions(new_op) if graph._colocation_stack: all_colocation_groups = [] @@ -330,17 +327,17 @@ def copy(org_instance, dict_swap=None, scope="copied", # Make this device match the device of the colocated op, to # provide consistency between the device and the colocation # property. - if ret.device and ret.device != colocation_op.device: + if new_op.device and new_op.device != colocation_op.device: logging.warning("Tried to colocate %s with an op %s that had " "a different device: %s vs %s. " "Ignoring colocation property.", - name, colocation_op.name, ret.device, + name, colocation_op.name, new_op.device, colocation_op.device) else: - ret._set_device(colocation_op.device) + new_op._set_device(colocation_op.device) all_colocation_groups = sorted(set(all_colocation_groups)) - ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( + new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) # Sets "container" attribute if @@ -351,12 +348,12 @@ def copy(org_instance, dict_swap=None, scope="copied", if (graph._container and op_type in graph._registered_ops and graph._registered_ops[op_type].is_stateful and - "container" in ret.node_def.attr and - not ret.node_def.attr["container"].s): - ret.node_def.attr["container"].CopyFrom( + "container" in new_op.node_def.attr and + not new_op.node_def.attr["container"].s): + new_op.node_def.attr["container"].CopyFrom( attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container))) - return ret + return new_op else: raise TypeError("Could not copy instance: " + str(org_instance))