From 5bdb32a799022a7016e66b4d48602df1cebf6ed0 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Mon, 21 Sep 2020 18:15:39 -0700 Subject: [PATCH] [TF-numpy/extensions] Lets `scan` allow None in `init`. PiperOrigin-RevId: 332979305 --- trax/tf_numpy/extensions/extensions.py | 20 +++++++++++++---- trax/tf_numpy/extensions/extensions_test.py | 25 ++++++++++++++------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py index 1b2646253..b469c239f 100644 --- a/trax/tf_numpy/extensions/extensions.py +++ b/trax/tf_numpy/extensions/extensions.py @@ -1001,7 +1001,16 @@ def get_length(x): lambda t: (tf.TensorArray(t.dtype, size=0, dynamic_size=True).unstack(t) # pylint: disable=g-long-lambda if t is not None else None), xs) - def body(i, carry, ys_ta): + # tf.while_loop doesn't allow None in loop_vars, so we mask them. + is_init_none = tf.nest.map_structure(lambda x: x is None, init) + def to_safe(carry): + return tf.nest.map_structure( + lambda x, is_none: tf.zeros([]) if is_none else x, carry, is_init_none) + def from_safe(safe_carry): + return tf.nest.map_structure( + lambda x, is_none: None if is_none else x, safe_carry, is_init_none) + def body(i, safe_carry, ys_ta): + carry = from_safe(safe_carry) if reverse: i_ = length - 1 - i else: @@ -1013,7 +1022,8 @@ def body(i, carry, ys_ta): lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), ys_ta, ys) i = i + 1 - return i, carry, ys_ta + safe_carry = to_safe(carry) + return i, safe_carry, ys_ta xs_spec = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, xs) @@ -1024,8 +1034,10 @@ def body(i, carry, ys_ta): lambda y: tf.TensorArray(y.dtype if y is not None else tf.float32, size=0, # pylint: disable=g-long-lambda dynamic_size=True), ys_spec) - _, carry, ys_ta = tf.while_loop( - lambda i, *_: i < length, body, (0, init, ys_ta)) + safe_init = to_safe(init) + _, safe_carry, ys_ta = tf.while_loop( + lambda i, *_: i < length, body, (0, safe_init, ys_ta)) + carry = from_safe(safe_carry) def _stack(a, spec): if spec is None: return None diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py index 9b4d13f94..72d05b835 100644 --- a/trax/tf_numpy/extensions/extensions_test.py +++ b/trax/tf_numpy/extensions/extensions_test.py @@ -552,13 +552,17 @@ def testScanStruct(self): rng = np.random.RandomState(0) d = rng.randn(2) - def f(c_g, a_e_h): + def f(c_g_i, a_e_h): + c_g, i = c_g_i c, g = c_g - a, e, h = a_e_h + a, e_h = a_e_h + e, h = e_h assert a.shape == (3,) assert e.shape == () # pylint: disable=g-explicit-bool-comparison assert c.shape == (4,) assert g.shape == (2,) + assert i is None + assert h is None b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + tf_np.sum(tf_np.tan(d))) f = tf_np.cos(a) @@ -566,18 +570,23 @@ def f(c_g, a_e_h): g = tf_np.sin(g * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison assert f.shape == (3,) - return [c, g], (b, f, h) + return [(c, g), i], (b, [f, h]) - xs = (rng.randn(5, 3), rng.randn(5), None) - init = [rng.randn(4), rng.randn(2)] + xs = (rng.randn(5, 3), [rng.randn(5), None]) + init = [(rng.randn(4), rng.randn(2)), None] - c_g, b_f_h = extensions.scan(f, init, xs) - self.assertIsInstance(c_g, list) + c_g_i, b_f_h = extensions.scan(f, init, xs) + self.assertIsInstance(c_g_i, list) self.assertIsInstance(b_f_h, tuple) + c_g, i = c_g_i c, g = c_g - b, f, h = b_f_h + self.assertIsInstance(c_g, tuple) self.assertEqual((4,), c.shape) self.assertEqual((2,), g.shape) + self.assertIsNone(i) + b, f_h = b_f_h + f, h = f_h + self.assertIsInstance(f_h, list) self.assertEqual((5,), b.shape) self.assertEqual((5, 3), f.shape) self.assertIsNone(h)