In [71]:
import tensorflow as tf

# Introduction

When writing tensorflow, I often find myself hitting difficult-to-parse errors and highly unexpected behavior. My strategy for resolving these errors is to
1. Whittle the code I'm working with down to a minimum example that reproduces the error.
2. If the solution is not clear, ask ChatGPT why I'm getting the unexpected behavior.

This is a nearly foolproof formula for debugging all tensorflow issues. In addition, I end up with bite-sized examples of tensorflow code that provide counterintuitive knowledge about how the library works. 

With this document, I'm sharing these "tensorflow puzzles" for others to learn what I've learned the easy way. 

# Example 1: Not-so-forbidden paths

In the example below, we look at some oddities with the `tf.cond` operator. This is your tensorflow equivalent of an `if` statement, but it doesn't work quite like you might expect. Here we see that the function takes two different paths depending on whether `y` is passed in as a string or not. We use `y` as an argument for multiplication only if it is not a string. 

In [82]:
class Example1:

    def __init__(self, y):

        self.y_is_string = tf.constant(isinstance(y, str), dtype=tf.bool)
        self.y = tf.convert_to_tensor(y)


    @tf.function
    def __call__(self, x):

        return tf.cond(
            self.y_is_string,
            true_fn=lambda: x * tf.constant(2.0, dtype=tf.float32),
            false_fn=lambda: x * self.y,
        )

example = Example1(y=2.0)
print(f'{example(1.0).numpy() = }')

example = Example1(y='string')
print(f'{example(1.0).numpy() = }')

example(1.0).numpy() = 2.0


TypeError: ignored

Why are we getting an error complaining about a datatype issue on an operation that we never expected to execute?

The issue is that tensorflow will create the execution graph without regard for whether certain branches are prohibited with by certain data type configurations. 

The solution here is to have at most one data type per variable used in tensorflow graph, which might require some variable pre-processing. See the code below.

In [84]:
class Example1Fixed:

    def __init__(self, y):

        if isinstance(y, str):
          self.y_float = tf.constant(2.0, dtype=tf.float32)
        else:
          self.y_float = tf.keras.backend.cast_to_floatx(y)
        self.y = tf.convert_to_tensor(y)


    @tf.function
    def __call__(self, x):

        return x * self.y_float


example = Example1Fixed(y=2.0)
print(f'{example(1.0).numpy() = }')

example = Example1Fixed(y='string')
print(f'{example(1.0).numpy() = }')

example(1.0).numpy() = 2.0
example(1.0).numpy() = 2.0


*Note: One important characteristic of this example (and all of the examples below) is that if you comment out the `@tf.function`, everything works as you would expect. In this case, we are running regular Python code, not a tensorflow graph.*

# Example 2: TensorFlow's Magic Cure: Now with Zero Side Effects!

This example deals with how side effects work inside tensorflow graphs. Here we have a function that should update the `side_effect` attribute of the class depending on the branch taken in the `tf.cond`. One might expect that after running the code, that the `side_effect` attribute will be the same as the `pred` value passed into the constructor. 

In [85]:
class Example2:

  def __init__(self, pred):

    self.pred = tf.cast(pred, tf.bool)

  @tf.function
  def __call__(self):

      tf.cond(self.pred,
        self._true_func,
        self._false_func
      )

  def _true_func(self):

    self.side_effect = True

  def _false_func(self):

    self.side_effect = False



pred = True
c = Example2(pred=pred)
c()
print(f'{pred = }, {c.side_effect = }\n')

pred = False
c = Example2(pred=pred)
c()
print(f'{pred = }, {c.side_effect = }\n')

pred = True, c.side_effect = False

pred = False, c.side_effect = False



Oddly, it seems that the side effect is set to `False` in both cases.

What's happening here is that the attribute is being set when the graph is constructed, when both branches are passed. Since the `False` branch is executed second, this is the attribute value that persists. 

To work around this, you can get the expected behavior by setting the `side_effect` attribute as a tf.Variable. 

In [75]:
import tensorflow as tf

class Example2Fixed:

    def __init__(self, pred):
        self.pred = tf.cast(pred, tf.bool)
        self.side_effect = tf.Variable(False, dtype=tf.bool)

    @tf.function
    def __call__(self):

        tf.cond(self.pred, self._true_func, self._false_func)

    def _true_func(self):
        self.side_effect.assign(True)

    def _false_func(self):
        self.side_effect.assign(False)


pred = True
c = Example2Fixed(pred=pred)
c()
print(f'{pred = }, {c.side_effect.numpy() = }\n')

pred = False
c = Example2Fixed(pred=pred)
c()
print(f'{pred = }, {c.side_effect.numpy() = }\n')

pred = True, c.side_effect.numpy() = True

pred = False, c.side_effect.numpy() = False



# Example 3: Some things never change (in tensorflow graphs)

This example deals with how tensorflow graphs handle attribute updates after construction. Here you might expect that the function output would change after one of the core parameters of the function was updated. 

In [76]:
class Example3:

  def __init__(self, x):

    self.x = x

  @tf.function
  def __call__(self, y):
    return self.x * y

example = Example3(1)
print(f'{example.x = }, {example(2).numpy() = }')

example.x = 2
print(f'{example.x = }, {example(2).numpy() = }')


example.x = 1, example(2).numpy() = 2
example.x = 2, example(2).numpy() = 2


As you should have foreseen by this point, tensorflow did not exhibit the expected behavior. Similarly to the last example, the solution in this case is to set `x` as a `tf.Variable`, and update its value with the `assign` method. 

In [86]:
import tensorflow as tf

class Example3Fixed:

    def __init__(self, x):
        self.x = tf.Variable(x, dtype=tf.float32)

    @tf.function
    def __call__(self, y):
        return self.x * y

example = Example3Fixed(1)
print(f'{example.x.numpy() = }, {example(2).numpy() = }')

example.x.assign(2.0)
print(f'{example.x.numpy() = }, {example(2).numpy() = }')

example.x.numpy() = 1.0, example(2).numpy() = 2.0
example.x.numpy() = 2.0, example(2).numpy() = 4.0


# Example 4: Be specific about your ambiguity

This example deals with ambiguous shapes of tf.Variables. In the code below, the `variable` attribute is a `tf.Variable` with unspecified shape `tf.TensorShape(None)`. 

The `tf.cond` in the `__call__` method returns either the variable value or the input value. Since the variable shape is completely unspecified, you might reasonably expect that tensorflow can handle any shape for `x`.

In [87]:
class Example4:

  def __init__(self):

    self.variable = tf.Variable(1.0, shape=tf.TensorShape(None))

  @tf.function
  def __call__(self, x):

    variable = tf.cond(
      tf.math.greater(x, 0.0), 
      lambda: self.update_variable(x),
      lambda: self.variable
    )

    return variable * x

  def update_variable(self, x):

    self.variable.assign(x)

    return x

example = Example4()
x = tf.constant(1.0)

with tf.GradientTape() as tape:
  tape.watch(x)
  res = example(x)

grad = tape.gradient(res, x)

print(f'{grad.numpy() = }')

ValueError: ignored

Here we see a complaint about a `partially known TensorShape`. The issue here is that one branch is returning a `tf.Variable` with partially unknown shape, and the other is returning a tensor with no shape ambiguity. 

The fix is surprisingly simple. If we have the function pass back the updated `variable`, which has the undetermined shape specification, the code runs fine. 

In [88]:
class Example4Fixed:

  def __init__(self):

    self.variable = tf.Variable(1.0, shape=tf.TensorShape(None))

  @tf.function
  def __call__(self, x):

    variable = tf.cond(
      tf.math.greater(x, 0.0), 
      lambda: self.update_variable(x),
      lambda: self.variable
    )

    return variable * x

  def update_variable(self, x):

    self.variable.assign(x)

    return self.variable

example = Example4Fixed()
x = tf.constant(1.0)

with tf.GradientTape() as tape:
  tape.watch(x)
  res = example(x)

grad = tape.gradient(res, x)

print(f'{grad.numpy() = }')

grad.numpy() = 1.0


*Note: This code actually runs fine outside of the `GradientTape` context manager. This makes it more dangerous, since you may not notice the issue until you've incorporated your function into a model.*

# Conclusion

Tensorflow is a powerful tool, but unfortunately often behaves in ways that you might not expect. If you find some other nuggets of tensorflow wisdown after spending hours getting to the bottom of a bug, please add them here. 