In [2]:
from __future__ import print_function
import tensorflow as tf

In [3]:
tf.reset_default_graph()

pred = tf.placeholder(tf.bool, name="pred")
x = tf.cond(pred, 
            lambda: tf.constant(3, dtype=tf.float32), 
            lambda: tf.constant(1, dtype=tf.float32))

In [4]:
with tf.Session() as sess:
    print(sess.run(x, feed_dict={pred: True}))
    print(sess.run(x, feed_dict={pred:False}))

3.0
1.0


In [8]:
print(tf.__version__)

1.4.0


## The Tricky Things about `tf.cond`

- Based on version printed above (should be 1.4.0)
- [source code](https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/control_flow_ops.py#L1747)
- Make sure you read the doc string of `tf.cond` from line 10 to 16

In [9]:
for i, l in enumerate(tf.cond.__doc__.split("\n"), 1):
    print("{:>2}| {}".format(i, l))

 1| Return `true_fn()` if the predicate `pred` is true else `false_fn()`. (deprecated arguments)
 2| 
 3| SOME ARGUMENTS ARE DEPRECATED. They will be removed in a future version.
 4| Instructions for updating:
 5| fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.
 6| 
 7| `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
 8| `false_fn` must have the same non-zero number and type of outputs.
 9| 
10| Note that the conditional execution applies only to the operations defined in
11| `true_fn` and `false_fn`. Consider the following simple program:
12| 
13| ```python
14| z = tf.multiply(a, b)
15| result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
16| ```
17| 
18| If `x < y`, the `tf.add` operation will be executed and `tf.square`
19| operation will not be executed. Since `z` is needed for at least one
20| branch of the `cond`, the `tf.multiply` operation is always executed,
21| unconditionally.
22| Although this behavior is consist

In [13]:
# let's try the example in the doc
with tf.Graph().as_default():
    a = tf.constant(2.0)
    b = tf.constant(3.0)
    z = tf.multiply(a, b)
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
    
with tf.Session(graph=a.graph) as sess:
    feed_dict = {x: 1, y:2}
    r, zz = sess.run([result, z], feed_dict=feed_dict)
    print(r, zz)

with tf.Session(graph=a.graph) as sess:
    feed_dict = {x: 2, y:1}
    r, zz = sess.run([result, z], feed_dict=feed_dict)
    print(r, zz)

7.0 6.0
1.0 6.0


Well, nothing special

Ok, let's do some exciting side-effect

**side-effect is wonderful, fuck yeah~\~\~!**

In [80]:
side_graph = tf.Graph()
with side_graph.as_default():
    pred = tf.placeholder(tf.bool, name="pred")
    x = tf.Variable(1.0, trainable=False, name="x")
    add_x = tf.assign_add(x, 1.0, name="add_x")
    with tf.control_dependencies([add_x]):
        y = tf.constant(3.0, name="y")
    result = tf.cond(pred, true_fn=lambda: x, false_fn=lambda: y)

In [81]:
# Before you run the following cell, what you'll see?
# Make a guess!

In [82]:
with tf.Session(graph=side_graph) as sess:
    tf.global_variables_initializer().run()
    feed_dict = {pred: False}
    xx, yy = sess.run([x, result], feed_dict=feed_dict)
    print(xx, yy)

2.0 3.0


In [83]:
with tf.Session(graph=side_graph) as sess:
    tf.global_variables_initializer().run()
    feed_dict = {pred: True}
    xx, yy = sess.run([x, result], feed_dict=feed_dict)
    print(xx, yy)
## xx is still 2.0, WTF!?

2.0 2.0


So...why `x` is still updated even we the `pred` is `True`? 

That is, we in fact don't need `y` here.

In [50]:
tf.summary.FileWriter(graph=side_graph, logdir='log/side_graph').close()

## The Graph

![cond-side-eff](images/cond_side_eff.png)

As you can see here, `x`, `y` and `add_x` are all needed by `cond`.

According to dataflow model in `Tensorflow`, these tensors will be evaluated before evaluation of `cond`.

In this case, `add_x` will be evaluated no matter what `pred` is!

## Summary

1. the `true_fn` and `false_fn` will be called only **once** when the `cond` node is created in the graph, not the runtime (`Session`)
2. Make sure you block your side-effect in the `true_fn` and `false_fn`

## References

- [Stackoverflow](https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond)
  - Exellent explanation, must read.
- [Issue #3287](https://github.com/tensorflow/tensorflow/issues/3287)