-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxtyping not working with threads #23
Comments
patrick-kidger/jaxtyping#23 `@jaxtyped` -> `# @jaxtyped` `@typechecke`r -> `# @typechecker` `^( +)check_type` -> `$1### check_type`
Should be fixed in #24! |
I'm still having a similar issue with jaxtyping==0.2.2 and jaxtyping==0.2.3 -- from a similar but different place. Reverting back to jaxtyping==0.2.0 has no issue. Client code:
Where it crashes:
|
Can you provide a full MWE that I can run? |
It's just making any Here is a minimal working example that fails with jaxtyping==0.2.3: import jax.numpy as jnp
from jaxtyping import Array, Float as F
import threading
from typeguard import typechecked as typechecker
FloatArray = F[jnp.ndarray, '...']
def run():
a = jnp.array([[1., 2.]])
assert isinstance(a, FloatArray)
thread = threading.Thread(target=run)
thread.start()
thread.join() Error:
|
Thanks! This should be fixed in #27. |
Typecheck cannot be performed in another thread:
Error:
Environment:
3.10.6
0.3.17
0.3.15+cuda11.cudnn82
0.2.1
The text was updated successfully, but these errors were encountered: