Skip to content

Commit

Permalink
@bearytype + __call__() + __wrapped__ x 7.
Browse files Browse the repository at this point in the history
This commit is the last in a commit chain generalizing the `@beartype`
decorator to support **pseudo-callable wrapper objects** (i.e., objects
defining both the `__call__()` and `__wrapped__` dunder attributes),
resolving feature request #368 kindly submitted by @danielward27 (Daniel
Ward). Specifically, this commit:

* Adds `equinox`, `jax[cpu]`, and `jaxtyping` as optional test-time
  dependencies to GitHub Actions-based continuous integration (CI)
  workflows. For robustness, installation of these dependencies is
  currently confined to the platform most likely to fully support these
  dependencies with minimal pain and anguishing heartache: *Linux.*
* Adds integration tests explicitly exercising that the `@beartype`
  decorator is now order-invariant with respect to the third-party
  `@equinox.filter_jit` *and* `@jax.jit` decorators – both of which
  dynamically generate pseudo-callable pseudo-wrapper objects.

The `@beartype` decorator may now be chained (i.e., listed) either
below or above the third-party `@equinox.filter_jit` and `@jax.jit`
decorators. Since `beartype.claw` import hooks (e.g.,
`beartype.claw.beartype_this_package()`) forcefully chain `@beartype`
above all other decorators, these hooks now transparently support:

* The third-party `@equinox.filter_jit` decorator.
* The third-party `@jax.jit` decorator.
* All other third-party decorators creating and returning similar
  pseudo-callable wrapper objects... *probably*. :grimace:

Examples or it only happened in the DMT hyperplane:

```python
from beartype import beartype
from jax import (
    jit,
    numpy as jax_numpy,
)
from jaxtyping import (
    Array,
    Float,
)

@beartype  # <-- *GOOD*. @beartype goes last! patiently suffer in silence, @beartype.
@jit       # <-- *GOOD*. @jax.jit goes first! yoink.
def what_would_chat_gpt_do(
    probably_hallucinate_everything: Float[Array, '']) -> Float[Array, '']:

    # One-liner: "Do what I say, not what I code."
    return probably_hallucinate_everything + 1

assert what_would_chat_gpt_do(jax_numpy.array(1.0)) == jax_numpy.array(2.0)
what_would_chat_gpt_do('If this is a JAX array, we all have serious problems.')
```

...which raises the expected type-checking violation:

```python
Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 22, in <module>
    what_would_chat_gpt_do('If this is a JAX array, we all have serious problems.')
  File "<@beartype(PjitFunction.__call__) at 0x7fd9afc96140>", line 29, in __call__
beartype.roar.BeartypeCallHintParamViolation: Object
PjitFunction.__call__() parameter probably_hallucinate_everything='If
this is a JAX array, we all have serious problems.' violates type hint
<class 'jaxtyping.Float[Array, '']'>, as str 'If this is a JAX array, we
all have serious problems.' not instance of <protocol
"jaxtyping.Float[Array, '']">.
```

**kk.** That was only one example. Just pretend we repeated that *ad
naseum* with `@jax.jit` replaced by `@equinox.filter_jit`. Things are
getting boring here. In a desperate bid to stay awake, pretend we did
more than we did.

The perspicacious user may now be thinking: "WAIT. What is a
`PjitFunction.__call__()`? That's ambiguous and means less than my cat
licking itself. Your type-checking violation message sucks, huh?"

You're *not* wrong. But we're tired. At least @beartype works now for
various definitions of "works." If you just hit this ambiguous
type-checking violation message in your workflow and want @beartype to
justifiably do something about it, bang on our issue tracker until the
cats start squalling and biting @leycec in the face. *Works every time.*

Beartype: we broke our sanity for your security. (*Bland land, man!*)
  • Loading branch information
leycec committed May 3, 2024
1 parent 8fc6e2a commit 1b3b589
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 123 deletions.
95 changes: 66 additions & 29 deletions beartype/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,26 +414,24 @@ def _convert_version_str_to_tuple(version_str: str): # -> _Tuple[int, ...]:
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

LIBS_TESTTIME_OPTIONAL = (
# Required by optional Equinox-specific integration tests.
'equinox',
# ....................{ DOCOS }....................
# Documentation-centric optional test-time dependencies.

# Require a reasonably recent version of mypy known to behave well. Less
# recent versions are significantly deficient with respect to error
# reporting and *MUST* thus be blacklisted.
# Required by optional Sphinx-specific integration tests.
#
# Note that PyPy currently fails to support mypy. See also this official
# documentation discussing this regrettable incompatibility:
# https://mypy.readthedocs.io/en/stable/faq.html#does-it-run-on-pypy
'mypy >=0.800; platform_python_implementation != "PyPy"',
# Note that Sphinx currently provokes unrelated test failures under Python
# 3.7 with obscure deprecation warnings. Since *ALL* of this only applies
# to Python 3.7, we crudely circumvent this nonsense by simply avoiding
# installing Sphinx under Python 3.7. The exception resembles:
# FAILED
# ../../../beartype_test/a00_unit/a20_util/test_utilobject.py::test_is_object_hashable
# - beartype.roar.BeartypeModuleUnimportableWarning: Ignoring module
# "pkg_resources.__init__" importation exception DeprecationWarning:
# Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
'sphinx; python_version >= "3.8.0"',

#FIXME: Let's avoid attempting to remotely compile with nuitka under GitHub
#Actions-hosted continuous integration (CI) for the moment. Doing so is
#non-trivial enough under local testing workflows. *sigh*
# Require a reasonably recent version of nuitka if the current platform is a
# Linux distribution *AND* the active Python interpreter targets Python >=
# 3.8. For questionable reasons best ignored, nuitka fails to compile
# beartype under Python <= 3.7.
# 'nuitka >=1.2.6; sys_platform == "linux" and python_version >= "3.8.0"',
# ....................{ SCIENCE ~ data }....................
# Data science-centric optional test-time dependencies.

#FIXME: Consider dropping the 'and platform_python_implementation != "PyPy"'
#clause now that "tox.ini" installs NumPy wheels from a third-party vendor
Expand Down Expand Up @@ -470,31 +468,70 @@ def _convert_version_str_to_tuple(version_str: str): # -> _Tuple[int, ...]:
# Required by optional Pandera-specific integration tests.
'pandera',

# Required by optional Sphinx-specific integration tests.
#
# Note that Sphinx currently provokes unrelated test failures under Python
# 3.7 with obscure deprecation warnings. Since *ALL* of this only applies
# to Python 3.7, we crudely circumvent this nonsense by simply avoiding
# installing Sphinx under Python 3.7. The exception resembles:
# FAILED
# ../../../beartype_test/a00_unit/a20_util/test_utilobject.py::test_is_object_hashable
# - beartype.roar.BeartypeModuleUnimportableWarning: Ignoring module
# "pkg_resources.__init__" importation exception DeprecationWarning:
# Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
'sphinx; python_version >= "3.8.0"',
# ....................{ SCIENCE ~ data : ml }....................
# Machine learning-centric optional test-time dependencies. These
# dependencies are well-known to be extremely non-trivial to install,
# typically due to conditionally depending on low-level C(++)-driven
# hardware GPU and TPU compute APIs (e.g., Nvidia CUDA, AMD OpenCL). To
# improve the likelihood of success on both local and remote workflows,
# these dependencies are intentionally confined to Linux.

# Required by optional Equinox-specific integration tests. Note that Equinox
# requires JAX.
'equinox; sys_platform == "linux"',

# Required by optional JAX-specific integration tests and JAX-dependent
# packages (e.g., Equinox). Note that JAX *MUST* be installed with one or
# more subscripted extras. Omitting extras installs only the high-level
# pure-Python "jax" package *WITHOUT* also installing a low-level
# hardware-specific variant of the typically C-based "jaxlib" package, which
# results in the "jax" package being unimportable and thus non-working.
# In this case, specifying the "cpu" extra also installs a low-level
# CPU-specific variant of the typically C-based "jaxlib" package. Since
# GitHub Actions-based continuous integration (CI) workflows are unlikely to
# reliably provide GPU or TPU compute hardware or APIs, the only safe and
# reliable alternative is CPU-specific.
'jax[cpu]; sys_platform == "linux"',

# Required by optional JAX- and Equinox-specific integration tests.
'jaxtyping; sys_platform == "linux"',

#FIXME: Temporarily disabled for sanity.
# Required by optional PyTorch-specific integration tests.
#
# Note that PyTorch has yet to release a Python >= 3.12-compatible version.
# 'torch; python_version < "3.12.0"',

# ....................{ TESTING }....................
# Testing-centric optional test-time dependencies.

# ....................{ TYPING }....................
# Typing-centric optional test-time dependencies.

# Require a reasonably recent version of mypy known to behave well. Less
# recent versions are significantly deficient with respect to error
# reporting and *MUST* thus be blacklisted.
#
# Note that PyPy currently fails to support mypy. See also this official
# documentation discussing this regrettable incompatibility:
# https://mypy.readthedocs.io/en/stable/faq.html#does-it-run-on-pypy
'mypy >=0.800; platform_python_implementation != "PyPy"',

# Required to exercise third-party backports of type hint factories
# published by the standard "typing" module under newer versions of Python.
(
f'typing-extensions >='
f'{_LIB_RUNTIME_OPTIONAL_VERSION_MINIMUM_TYPING_EXTENSIONS}'
),

#FIXME: Let's avoid attempting to remotely compile with nuitka under GitHub
#Actions-hosted continuous integration (CI) for the moment. Doing so is
#non-trivial enough under local testing workflows. *sigh*
# Require a reasonably recent version of nuitka if the current platform is a
# Linux distribution *AND* the active Python interpreter targets Python >=
# 3.8. For questionable reasons best ignored, nuitka fails to compile
# beartype under Python <= 3.7.
# 'nuitka >=1.2.6; sys_platform == "linux" and python_version >= "3.8.0"',
)
'''
**Optional developer test-time package dependencies** (i.e., dependencies
Expand Down
Empty file.
213 changes: 213 additions & 0 deletions beartype_test/a90_func/z90_lib/a80_jax/test_equinox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#!/usr/bin/env python3
# --------------------( LICENSE )--------------------
# Copyright (c) 2014-2024 Beartype authors.
# See "LICENSE" for further details.

'''
Project-wide **Equinox integration tests.**
This submodule functionally tests the :mod:`beartype` package against the
third-party :mod:`equinox` package.
'''

# ....................{ IMPORTS }....................
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# WARNING: To raise human-readable test errors, avoid importing from
# package-specific submodules at module scope.
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# WARNING: To avoid inscrutable issues with previously run unit tests requiring
# forked subprocess isolation, avoid attempting to conditionally skip
# JAX-dependent integration tests with standard mark decorators: e.g.,
# @skip_unless_package('jax') # <-- *NEVER DO THIS* srsly. never.
# @skip_unless_package('jaxtyping') # <-- *NEVER DO THIS EITHER* bad is bad
#
# Why? Because even the mere act of attempting to decide whether JAX is
# importable at early test collection time causes the first unit test isolated
# to a forked subprocess to raise the following suspicious exception:
# E pytest.PytestUnraisableExceptionWarning: Exception ignored in:
# <function _at_fork at 0x7f3865738310>
# E Traceback (most recent call last):
# E File "/home/leycec/py/conda/envs/ionyou_dev/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 112, in _at_fork
# E warnings.warn(
# E RuntimeWarning: os.fork() was called. os.fork() is incompatible
# with multithreaded code, and JAX is multithreaded, so this will
# likely lead to a deadlock.
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

# ....................{ TESTS }....................
def test_equinox_filter_jit() -> None:
'''
Functional test validating that the :mod:`beartype` package successfully
type-checks callables decorated by the third-party
:func:`equinox.filter_jit` decorator annotated by type hints by the
third-party :mod:`jaxtyping` package.
See Also
--------
https://github.com/beartype/beartype/issues/368
Issue strongly inspiring this functional test.
'''

# ....................{ IMPORTS ~ early }....................
# Defer test-specific imports.
#
# Note that JAX-dependent packages are *NOT* safely importable from here.
from beartype import beartype
from beartype.roar import BeartypeCallHintParamViolation
from beartype._util.module.utilmodtest import is_package
from pytest import raises

#FIXME: *EVEN THIS ISN"T SAFE.* Any importation whatsoever from JAX is
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain.
#See similar logic in "test_jax" also requiring a similar resolution.
# If any requisite JAX package is unimportable, silently reduce to a noop.
if not (
is_package('equinox') and
is_package('jax') and
is_package('jaxtyping')
):
return
# Else, all requisite JAX packages is importable.

# ....................{ IMPORTS ~ late }....................
# Defer JAX-dependent imports.
from equinox import filter_jit
from jax import numpy as jax_numpy
from jaxtyping import (
Array,
Float,
)

# ....................{ LOCALS }....................
# Type hint matching a JAX array of floating-point numbers.
JaxArrayOfFloats = Float[Array, '']

# JAX array of arbitrary floating-point numbers.
of_those_beloved_eyes = jax_numpy.array(1.0)

# ....................{ CALLABLES }....................
@beartype
@filter_jit
def as_if_their_genii(
were_the_ministers: JaxArrayOfFloats) -> JaxArrayOfFloats:
'''
Arbitrary callable decorated first by :func:`equinox.filter_jit` and
then by :func:`beartype.beartype`, exercising a well-known edge case.
'''

# Do it because it feels good, one-liner.
return were_the_ministers + 1


@beartype
@filter_jit
def appointed_to_conduct_him(
to_the_light: JaxArrayOfFloats) -> JaxArrayOfFloats:
'''
Arbitrary callable decorated first by :func:`equinox.filter_jit` and
then by :func:`beartype.beartype`, exercising a well-known edge case.
Note that fully exercising this edge case requires defining not merely
one but *TWO* callables decorated in this order. Before this issue was
resolved, :func:`beartype.beartype` *literally* swapped the code objects
of these two callables. Believe it or not, it's beartype! O_o
'''

# Do it because it feels right, one-liner.
return to_the_light - 1

# ....................{ PASS }....................
# Assert that these callables return the expected values when passed valid
# parameters satisfying the type hints annotating these callables.
assert as_if_their_genii(of_those_beloved_eyes) == (
jax_numpy.array(2.0))
assert appointed_to_conduct_him(of_those_beloved_eyes) == (
jax_numpy.array(0.0))

# ....................{ FAIL }....................
# Assert that these callables raise the expected exceptions when passed
# invalid parameters violating the type hints annotating these callables.
with raises(BeartypeCallHintParamViolation):
as_if_their_genii('Appointed to conduct him to the light')
with raises(BeartypeCallHintParamViolation):
appointed_to_conduct_him('Of those belovèd eyes, the Poet sate')


def test_equinox_module_subclass() -> None:
'''
Functional test validating that the :mod:`beartype` package successfully
type-checks subclasses of superclasses defined by the third-party Equinox
package annotated by type hints by the third-party :mod:`jaxtyping` package.
See Also
--------
https://github.com/patrick-kidger/equinox/issues/584
Upstream Equinox issue strongly inspiring this functional test.
'''

# ....................{ IMPORTS ~ early }....................
# Defer test-specific imports.
#
# Note that JAX-dependent packages are *NOT* safely importable from here.
from beartype import beartype
from beartype.roar import BeartypeCallHintParamViolation
from beartype._util.module.utilmodtest import is_package
from pytest import raises

#FIXME: *EVEN THIS ISN"T SAFE.* Any importation whatsoever from JAX is
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain.
#See similar logic in "test_jax" also requiring a similar resolution.
# If any requisite JAX package is unimportable, silently reduce to a noop.
if not (
is_package('equinox') and
is_package('jax') and
is_package('jaxtyping')
):
return
# Else, all requisite JAX packages is importable.

# ....................{ IMPORTS ~ late }....................
# Defer JAX-dependent imports.
from equinox import Module
from jax import numpy as jax_numpy
from jaxtyping import (
Array,
Float,
)

# ....................{ CLASSES }....................
@beartype
class EquinoxModule(Module):
'''
Arbitrary subclass of the :class:`equinox.Module` superclass decorated
by the :func:`beartype.beartype` decorator.
'''

# ....................{ CLASS VARS }....................
float_array: Float[Array, '']
'''
Arbitrary class variable annotated by a :mod:`jaxtyping` type hint.
'''

# ....................{ METHODS }....................
def munge_array(self, python_bool: bool) -> Float[Array, '']:
'''
Arbitrary method accepting a trivial object to be type-checked.
'''

# Arbitrary one-liner is arbitrary.
return self.float_array + 1

# ....................{ LOCALS }....................
# JAX-based NumPy array containing arbitrary data.
jax_array = jax_numpy.array(1.0)

# Arbitrary instance of the above Equinox subclass.
equinox_module = EquinoxModule(jax_array)

# ....................{ FAIL }....................
# Assert that this @beartype-decorated method of this Equinox instance
# raises the expected type-checking violation exception when passed an
# invalid parameter violating the type hint annotating that parameter.
with raises(BeartypeCallHintParamViolation):
equinox_module.munge_array('A string is not a boolean.')
Loading

0 comments on commit 1b3b589

Please sign in to comment.