diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index c8aa86fbddad..a95a31c7ed88 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -397,8 +397,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array, if seq_lengths.dtype != jnp.dtype("int32"): raise NotImplementedError("`seq_lengths` can only be int32.") cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision) - if jax._src.lib.version < (0, 4, 9): - y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind( + y, h_n, c_n, reserve_space = rnn_fwd_p.bind( x, h_0, c_0, @@ -410,22 +409,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array, dropout=dropout, bidirectional=bidirectional, cudnn_allow_tf32=cudnn_allow_tf32) - return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace, - reserve_space) - else: - y, h_n, c_n, reserve_space = rnn_fwd_p.bind( - x, - h_0, - c_0, - w, - seq_lengths, - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - bidirectional=bidirectional, - cudnn_allow_tf32=cudnn_allow_tf32) - return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space) + return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space) def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval, @@ -436,94 +420,71 @@ def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval, num_directions = 2 if bidirectional else 1 output_shape = (batch_size, max_seq_length, num_directions * hidden_size) output_aval = core.ShapedArray(output_shape, x_aval.dtype) - if jax._src.lib.version < (0, 4, 9): - workspace_size, reserve_space_size = ( - gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error - input_size, hidden_size, num_layers, batch_size, max_seq_length, - dropout, bidirectional, cudnn_allow_tf32)) - workspace_aval = core.ShapedArray((workspace_size,), jnp.float32) - reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32) - return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval - else: + if jax._src.lib.version >= (0, 4, 17): _, reserve_space_size = ( gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional, cudnn_allow_tf32)) - reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32) - return output_aval, h_0_aval, c_0_aval, reserve_space_aval + else: + _, reserve_space_size = ( + gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error + input_size, hidden_size, num_layers, batch_size, max_seq_length, + dropout, bidirectional)) + reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32) + return output_aval, h_0_aval, c_0_aval, reserve_space_aval + +def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): + del cudnn_allow_tf32 + return fn(*args, **kw) rnn_fwd_p = core.Primitive('rnn_fwd') rnn_fwd_p.multiple_results = True rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p)) rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: - mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') + if jax._src.lib.version >= (0, 4, 17): + mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') + else: + mlir.register_lowering( + rnn_fwd_p, + partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_lowering), + platform='cuda' + ) def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool, precision: lax.PrecisionLike, residuals, gradients): cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision) - if jax._src.lib.version < (0, 4, 9): - x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals - dy, dh_n, dc_n = gradients - dx, dh_0, dc_0, dw = rnn_bwd_p.bind( - dy, - dh_n, - dc_n, - x, - h_0, - c_0, - w, - y, - workspace, - reserve_space, - seq_lengths, - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - bidirectional=bidirectional, - cudnn_allow_tf32=cudnn_allow_tf32) - return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) - else: - x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals - dy, dh_n, dc_n = gradients - dx, dh_0, dc_0, dw = rnn_bwd_p.bind( - dy, - dh_n, - dc_n, - x, - h_0, - c_0, - w, - y, - reserve_space, - seq_lengths, - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - bidirectional=bidirectional, - cudnn_allow_tf32=cudnn_allow_tf32) - return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) - - -if jax._src.lib.version < (0, 4, 9): - def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, - w_aval, y_aval, workspace_aval, reserve_space_aval, + x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals + dy, dh_n, dc_n = gradients + dx, dh_0, dc_0, dw = rnn_bwd_p.bind( + dy, + dh_n, + dc_n, + x, + h_0, + c_0, + w, + y, + reserve_space, + seq_lengths, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + bidirectional=bidirectional, + cudnn_allow_tf32=cudnn_allow_tf32) + return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths)) + + +def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore + w_aval, y_aval, reserve_space_aval, seq_lengths_aval, input_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool, cudnn_allow_tf32: bool): - return x_aval, h0_aval, c0_aval, w_aval -else: - def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore - w_aval, y_aval, reserve_space_aval, - seq_lengths_aval, input_size: int, hidden_size: int, - num_layers: int, dropout: float, bidirectional: bool, - cudnn_allow_tf32: bool): - return x_aval, h0_aval, c0_aval, w_aval + return x_aval, h0_aval, c0_aval, w_aval rnn_bwd_p = core.Primitive('rnn_bwd') @@ -531,7 +492,14 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p)) rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval) if gpu_rnn: - mlir.register_lowering( - rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda') + if jax._src.lib.version >= (0, 4, 17): + mlir.register_lowering( + rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda') + else: + mlir.register_lowering( + rnn_bwd_p, + partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_bwd_lowering), + platform='cuda' + ) lstm.defvjp(lstm_fwd, lstm_bwd) diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 5c0e46ab6617..1384dd58d0ea 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -16,7 +16,6 @@ import numpy as np import jax import jax.numpy as jnp -from jax._src import lib from jax._src import test_util as jtu from jax.experimental import rnn @@ -40,11 +39,6 @@ class RnnTest(jtu.JaxTestCase): @jax.default_matmul_precision("float32") def test_lstm(self, batch_size: int, seq_len: int, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool): - if lib.version < (0, 4, 7): - # TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is - # bumped - self.skipTest("Need latest jaxlib for this test to pass.") - # TODO(phawkins): Partially disable this on cudnn version per b/281071013 if (batch_size == 1 and seq_len == 4 and input_size == 1 and hidden_size == 6 and num_layers == 4 and bidirectional == False):