Skip to content

Commit

Permalink
Merge pull request #16480 from gnecula:rename_bc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541803861
  • Loading branch information
jax authors committed Jun 20, 2023
2 parents 000e6b8 + 7a46383 commit c2935bf
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def func(...): ...
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_eigh_cusolver_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_eigh_lapack_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_qr_cusolver_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_qr_lapack_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_threefry2x32
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Eigh
Expand Down Expand Up @@ -395,9 +395,9 @@ def test_custom_call_coverage(self):
# stable
covering_testdatas = [
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14,
cpu_lapack_syev.data_2023_03_17,
cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17,
cpu_eigh_lapack_syev.data_2023_03_17,
cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
tf_call_tf_function.data_2023_06_02,
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
Expand Down Expand Up @@ -465,7 +465,7 @@ def check_eigh_results(self, operand, res_now, res_expected, *,
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_lapack_syevd(self, dtype_name="f32"):
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
# For lax.linalg.eigh
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
Expand All @@ -475,7 +475,7 @@ def test_cpu_lapack_syevd(self, dtype_name="f32"):
size = 8
operand = CompatTest.eigh_input((size, size), dtype)
func = lambda: CompatTest.eigh_harness((8, 8), dtype)
data = self.load_testdata(cpu_lapack_syev.data_2023_03_17[dtype_name])
data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
self.run_one_test(func, data, rtol=rtol, atol=atol,
Expand All @@ -487,15 +487,15 @@ def test_cpu_lapack_syevd(self, dtype_name="f32"):
for dtype_name in ("f32", "f64")
# We use different custom calls for sizes <= 32
for variant in ["syevj", "syevd"])
def test_gpu_cusolver_syev(self, dtype_name="f32", variant="syevj"):
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
# For lax.linalg.eigh
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
size = dict(syevj=8, syevd=36)[variant]
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
operand = CompatTest.eigh_input((size, size), dtype)
func = lambda: CompatTest.eigh_harness((size, size), dtype)
data = self.load_testdata(cuda_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))

Expand All @@ -521,15 +521,15 @@ def qr_harness(shape, dtype):
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_lapack_geqrf(self, dtype_name="f32"):
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
# For lax.linalg.qr
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")

dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
func = lambda: CompatTest.qr_harness((3, 3), dtype)
data = self.load_testdata(cpu_lapack_geqrf.data_2023_03_17[dtype_name])
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
self.run_one_test(func, data, rtol=rtol)

Expand All @@ -539,13 +539,13 @@ def test_cpu_lapack_geqrf(self, dtype_name="f32"):
for dtype_name in ("f32",)
# For batched qr we use cublas_geqrf_batched
for batched in ("batched", "unbatched"))
def test_gpu_cusolver_geqrf(self, dtype_name="f32", batched="unbatched"):
def test_cuda_qr_cusolver_geqrf(self, dtype_name="f32", batched="unbatched"):
# For lax.linalg.qr
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched]
func = lambda: CompatTest.qr_harness(shape, dtype)
data = self.load_testdata(cuda_cusolver_geqrf.data_2023_03_18[batched])
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
self.run_one_test(func, data, rtol=rtol)

def test_tpu_Qr(self):
Expand Down

0 comments on commit c2935bf

Please sign in to comment.