You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import jax
from jax import numpy as jnp
from jaxtyping import jaxtyped, Float, Array
from typeguard import typechecked
from jax.lax import Precision
@jaxtyped
@typechecked
def func1(arg1: Float[Array, "s o o"],
arg2: Float[Array, "g o"],
precision: Precision = Precision.HIGHEST
) -> Float[Array, "g s"]:
return jnp.einsum("...ab,ra,rb->r...", arg1, arg2, arg2, precision=precision)
key = jax.random.PRNGKey(42)
s = 2
o = 5
g = 10
# Generate random arg1
arg1 = jax.random.uniform(key, shape=(s, o, o))
# Generate random arg2
arg2 = jax.random.uniform(key, shape=(g, o))
# Compute func1
rho = func1(arg1, arg2)
Upon running it, I am getting a weird error
Traceback (most recent call last):
File "/Users/user/Developer/folder/subfolder/file.py", line 12, in <module>
def func1(arg1: Float[Array, 's o o'],
File "/Users/user/miniforge3/envs/myenvlib/python3.10/site-packages/typeguard/_decorators.py", line 213, in typechecked
retval = instrument(target)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_decorators.py", line 54, in instrument
instrumentor.visit(module_ast)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 570, in visit_Module
self.generic_visit(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 494, in generic_visit
value = self.visit(value)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 644, in visit_FunctionDef
with self._use_memo(node):
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/contextlib.py", line 135, in __enter__
return next(self.gen)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 525, in _use_memo
new_memo.return_annotation = self._convert_annotation(
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 555, in _convert_annotation
new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 343, in visit
new_node = super().visit(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 400, in visit_Subscript
items = [self.visit(item) for item in slice_value.elts]
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 400, in <listcomp>
items = [self.visit(item) for item in slice_value.elts]
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 343, in visit
new_node = super().visit(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 451, in visit_Constant
expression = ast.parse(node.value, mode="eval")
File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 50, in parse
return compile(source, filename, mode, flags,
File "<unknown>", line 1
g s
^
This is in an installation on macOs M2 Air, a conda environment with python=3.10, and
Hi, sorry to post it here, but I am not sure if there is a better place to post questions about jaxtyping. I have the following simple script that follows the API documentation https://docs.kidger.site/jaxtyping/api/runtime-type-checking/
Upon running it, I am getting a weird error
This is in an installation on macOs M2 Air, a conda environment with
python=3.10
, andUpdating typeguard to version 4.1.5 does not seem to matter.
Any suggestions about what might be going on?
Thanks a lot
The text was updated successfully, but these errors were encountered: