diff --git a/jax/BUILD b/jax/BUILD index 70df4fb3de7a..d473fbb62c10 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -110,6 +110,7 @@ py_library_providing_imports_info( "experimental/compilation_cache/cache_interface.py", ], lib_rule = pytype_library, + pytype_srcs = glob(["_src/**/*.pyi"]), visibility = ["//visibility:public"], deps = select({ ":enable_jaxlib_build": [":jaxlib_deps"],