Skip to content

Commit

Permalink
Fix test breakage in RNN test with old jaxlibs.
Browse files Browse the repository at this point in the history
Remove some outdated version guards.
  • Loading branch information
hawkinsp committed Sep 20, 2023
1 parent 5122160 commit f52926e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 94 deletions.
144 changes: 56 additions & 88 deletions jax/experimental/rnn.py
Expand Up @@ -397,8 +397,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
if seq_lengths.dtype != jnp.dtype("int32"):
raise NotImplementedError("`seq_lengths` can only be int32.")
cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision)
if jax._src.lib.version < (0, 4, 9):
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
y, h_n, c_n, reserve_space = rnn_fwd_p.bind(
x,
h_0,
c_0,
Expand All @@ -410,22 +409,7 @@ def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
dropout=dropout,
bidirectional=bidirectional,
cudnn_allow_tf32=cudnn_allow_tf32)
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace,
reserve_space)
else:
y, h_n, c_n, reserve_space = rnn_fwd_p.bind(
x,
h_0,
c_0,
w,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
cudnn_allow_tf32=cudnn_allow_tf32)
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space)
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, reserve_space)


def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
Expand All @@ -436,102 +420,86 @@ def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
num_directions = 2 if bidirectional else 1
output_shape = (batch_size, max_seq_length, num_directions * hidden_size)
output_aval = core.ShapedArray(output_shape, x_aval.dtype)
if jax._src.lib.version < (0, 4, 9):
workspace_size, reserve_space_size = (
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
input_size, hidden_size, num_layers, batch_size, max_seq_length,
dropout, bidirectional, cudnn_allow_tf32))
workspace_aval = core.ShapedArray((workspace_size,), jnp.float32)
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval
else:
if jax._src.lib.version >= (0, 4, 17):
_, reserve_space_size = (
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
input_size, hidden_size, num_layers, batch_size, max_seq_length,
dropout, bidirectional, cudnn_allow_tf32))
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
return output_aval, h_0_aval, c_0_aval, reserve_space_aval
else:
_, reserve_space_size = (
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
input_size, hidden_size, num_layers, batch_size, max_seq_length,
dropout, bidirectional))
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
return output_aval, h_0_aval, c_0_aval, reserve_space_aval


def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
del cudnn_allow_tf32
return fn(*args, **kw)

rnn_fwd_p = core.Primitive('rnn_fwd')
rnn_fwd_p.multiple_results = True
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
if gpu_rnn:
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
if jax._src.lib.version >= (0, 4, 17):
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
else:
mlir.register_lowering(
rnn_fwd_p,
partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_lowering),
platform='cuda'
)


def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool, precision: lax.PrecisionLike,
residuals, gradients):
cudnn_allow_tf32 = _lstm_cudnn_allow_tf32(precision)
if jax._src.lib.version < (0, 4, 9):
x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals
dy, dh_n, dc_n = gradients
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
dy,
dh_n,
dc_n,
x,
h_0,
c_0,
w,
y,
workspace,
reserve_space,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
cudnn_allow_tf32=cudnn_allow_tf32)
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
else:
x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals
dy, dh_n, dc_n = gradients
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
dy,
dh_n,
dc_n,
x,
h_0,
c_0,
w,
y,
reserve_space,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
cudnn_allow_tf32=cudnn_allow_tf32)
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))


if jax._src.lib.version < (0, 4, 9):
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
w_aval, y_aval, workspace_aval, reserve_space_aval,
x, h_0, c_0, w, seq_lengths, y, reserve_space = residuals
dy, dh_n, dc_n = gradients
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
dy,
dh_n,
dc_n,
x,
h_0,
c_0,
w,
y,
reserve_space,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
cudnn_allow_tf32=cudnn_allow_tf32)
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))


def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
w_aval, y_aval, reserve_space_aval,
seq_lengths_aval, input_size: int, hidden_size: int,
num_layers: int, dropout: float, bidirectional: bool,
cudnn_allow_tf32: bool):
return x_aval, h0_aval, c0_aval, w_aval
else:
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, # type: ignore
w_aval, y_aval, reserve_space_aval,
seq_lengths_aval, input_size: int, hidden_size: int,
num_layers: int, dropout: float, bidirectional: bool,
cudnn_allow_tf32: bool):
return x_aval, h0_aval, c0_aval, w_aval
return x_aval, h0_aval, c0_aval, w_aval


rnn_bwd_p = core.Primitive('rnn_bwd')
rnn_bwd_p.multiple_results = True
rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p))
rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval)
if gpu_rnn:
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
if jax._src.lib.version >= (0, 4, 17):
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
else:
mlir.register_lowering(
rnn_bwd_p,
partial(_gpu_lowering_strip_tf32, gpu_rnn.cudnn_rnn_bwd_lowering),
platform='cuda'
)

lstm.defvjp(lstm_fwd, lstm_bwd)
6 changes: 0 additions & 6 deletions tests/experimental_rnn_test.py
Expand Up @@ -16,7 +16,6 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import lib
from jax._src import test_util as jtu
from jax.experimental import rnn

Expand All @@ -40,11 +39,6 @@ class RnnTest(jtu.JaxTestCase):
@jax.default_matmul_precision("float32")
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
if lib.version < (0, 4, 7):
# TODO(sharadmv, zhangqiaorjc): remove this when minimum jaxlib version is
# bumped
self.skipTest("Need latest jaxlib for this test to pass.")

# TODO(phawkins): Partially disable this on cudnn version per b/281071013
if (batch_size == 1 and seq_len == 4 and input_size == 1 and
hidden_size == 6 and num_layers == 4 and bidirectional == False):
Expand Down

0 comments on commit f52926e

Please sign in to comment.