In [1]:
import jax
import jax.numpy as jnp

In [2]:
from jaxtyping import Float
from typing import no_type_check

In [3]:

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

### Test type errors

In [4]:
@no_type_check
def bad_fun(arg: Float[jax.Array, " 1"]):
    return arg

def bad_fun_typed(arg: Float[jax.Array, " 1"]):
    return arg


int_arr = jnp.array([1], dtype=jnp.int32)
print("No error here", bad_fun(int_arr)) 
try:
     bad_fun_typed(int_arr)
except TypeError as e:
    print("Expected error:", e)


@no_type_check
def bad_fun_wrapped(arg: Float[jax.Array, " 1"]):
    return bad_fun_typed(arg)

try:
     bad_fun_wrapped(int_arr)
except TypeError as e:
    print("Expected error for wrapped function:", e)



No error here [1]
Expected error: Type-check error whilst checking the parameters of __main__.bad_fun_typed.
The problem arose whilst typechecking parameter 'arg'.
Actual value: i32[1]
Expected type: <class 'Float[Array, '1']'>.
----------------------
Called with parameters: {'arg': i32[1]}
Parameter annotations: (arg: Float[Array, '1']) -> Any.

Expected error for wrapped function: Type-check error whilst checking the parameters of __main__.bad_fun_typed.
The problem arose whilst typechecking parameter 'arg'.
Actual value: i32[1]
Expected type: <class 'Float[Array, '1']'>.
----------------------
Called with parameters: {'arg': i32[1]}
Parameter annotations: (arg: Float[Array, '1']) -> Any.



### Try unloading add-on

In [5]:

%unload_ext jaxtyping


In [6]:
bad_fun_typed(int_arr)

Array([1], dtype=int32)

In [7]:
# re-load

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

In [8]:
try:
     bad_fun_typed(int_arr)
except TypeError as e:
    print("Expected error:", e)

Expected error: Type-check error whilst checking the parameters of __main__.bad_fun_typed.
The problem arose whilst typechecking parameter 'arg'.
Actual value: i32[1]
Expected type: <class 'Float[Array, '1']'>.
----------------------
Called with parameters: {'arg': i32[1]}
Parameter annotations: (arg: Float[Array, '1']) -> Any.

