-
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
@bearytype
+ __call__()
+ __wrapped__
x 7.
This commit is the last in a commit chain generalizing the `@beartype` decorator to support **pseudo-callable wrapper objects** (i.e., objects defining both the `__call__()` and `__wrapped__` dunder attributes), resolving feature request #368 kindly submitted by @danielward27 (Daniel Ward). Specifically, this commit: * Adds `equinox`, `jax[cpu]`, and `jaxtyping` as optional test-time dependencies to GitHub Actions-based continuous integration (CI) workflows. For robustness, installation of these dependencies is currently confined to the platform most likely to fully support these dependencies with minimal pain and anguishing heartache: *Linux.* * Adds integration tests explicitly exercising that the `@beartype` decorator is now order-invariant with respect to the third-party `@equinox.filter_jit` *and* `@jax.jit` decorators – both of which dynamically generate pseudo-callable pseudo-wrapper objects. The `@beartype` decorator may now be chained (i.e., listed) either below or above the third-party `@equinox.filter_jit` and `@jax.jit` decorators. Since `beartype.claw` import hooks (e.g., `beartype.claw.beartype_this_package()`) forcefully chain `@beartype` above all other decorators, these hooks now transparently support: * The third-party `@equinox.filter_jit` decorator. * The third-party `@jax.jit` decorator. * All other third-party decorators creating and returning similar pseudo-callable wrapper objects... *probably*. :grimace: Examples or it only happened in the DMT hyperplane: ```python from beartype import beartype from jax import ( jit, numpy as jax_numpy, ) from jaxtyping import ( Array, Float, ) @beartype # <-- *GOOD*. @beartype goes last! patiently suffer in silence, @beartype. @jit # <-- *GOOD*. @jax.jit goes first! yoink. def what_would_chat_gpt_do( probably_hallucinate_everything: Float[Array, '']) -> Float[Array, '']: # One-liner: "Do what I say, not what I code." return probably_hallucinate_everything + 1 assert what_would_chat_gpt_do(jax_numpy.array(1.0)) == jax_numpy.array(2.0) what_would_chat_gpt_do('If this is a JAX array, we all have serious problems.') ``` ...which raises the expected type-checking violation: ```python Traceback (most recent call last): File "/home/leycec/tmp/mopy.py", line 22, in <module> what_would_chat_gpt_do('If this is a JAX array, we all have serious problems.') File "<@beartype(PjitFunction.__call__) at 0x7fd9afc96140>", line 29, in __call__ beartype.roar.BeartypeCallHintParamViolation: Object PjitFunction.__call__() parameter probably_hallucinate_everything='If this is a JAX array, we all have serious problems.' violates type hint <class 'jaxtyping.Float[Array, '']'>, as str 'If this is a JAX array, we all have serious problems.' not instance of <protocol "jaxtyping.Float[Array, '']">. ``` **kk.** That was only one example. Just pretend we repeated that *ad naseum* with `@jax.jit` replaced by `@equinox.filter_jit`. Things are getting boring here. In a desperate bid to stay awake, pretend we did more than we did. The perspicacious user may now be thinking: "WAIT. What is a `PjitFunction.__call__()`? That's ambiguous and means less than my cat licking itself. Your type-checking violation message sucks, huh?" You're *not* wrong. But we're tired. At least @beartype works now for various definitions of "works." If you just hit this ambiguous type-checking violation message in your workflow and want @beartype to justifiably do something about it, bang on our issue tracker until the cats start squalling and biting @leycec in the face. *Works every time.* Beartype: we broke our sanity for your security. (*Bland land, man!*)
- Loading branch information
Showing
6 changed files
with
406 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
#!/usr/bin/env python3 | ||
# --------------------( LICENSE )-------------------- | ||
# Copyright (c) 2014-2024 Beartype authors. | ||
# See "LICENSE" for further details. | ||
|
||
''' | ||
Project-wide **Equinox integration tests.** | ||
This submodule functionally tests the :mod:`beartype` package against the | ||
third-party :mod:`equinox` package. | ||
''' | ||
|
||
# ....................{ IMPORTS }.................... | ||
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | ||
# WARNING: To raise human-readable test errors, avoid importing from | ||
# package-specific submodules at module scope. | ||
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | ||
# WARNING: To avoid inscrutable issues with previously run unit tests requiring | ||
# forked subprocess isolation, avoid attempting to conditionally skip | ||
# JAX-dependent integration tests with standard mark decorators: e.g., | ||
# @skip_unless_package('jax') # <-- *NEVER DO THIS* srsly. never. | ||
# @skip_unless_package('jaxtyping') # <-- *NEVER DO THIS EITHER* bad is bad | ||
# | ||
# Why? Because even the mere act of attempting to decide whether JAX is | ||
# importable at early test collection time causes the first unit test isolated | ||
# to a forked subprocess to raise the following suspicious exception: | ||
# E pytest.PytestUnraisableExceptionWarning: Exception ignored in: | ||
# <function _at_fork at 0x7f3865738310> | ||
# E Traceback (most recent call last): | ||
# E File "/home/leycec/py/conda/envs/ionyou_dev/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 112, in _at_fork | ||
# E warnings.warn( | ||
# E RuntimeWarning: os.fork() was called. os.fork() is incompatible | ||
# with multithreaded code, and JAX is multithreaded, so this will | ||
# likely lead to a deadlock. | ||
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | ||
|
||
# ....................{ TESTS }.................... | ||
def test_equinox_filter_jit() -> None: | ||
''' | ||
Functional test validating that the :mod:`beartype` package successfully | ||
type-checks callables decorated by the third-party | ||
:func:`equinox.filter_jit` decorator annotated by type hints by the | ||
third-party :mod:`jaxtyping` package. | ||
See Also | ||
-------- | ||
https://github.com/beartype/beartype/issues/368 | ||
Issue strongly inspiring this functional test. | ||
''' | ||
|
||
# ....................{ IMPORTS ~ early }.................... | ||
# Defer test-specific imports. | ||
# | ||
# Note that JAX-dependent packages are *NOT* safely importable from here. | ||
from beartype import beartype | ||
from beartype.roar import BeartypeCallHintParamViolation | ||
from beartype._util.module.utilmodtest import is_package | ||
from pytest import raises | ||
|
||
#FIXME: *EVEN THIS ISN"T SAFE.* Any importation whatsoever from JAX is | ||
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain. | ||
#See similar logic in "test_jax" also requiring a similar resolution. | ||
# If any requisite JAX package is unimportable, silently reduce to a noop. | ||
if not ( | ||
is_package('equinox') and | ||
is_package('jax') and | ||
is_package('jaxtyping') | ||
): | ||
return | ||
# Else, all requisite JAX packages is importable. | ||
|
||
# ....................{ IMPORTS ~ late }.................... | ||
# Defer JAX-dependent imports. | ||
from equinox import filter_jit | ||
from jax import numpy as jax_numpy | ||
from jaxtyping import ( | ||
Array, | ||
Float, | ||
) | ||
|
||
# ....................{ LOCALS }.................... | ||
# Type hint matching a JAX array of floating-point numbers. | ||
JaxArrayOfFloats = Float[Array, ''] | ||
|
||
# JAX array of arbitrary floating-point numbers. | ||
of_those_beloved_eyes = jax_numpy.array(1.0) | ||
|
||
# ....................{ CALLABLES }.................... | ||
@beartype | ||
@filter_jit | ||
def as_if_their_genii( | ||
were_the_ministers: JaxArrayOfFloats) -> JaxArrayOfFloats: | ||
''' | ||
Arbitrary callable decorated first by :func:`equinox.filter_jit` and | ||
then by :func:`beartype.beartype`, exercising a well-known edge case. | ||
''' | ||
|
||
# Do it because it feels good, one-liner. | ||
return were_the_ministers + 1 | ||
|
||
|
||
@beartype | ||
@filter_jit | ||
def appointed_to_conduct_him( | ||
to_the_light: JaxArrayOfFloats) -> JaxArrayOfFloats: | ||
''' | ||
Arbitrary callable decorated first by :func:`equinox.filter_jit` and | ||
then by :func:`beartype.beartype`, exercising a well-known edge case. | ||
Note that fully exercising this edge case requires defining not merely | ||
one but *TWO* callables decorated in this order. Before this issue was | ||
resolved, :func:`beartype.beartype` *literally* swapped the code objects | ||
of these two callables. Believe it or not, it's beartype! O_o | ||
''' | ||
|
||
# Do it because it feels right, one-liner. | ||
return to_the_light - 1 | ||
|
||
# ....................{ PASS }.................... | ||
# Assert that these callables return the expected values when passed valid | ||
# parameters satisfying the type hints annotating these callables. | ||
assert as_if_their_genii(of_those_beloved_eyes) == ( | ||
jax_numpy.array(2.0)) | ||
assert appointed_to_conduct_him(of_those_beloved_eyes) == ( | ||
jax_numpy.array(0.0)) | ||
|
||
# ....................{ FAIL }.................... | ||
# Assert that these callables raise the expected exceptions when passed | ||
# invalid parameters violating the type hints annotating these callables. | ||
with raises(BeartypeCallHintParamViolation): | ||
as_if_their_genii('Appointed to conduct him to the light') | ||
with raises(BeartypeCallHintParamViolation): | ||
appointed_to_conduct_him('Of those belovèd eyes, the Poet sate') | ||
|
||
|
||
def test_equinox_module_subclass() -> None: | ||
''' | ||
Functional test validating that the :mod:`beartype` package successfully | ||
type-checks subclasses of superclasses defined by the third-party Equinox | ||
package annotated by type hints by the third-party :mod:`jaxtyping` package. | ||
See Also | ||
-------- | ||
https://github.com/patrick-kidger/equinox/issues/584 | ||
Upstream Equinox issue strongly inspiring this functional test. | ||
''' | ||
|
||
# ....................{ IMPORTS ~ early }.................... | ||
# Defer test-specific imports. | ||
# | ||
# Note that JAX-dependent packages are *NOT* safely importable from here. | ||
from beartype import beartype | ||
from beartype.roar import BeartypeCallHintParamViolation | ||
from beartype._util.module.utilmodtest import is_package | ||
from pytest import raises | ||
|
||
#FIXME: *EVEN THIS ISN"T SAFE.* Any importation whatsoever from JAX is | ||
#dangerous and *MUST* be isolated to a subprocess. Honestly, what a pain. | ||
#See similar logic in "test_jax" also requiring a similar resolution. | ||
# If any requisite JAX package is unimportable, silently reduce to a noop. | ||
if not ( | ||
is_package('equinox') and | ||
is_package('jax') and | ||
is_package('jaxtyping') | ||
): | ||
return | ||
# Else, all requisite JAX packages is importable. | ||
|
||
# ....................{ IMPORTS ~ late }.................... | ||
# Defer JAX-dependent imports. | ||
from equinox import Module | ||
from jax import numpy as jax_numpy | ||
from jaxtyping import ( | ||
Array, | ||
Float, | ||
) | ||
|
||
# ....................{ CLASSES }.................... | ||
@beartype | ||
class EquinoxModule(Module): | ||
''' | ||
Arbitrary subclass of the :class:`equinox.Module` superclass decorated | ||
by the :func:`beartype.beartype` decorator. | ||
''' | ||
|
||
# ....................{ CLASS VARS }.................... | ||
float_array: Float[Array, ''] | ||
''' | ||
Arbitrary class variable annotated by a :mod:`jaxtyping` type hint. | ||
''' | ||
|
||
# ....................{ METHODS }.................... | ||
def munge_array(self, python_bool: bool) -> Float[Array, '']: | ||
''' | ||
Arbitrary method accepting a trivial object to be type-checked. | ||
''' | ||
|
||
# Arbitrary one-liner is arbitrary. | ||
return self.float_array + 1 | ||
|
||
# ....................{ LOCALS }.................... | ||
# JAX-based NumPy array containing arbitrary data. | ||
jax_array = jax_numpy.array(1.0) | ||
|
||
# Arbitrary instance of the above Equinox subclass. | ||
equinox_module = EquinoxModule(jax_array) | ||
|
||
# ....................{ FAIL }.................... | ||
# Assert that this @beartype-decorated method of this Equinox instance | ||
# raises the expected type-checking violation exception when passed an | ||
# invalid parameter violating the type hint annotating that parameter. | ||
with raises(BeartypeCallHintParamViolation): | ||
equinox_module.munge_array('A string is not a boolean.') |
Oops, something went wrong.