Skip to content

Commit

Permalink
[jax2tf] Add backward compatibility tests for linalg.eig on CPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Jul 14, 2023
1 parent d66a4e3 commit c649b38
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 4 deletions.
67 changes: 63 additions & 4 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -33,6 +33,7 @@

from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_cholesky_lapack_potrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_eig_lapack_geev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_eigh_cusolver_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_eigh_lapack_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lu_lapack_getrf
Expand Down Expand Up @@ -97,6 +98,7 @@ def test_custom_call_coverage(self):
covering_testdatas = [
cpu_ducc_fft.data_2023_03_17, cpu_ducc_fft.data_2023_06_14,
cpu_cholesky_lapack_potrf.data_2023_06_19,
cpu_eig_lapack_geev.data_2023_06_19,
cpu_eigh_lapack_syev.data_2023_03_17,
cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cpu_lu_lapack_getrf.data_2023_06_14,
Expand All @@ -118,8 +120,6 @@ def test_custom_call_coverage(self):
covered_targets = covered_targets.union(data.custom_call_targets)

covered_targets = covered_targets.union({
# TODO(necula): add tests for eig on CPU
"lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev",
# TODO(necula): add tests for svd on CPU
"lapack_sgesdd", "lapack_dsesdd", "lapack_cgesdd", "lapack_zgesdd",
# TODO(necula): add tests for triangular_solve on CPU
Expand Down Expand Up @@ -162,13 +162,72 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
del input # Input is in the testdata, here for readability
func = lax.linalg.cholesky

# data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]

data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_eig_lapack_geev(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")

dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (4, 4)
def func():
# Compute the inputs to simplify the harness
input = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
return lax.linalg.eig(input,
compute_left_eigenvectors=True,
compute_right_eigenvectors=True)

data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]

def check_eig_results(res_run, res_expected, *, rtol, atol):
# Test ported from tests.linlag_test.testEig
# Norm, adjusted for dimension and type.
inner_dimension = shape[-1]
operand = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
def norm(x):
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)

def check_right_eigenvectors(a, w, vr):
self.assertTrue(
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))

def check_left_eigenvectors(a, w, vl):
rank = len(a.shape)
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
wC = jnp.conj(w)
check_right_eigenvectors(aH, wC, vl)

def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
closest_diff = min(abs(eigenvalues_array - eigenvalue))
self.assertAllClose(
closest_diff,
np.array(0., closest_diff.dtype),
atol=atol, rtol=rtol)

all_w_run, all_w_exp = res_run[0], res_expected[0]
for idx in itertools.product(*map(range, operand.shape[:-2])):
w_run, w_exp = all_w_run[idx], all_w_exp[idx]
for i in range(inner_dimension):
check_eigenvalue_is_in_array(w_run[i], w_exp)
check_eigenvalue_is_in_array(w_exp[i], w_run)

check_left_eigenvectors(operand, all_w_run, res_run[1])
check_right_eigenvectors(operand, all_w_run, res_run[2])

self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)

@staticmethod
def eigh_input(shape, dtype):
# In order to keep inputs small, we construct the input programmatically
Expand Down Expand Up @@ -441,7 +500,7 @@ def func():
data = self.load_testdata(tpu_ApproxTopK.data_2023_05_16)
self.run_one_test(func, data)

def test_cu_threefry2x32(self):
def test_cuda_threefry2x32(self):
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)

Expand Down

0 comments on commit c649b38

Please sign in to comment.