-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we check statement-level annotations? #7
Comments
So checking intermediate annotations like this is really the job of a runtime type checker (typeguard/beartype) -- just like how checking the argument/return annotations is already the job of a runtime type checker. In either case all jaxtyping does is provide Probably what would be needed here would be a decorator that parses the abstract syntax tree for a function, detects annotations, and then inserts manual import ast
import inspect
import beartype
import jaxtyping
def check_intermediate_annotations(fn):
ast = ast.parse(inspect.getsource(fn))
# rewrite ast using ast.NodeVisitor and ast.NodeTransformer
return eval(ast.unparse(ast))
@jaxtyping.jaxtyped
@beartype.beartype
@check_intermediate_annotations
def bar(y):
x: LxDk = foo()
return x Although I note that this kind of source-file parsing and unparsing is a bit fraught with edge cases (e.g. it won't work in a REPL). Unfortunately Python just doesn't provide a good way to handle this kind of thing. Off the top of my head I don't know of a runtime type checker that does this, though. (CC @leycec for interest.) |
Fascinating feature request intensifies. Thanks so much for pinging me into the fray, @patrick-kidger. Exactly as you suspect, no actively maintained runtime type-checker that I know of currently performs static type-checking at runtime. Sad cat emoji is sad. 😿 That said, @beartype does have an open feature request encouraging us to eventually do this. This is fun stuff, because it's hard stuff. Actually, it's not too hard to naively perform static type-checking at runtime by combining the fearsome powers of import hooks + abstract syntax tree (AST) inspection. But it's really hard to do so without destroying runtime performance – especially in pure Python. Extremely aggressive on-disk caching (e.g., like the sordid pile of JSON files that I Don't Like What I'm HearingUntil then, @beartype provides a reasonably well-documented procedural API for type-checking arbitrary things against arbitrary type hints at any time: query : LxDk = linear(head.query, t1)
key : LxDk = linear(head.key, t1)
# Runtime type-check everything above.
from beartype.abby import die_if_unbearable
die_if_unbearable(query, LxDk)
die_if_unbearable(key, LxDk) If that's a bit too much egregious boilerplate, consider wrapping the above calls to from beartype.abby import die_if_unbearable
def linear_typed(*args, **kwargs, hint: object):
'''
Linear JAX array runtime type-checked by the passed type hint.
'''
linear_array = linear(*args, **kwargs)
die_if_unbearable(linear_array, hint)
return linear_array
query : LxDk = linear_typed(head.query, t1, LxDk)
key : LxDk = linear_typed(head.key, t1, LxDk) Technically, that still violates DRY a bit by duplicating the Let's pray I actually do something and make static runtime type-checking happen, everybody. 😮💨 |
Thanks @leycec ! Just to clarify on "destroying runtime performance", do you mean making it worse than query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk) ? |
Because the above seems acceptable to me, particularly under a jax-style define-by-run scheme. And if it's not I might wrap die_if_unbearable(query, LxDk) if rand() > 0.93 or die_if_unbearable(query, LxDk) if (beartype_time/total_program_time < bearable_beartype_overhead or
rand() > (beartype_time/total_program_time / bearable_beartype_overhead)) [I realise there are simplifications/corrections of the last logic, hope the sentiment is clear] |
Right, so JAX sits in a lovely spot for applicability of runtime type checking, because Python is only ever being used as a metaprogramming language for XLA. In this context I wouldn't worry about the extra runtime overhead. |
Agh! I should be more explicit in my jargon, especially when slinging around suspicious phrases like "destroying runtime performance." So...
query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk) Yes. Much, much worse than that. By "destroy runtime performance," I was instead referring to the hypothetical crippling performance burden of doing static type-checking analysis at runtime via import hooks and AST inspection. @beartype don't do that yet; nobody does. The practical difficulties of optimizing static analysis at runtime is a Big Reason™ why. But... someday @beartype or somebody else will go there. A runtime type-checker that efficiently performs static analysis at runtime would effectively obsolete standard static type-checkers (e.g., Until then, we collectively wish upon a rainbow. 🌈
Absolutely. @beartype has been profiled to be disgustingly fast. That's the whole point, really. @beartype is actually two orders of magnitude faster than even You should never need to conditionally disable @beartype. If you do, bang on our issue tracker and we'll promptly resolve the performance regressions you are seeing. Until then, the best way to use @beartype is to just always use @beartype.
die_if_unbearable(query, LxDk) if rand() > 0.93 ...heh. Probabilistic runtime type-checking. Love it! I must acknowledge cleverness when I see it. Admittedly, that also makes my eye twitch spasmodically. If you do end up profiling Now I know. And knowledge is half the battle.
These are sweet, soothing words. Please say more relieving things like this. 😌 |
Thanks both for this discussion - I implemented something quick based on @patrick-kidger's spike above, and it seems to work quite nicely. Next step is to integrate with jaxtyping, but I thought I would put it out here... https://github.com/awf/awfutils#typecheck A fairly direct copy of your suggestion above... https://github.com/awf/awfutils/blob/7359acb6528325f6770fc9c28aab86f548d22ad4/typecheck.py#L133 |
Extremely impressive. Your current approach is outrageously useful, but appears to currently only support # I suspect this fails hard, but am lazy and thus did not test.
@typecheck
def foo(x : List[int], y : int):
z : List[int] = x * y
w : float = z[0] * 3.2
return w
foo([3, 2, 1], 1.3) Is that right? If so, that's still impressive tech for a several hundred-line decorator. I'll open up a feature request on your issue tracker to see if we can't trivially generalize that to support all (...or at least most) PEP-compliant type hints, @awf. In short, this is so good. \o/ |
Not really documented yet, but for anyone coming across this issue: this now exists in beartype (beartype/beartype#7 (comment))! |
Indeed. As @patrick-kidger notes, our new # In your top-level "{your_package}.__init__" submodule:
from beartype.claw import beartype_this_package
beartype_this_package() That's it. @beartype will now type-check statement-level annotations in concert with
...yeah. Noticed that, huh? I've intentionally left Until then, one-liners for great QA justice! 💪 🐻 |
One of my dreams for this package was to turn code like this
into this
where we have written
earlier in the
@jaxtyped
function.But it looks as if these annotations aren't checked?
I haven't looked into how hard that might be - is it a lot of work?
The text was updated successfully, but these errors were encountered: