Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple script error #108

Closed
PabloAMC opened this issue Sep 12, 2023 · 2 comments
Closed

Simple script error #108

PabloAMC opened this issue Sep 12, 2023 · 2 comments
Labels
question User queries

Comments

@PabloAMC
Copy link

PabloAMC commented Sep 12, 2023

Hi, sorry to post it here, but I am not sure if there is a better place to post questions about jaxtyping. I have the following simple script that follows the API documentation https://docs.kidger.site/jaxtyping/api/runtime-type-checking/

import jax
from jax import numpy as jnp

from jaxtyping import jaxtyped, Float, Array
from typeguard import typechecked
from jax.lax import Precision

@jaxtyped
@typechecked
def func1(arg1: Float[Array, "s o o"], 
            arg2: Float[Array, "g o"], 
            precision: Precision = Precision.HIGHEST
) -> Float[Array, "g s"]:

    return jnp.einsum("...ab,ra,rb->r...", arg1, arg2, arg2, precision=precision)

key = jax.random.PRNGKey(42)
s = 2
o = 5
g = 10
# Generate random arg1
arg1 = jax.random.uniform(key, shape=(s, o, o))
# Generate random arg2
arg2 = jax.random.uniform(key, shape=(g, o))

# Compute func1
rho = func1(arg1, arg2)

Upon running it, I am getting a weird error

Traceback (most recent call last):
  File "/Users/user/Developer/folder/subfolder/file.py", line 12, in <module>
    def func1(arg1: Float[Array, 's o o'], 
  File "/Users/user/miniforge3/envs/myenvlib/python3.10/site-packages/typeguard/_decorators.py", line 213, in typechecked
    retval = instrument(target)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_decorators.py", line 54, in instrument
    instrumentor.visit(module_ast)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 570, in visit_Module
    self.generic_visit(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 494, in generic_visit
    value = self.visit(value)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 644, in visit_FunctionDef
    with self._use_memo(node):
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 525, in _use_memo
    new_memo.return_annotation = self._convert_annotation(
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 555, in _convert_annotation
    new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 343, in visit
    new_node = super().visit(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 400, in visit_Subscript
    items = [self.visit(item) for item in slice_value.elts]
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 400, in <listcomp>
    items = [self.visit(item) for item in slice_value.elts]
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 343, in visit
    new_node = super().visit(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/site-packages/typeguard/_transformer.py", line 451, in visit_Constant
    expression = ast.parse(node.value, mode="eval")
  File "/Users/user/miniforge3/envs/myenv/lib/python3.10/ast.py", line 50, in parse
    return compile(source, filename, mode, flags,
  File "<unknown>", line 1
    g s
      ^

This is in an installation on macOs M2 Air, a conda environment with python=3.10, and

# Name                    Version                   Build  Channel
jax                       0.4.14                   pypi_0    pypi
jaxlib                    0.4.14                   pypi_0    pypi
jaxtyping                 0.2.21                   pypi_0    pypi
typeguard                 4.0.0                    pypi_0    pypi

Updating typeguard to version 4.1.5 does not seem to matter.
Any suggestions about what might be going on?
Thanks a lot

@patrick-kidger
Copy link
Owner

Use typeguard v2. The later releases are known to be buggy.

@patrick-kidger patrick-kidger added the question User queries label Sep 12, 2023
@PabloAMC
Copy link
Author

Backdating typeguard to v2.13.3 indeed solves the issue, thanks a lot @patrick-kidger!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants