Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxtyping not working with threads #23

Closed
ayaka14732 opened this issue Sep 16, 2022 · 5 comments · Fixed by #27
Closed

jaxtyping not working with threads #23

ayaka14732 opened this issue Sep 16, 2022 · 5 comments · Fixed by #27

Comments

@ayaka14732
Copy link

Typecheck cannot be performed in another thread:

import jax.numpy as np
from jaxtyping import Array, Float as F, jaxtyped
import threading
from typeguard import typechecked as typechecker

@jaxtyped
@typechecker
def add(x: F[Array, 'a b'], y: F[Array, 'a b']) -> F[Array, 'a b']:
    return x + y

def run():
    a = np.array([[1., 2.]])
    b = np.array([[2., 3.]])
    c = add(a, b)
    print(c)

thread = threading.Thread(target=run)
thread.start()
thread.join()

Error:

Exception in thread Thread-1 (run):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ayaka/Projects/test/1.py", line 14, in run
    c = add(a, b)
  File "/home/ayaka/Projects/test/venv/lib/python3.10/site-packages/jaxtyping/decorator.py", line 31, in wrapper
    storage.memo_stack.append(({}, {}, {}))
AttributeError: '_thread._local' object has no attribute 'memo_stack'

Environment:

  • Python 3.10.6
  • jax 0.3.17
  • jaxlib 0.3.15+cuda11.cudnn82
  • jaxtyping 0.2.1
ayaka14732 added a commit to ayaka14732/bart-base-jax that referenced this issue Sep 16, 2022
patrick-kidger/jaxtyping#23

`@jaxtyped` -> `# @jaxtyped`
`@typechecke`r -> `# @typechecker`
`^( +)check_type` -> `$1### check_type`
@patrick-kidger
Copy link
Owner

Should be fixed in #24!

@wookayin
Copy link

wookayin commented Sep 20, 2022

I'm still having a similar issue with jaxtyping==0.2.2 and jaxtyping==0.2.3 -- from a similar but different place. Reverting back to jaxtyping==0.2.0 has no issue.

Client code:

assert isinstance(observations, types.FloatArray)

Where it crashes:

File "...python3.10/site-packages/jaxtyping/array_types.py", line 153, in __instancecheck__
  153       if len(storage.memo_stack) == 0:   

@patrick-kidger
Copy link
Owner

Can you provide a full MWE that I can run?

@wookayin
Copy link

It's just making any isinstance call within a thread:

Here is a minimal working example that fails with jaxtyping==0.2.3:

import jax.numpy as jnp
from jaxtyping import Array, Float as F
import threading
from typeguard import typechecked as typechecker

FloatArray = F[jnp.ndarray, '...']

def run():
    a = jnp.array([[1., 2.]])
    assert isinstance(a, FloatArray)

thread = threading.Thread(target=run)
thread.start()
thread.join()

Error:

  File "test.py", line 10, in run
    assert isinstance(a, FloatArray)
  File ".../jaxtyping/array_types.py", line 153, in __instancecheck__
    if len(storage.memo_stack) == 0:
AttributeError: '_thread._local' object has no attribute 'memo_stack'

@patrick-kidger
Copy link
Owner

Thanks! This should be fixed in #27.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants