# Chex : 信頼性のあるJAXコードを書くためのライブラリ

This includes utils to help:

    Instrument your code (e.g. assertions)
    Debug (e.g. transforming pmaps in vmaps within a context manager).
    Test JAX code across many variants (e.g. jitted vs non-jitted)

In [7]:
import jax.numpy as jnp
import jax
import chex

In [3]:
@chex.dataclass
class Parameters:
    x: chex.ArrayDevice
    y: chex.ArrayDevice

In [5]:
parameters = Parameters(
    x = jnp.ones((2, 2)),
    y = jnp.ones((1, 2)),
)



In [10]:
# ValueError: Mappable dataclass constructor doesn't support positional args.
#parameters = Parameters(
#    jnp.ones((2, 2)),
#    jnp.ones((1, 2)),
#)

# Assertions

In [11]:
from chex import assert_shape, assert_rank

In [13]:
x = jnp.ones((2,3), dtype=jnp.float32)

In [16]:
assert_shape(x, (2, 3))

# Test variants

jitコンパイルされたバージョンとされていないバージョンの両方のテストを実行できる

In [22]:
from absl.testing import parameterized

In [23]:
def fn(x, y):
    return x + y

In [27]:
class ExampleTest(chex.TestCase):
    @chex.variants(with_jit=True, without_jit=True)
    def test_ex1(self):
        var_fn = self.variant(fn)
        self.assertEqual(fn(1, 2), 3)
        self.assertEqual(var_fn(1, 2), fn(1, 2))
    
    
class ExampleParameterizedTest(parameterized.TestCase):
    @chex.variants(with_jit=True, without_jit=True)
    @parameterized.named_parameters(
        ('case_positive', 1, 2, 3),
        ('case_negative', -1, -2, -3),
    )
    def test(self, arg_1, arg_2, expected):
        @self.variant
        def var_fn(x, y):
            return x + y

        self.assertEqual(var_fn(arg_1, arg_2), expected)
    
