Convert type_check_program from panics to collected errors (#3487)#3487
Open
stroxler wants to merge 25 commits into
Open
Convert type_check_program from panics to collected errors (#3487)#3487stroxler wants to merge 25 commits into
type_check_program from panics to collected errors (#3487)#3487stroxler wants to merge 25 commits into
Conversation
Contributor
|
@stroxler has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105783604. |
stroxler
added a commit
that referenced
this pull request
May 20, 2026
Summary: Pull Request resolved: #3487 **This stack** Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs for pytorch (and even more so if we want to extend to other libraries like numpy and jax). The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to associate the DSL function `reshape_ir` with a `torch.reshape` function). Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md **This commit** Replace ~20 panic sites in the DSL type checker with error collection, so type errors in `shape_dsl_function` stubs produce diagnostics instead of crashing the type checker. `type_check_program` now returns `Result<(), Vec<String>>`, threading errors through `check_body`, `check_expr`, `infer_expr`, `infer_call`, and the narrowing/joining helpers. `validate_shape_dsl_functions` propagates these errors to the solver, which emits them as `InvalidArgument` diagnostics on the function definition. Eval-time panics in `eval_dsl_body` / `eval_dsl_expr` are intentionally left as panics — they are correctness assertions that should be unreachable for type-checked programs. Differential Revision: D105783604
3e69bd6 to
719e973
Compare
type_check_program from panics to collected errorstype_check_program from panics to collected errors (#3487)
This comment has been minimized.
This comment has been minimized.
Summary: Add the `uses_shape_dsl` decorator to `shape_extensions`, the public API decorator that associates a shape DSL function with an API function in library stubs. At runtime it's a no-op passthrough; Pyrefly will use it at type-checking time to route bound arguments through the shape DSL for return-type refinement. Also register `UsesShapeDsl` as a `SpecialExport` variant in the Rust export system, so that later phases can detect `uses_shape_dsl` decorators during binding. Differential Revision: D105696519
Summary: Add the `shape_extensions.dsl` submodule containing the `shape_dsl_function` decorator. This decorator is DSL-internal — it marks a function whose body should be converted to shape DSL IR during binding. It lives in a separate submodule from the public `shape_extensions` API because it's only used inside DSL definition files (like `torch/_shapes.pyi`), not in normal stubs or user code. Also register `ShapeDslFunction` as a `SpecialExport` variant with `defined_in` matching `shape_extensions.dsl`, so later phases can detect `shape_dsl_function` decorators during binding. Differential Revision: D105696516
Summary: Add `prod`, `sum`, and `parse_einsum_equation` stub definitions to `shape_extensions/dsl.py` so DSL function files can import them from the canonical location. Update all DSL builtin references from `shape_extensions.*` to `shape_extensions.dsl.*` in both the Rust string matching (`convert_call`, `Display for DslBuiltin`) and the `DSL_SOURCE` Python code. Fix `convert_call` to support multi-level dotted names (e.g. `shape_extensions.dsl.prod`) — it previously only handled single-dotted names like `shape_extensions.prod`. Differential Revision: D105698000
Summary: Add an import statement at the top of `DSL_SOURCE` to make it look more like a normal Python module. The import is silently skipped by `parse_dsl` (which only processes function definitions), but having it there makes the DSL source ready to behave like a real stub file once Phase 5 migrates it out of the Rust string. Differential Revision: D105698001
Summary: Phase 2 of the tensor-shapes-in-stubs migration. Adds a public surface to `pyrefly_types::meta_shape_dsl` that lets later phases (the binder and solver in `pyrefly/lib`) drive the DSL pipeline without exposing the grammar-aligned `DslFnDef` internals. The new surface: - `ShapeDslFunction` — an opaque, cheap (one `Arc`) handle to a single DSL function lowered from its Python AST. - `ShapeDslProgram` — a bundle of `ShapeDslFunction`s that has been validated together as a program. The only way to obtain one is via `build_shape_dsl_program`, which type-checks the bundle. - `convert_shape_dsl_function` — AST → `ShapeDslFunction`. - `build_shape_dsl_program` — `Iterator<ShapeDslFunction>` → `ShapeDslProgram`. Panics on type-check failure today, matching the existing `parse_dsl` semantics; Phase 7 of the migration will convert this to a `Result` and surface diagnostics. - `make_meta_shape_function` — `(&ShapeDslProgram, root_name)` → `Box<dyn MetaShapeFunction>`. Taking a `&ShapeDslProgram` (not raw pieces) enforces that callers cannot build a `MetaShapeFunction` from un-type-checked DSL. The internal helpers (`convert_fndef`, `type_check_program`, `bind_dsl_params`, `eval_dsl_body`) stay module-private; the wrappers live in the same module and call them directly. This keeps DSL internals fully opaque outside the module. No call sites change in this commit; the new API is exercised by the follow-up dogfood refactor of `tensor_ops_registry`. Differential Revision: D105720303
Summary: Refactor `TensorOpsRegistry::new()` to drive DSL construction through the public `pyrefly_types::meta_shape_dsl` wrapper API added in the previous commit, rather than reaching into the now-obsolete `parse_dsl` / `Arc<DslFnDef>` / `DslMetaShapeFunction` internals. This validates the new API works end-to-end before Phase 3 depends on it, and it shrinks the number of code paths through the DSL engine to one: AST → `convert_shape_dsl_function` → `build_shape_dsl_program` → `make_meta_shape_function`. `parse_dsl` was the only caller of itself; with the registry converted there are no callers left, so it is deleted (together with the source-text parsing helper it relied on). The pipeline still panics on parse / type-check failure, exactly as before — Phase 7 will revisit error handling end-to-end. Differential Revision: D105720304
Summary: Thread a `capture_init: Option<Vec<Name>>` field from class binding through to solved `ClassMetadata`, following the `pydantic_before_validator_fields` precedent. This field holds `__init__` parameter names extracted from `uses_shape_dsl(..., capture_init=[...])` decorators on `forward` methods — today it is populated but not yet consumed. Phase 4 will wire it into `maybe_wrap_nn_module` to replace the hardcoded `TensorOpsRegistry::get_init_capture` lookup. Differential Revision: D105720302
Summary: Add a new `FunctionKind::ShapeDsl(Arc<FuncId>, Arc<ShapeDslFunction>)` variant that will represent functions whose return types are computed by the shape DSL. The `FuncId` provides identity (module, class, name) for display and lookup; the `ShapeDslFunction` carries the parsed DSL IR. The DSL definition is carried inside the `FunctionKind` variant rather than as a separate `Option` field on `Function` to avoid touching ~30-40 construction sites with `dsl_def: None`. This is semantically equivalent — `dsl_def` is `Some` exactly when `FunctionKind` is `ShapeDsl`, so embedding it in the variant enforces the invariant by construction. Adds `PartialEq`/`Eq`/`Hash`/`Ord`/`Visit`/`VisitMut`/`TypeEq` implementations on `ShapeDslFunction` (pointer-identity semantics, no-op visiting since DSL IR contains no `Type` values). Differential Revision: D105720305
Summary: Wire the binder and solver so that functions decorated with `shape_dsl_function` (from `shape_extensions.dsl`) are converted to DSL IR at binding time and produce `FunctionKind::ShapeDsl` at solve time. The DSL definition is stored as `shape_dsl_def: Option<Arc<ShapeDslFunction>>` on `BindingUndecoratedFunction` rather than as a separate `Binding` variant, since the function still needs its normal name binding and decorator processing chain. When the solver sees `shape_dsl_def` is `Some`, it constructs `FunctionKind::ShapeDsl(func_id, dsl_fn)` instead of `FunctionKind::Def`. The function's name resolves through the normal binding chain, so `from torch._shapes import reshape_ir` works automatically. The conversion must happen before `function_body()` consumes the AST body via `mem::take`. Conversion failures panic (Phase 7 will add structured diagnostics). Differential Revision: D105728837
Summary: When a function is decorated with `uses_shape_dsl(reshape_ir)`, extract the first positional argument's name from the decorator call AST and store it as `uses_shape_dsl_ir_name: Option<Name>` on `BindingUndecoratedFunction`. This is needed because the existing `KwCall` mechanism in the decorator pipeline only captures keyword arguments, not positional ones. Binding-time extraction makes the IR function name available for Phase 4, where the solver will resolve it to a `Type::Function` with `FunctionKind::ShapeDsl` and wire it into `FuncFlags.shape_transform`. Differential Revision: D105728838
Summary: Add `ShapeTransformRef` type in `meta_shape_dsl.rs` and a new `shape_transform: Option<Arc<ShapeTransformRef>>` field on `FuncFlags`. `ShapeTransformRef` carries an `Arc<ShapeDslFunction>` — the resolved DSL function definition. By the time this field is populated (Phase 4b), the IR function name has been resolved to a `Type::Function` with `FunctionKind::ShapeDsl`, so we store the extracted definition directly rather than a name or binding key. Trait impls follow the same pointer-identity pattern as `ShapeDslFunction`: `PartialEq/Eq/Hash/PartialOrd/Ord` delegate to the inner `ShapeDslFunction`, `Visit/VisitMut` are no-ops (DSL IR contains no `Type` values), and `TypeEq` delegates to `PartialEq`. Differential Revision: D105739362
… population Summary: Add `FunctionKind::UsesShapeDsl` so that calling `uses_shape_dsl(...)` produces a `Type::KwCall`, making the decorator identifiable in `get_special_decorator`. Without this, the decorator's type would be plain `Callable`, indistinguishable from any other callable-returning decorator, and the generic pipeline would produce `Any`. The decorator is consumed via `SpecialDecorator::UsesShapeDsl` → `set_flag_from_special_decorator` (returns `true` to filter it out). The actual `FuncFlags.shape_transform` is populated after the decorator loop in `undecorated_function`, where `uses_shape_dsl_ir_name` (extracted at binding time in Phase 3d) is resolved via `Key::BoundName` to get the IR function's `ShapeDslFunction` from its `FunctionKind::ShapeDsl` variant. To enable solve-time lookup, `uses_shape_dsl_ir_name` is changed from `Option<Name>` to `Option<(Name, ShortIdentifier)>`, carrying the `TextRange` needed for `Key::BoundName` resolution (Pyrefly's binding lookup is range-based, not name-based). Differential Revision: D105739363
Summary: Verify that `uses_shape_dsl` decorator recognition works correctly: - Plain function: decorator is consumed and function type is preserved - Overloaded with implementation: `shape_transform` flows through `merge_overload_metadata_with_implementation` via `FuncFlags` - Stub-only overloads (no implementation): `shape_transform` flows through `merge_overload_metadata_no_implementation` from the first overload These tests catch regressions if `merge_overload_metadata_*` ever changes in a way that drops `FuncFlags` fields. Differential Revision: D105739364
Summary: The `DslType::Int` and `DslType::Bool` branches in `val_to_type` synthesize `Literal[n]` / `Literal[bool]` from the DSL's traced runtime value. This looks inconsistent with the other branches (Tensor, List, Tuple, None, Str) which return `expected_return_type.clone()`, but the difference is intentional and load-bearing. Functions like `dim_ir`, `numel_ir`, and `size_ir(dim=N)` trace exact integer results. Downstream consumers (assert_type, reshape validation, shape inference) rely on this literal precision. The fixture return type for these functions is just `int` — the literal value comes solely from DSL evaluation. In contrast, the Tensor/List branches' `expected_return_type` already carries refined structure (e.g. `Tensor[B, C, H, W]` with shape injected), so cloning it is correct there. This commit adds comments explaining the invariant so future readers don't mistake the asymmetry for a bug. Differential Revision: D105758573
Summary: Add `ShapeTransformRef::to_meta_shape_function()` which builds a `DslMetaShapeFunction` from the decorator-carried DSL definition, and wire it into `callable_infer_inner` so the decorator-based `shape_transform` is preferred over the legacy registry lookup. The registry serves as fallback for functions not yet migrated to `uses_shape_dsl`. The decorator path is intentionally not gated by the `tensor_shapes` flag since `uses_shape_dsl` is itself the opt-in. Thread `shape_transform: Option<&ShapeTransformRef>` through the call inference chain (`call_infer_with_callee_range` → `call_infer_inner` → `callable_infer` → `callable_infer_inner`) and the overload path (`call_overloads` → `find_closest_overload` → `call_overload`). Fix ordering in the binder: `convert_shape_dsl_function` must run before `function_header`, which consumes `x.returns` via `mem::take`. Without this, the DSL converter sees no return type annotation and produces `DslFnDef.return_type = None`. Update `maybe_wrap_nn_module` to check `ClassMetadata.capture_init()` first, falling back to `TensorOpsRegistry::get_init_capture` for classes not yet migrated. Differential Revision: D105758771
Summary: Move all 86 shape DSL functions (14 helpers + 72 IR functions) from `DSL_SOURCE` in `tensor_ops_registry.rs` to `test/tensor_shapes/fixtures/torch/_shapes.pyi`, adding `shape_dsl_function` decorators. The function bodies are verbatim copies. No behavioral change — nothing references this file yet. The decorators will be consumed in commit 5c when `uses_shape_dsl` decorators are added to the torch fixture stubs. Differential Revision: D105775130
Summary: Enable DSL functions that call helpers (e.g., `reshape_ir` calling `normalize_dim`) to work through the decorator path. Previously, `ShapeTransformRef::to_meta_shape_function()` used an empty `fn_lookup`, which panicked on any helper call. Introduces a `Derived<T>` wrapper in `pyrefly_types` for attaching auxiliary data to types without affecting identity comparisons. Uses it to carry same-module DSL siblings on `FunctionKind::ShapeDsl` and `ShapeTransformRef`. At `shape_dsl_function` solve time, all siblings from the module's `BindingsMetadata` are collected; at `uses_shape_dsl` consumer sites, the siblings flow through to `to_meta_shape_function` which builds fn_lookup from self + siblings. This is a deliberate all-siblings shortcut matching the registry's flat-namespace behavior. Follow-up #9 replaces it with per-caller transitive-callee resolution. Differential Revision: D105775131
Summary: Wire up ~131 torch fixture stub functions with `uses_shape_dsl(ir_fn)` decorators, matching the `TensorOpsRegistry` mappings. The decorator path is now preferred over the registry for all decorated functions. Modified fixture files: - `torch/__init__.pyi`: ~112 decorators on module-level functions and `Tensor` methods (shape manipulation, reductions, creation, linalg, indexing, random, conditional, properties) - `torch/nn/functional.pyi`: ~30 decorators (conv, pool, loss, pad, interpolate, cosine_similarity) - `torch/fft.pyi`: 4 decorators (rfft, irfft, hfft, ihfft) - `torch/linalg.pyi`: 9 decorators (eig, eigvals, solve, slogdet, etc.) All tensor_shapes tests pass through the decorator path. Differential Revision: D105775133
Summary: Add `uses_shape_dsl(ir_fn, capture_init=[...])` decorators to the `forward` methods of all 15 nn.Module classes that have shape-aware forward inference: MaxPool1d/2d/3d, AvgPool1d/2d/3d, Flatten, PixelShuffle, GLU, LSTM, Upsample, GRU, LSTMCell, ReflectionPad2d, ReplicationPad2d. The `capture_init` plumbing (binding extraction, ClassMetadata propagation, and `maybe_wrap_nn_module` consumption) was completed in Phases 3a and 4c. This commit is purely fixture annotations — once added, the decorators override the registry path for each annotated class. Differential Revision: D105775129
Summary: Remove the old hardcoded registry now that all ~131 shape functions and 15 nn.Module classes are wired through `uses_shape_dsl` decorators. Deleted: - `tensor_ops_registry.rs` (1,022 lines): `DSL_SOURCE` string, registry struct, ~131 `register*()` calls, `OnceLock` statics - `lookup_meta_shape` in `callable.rs`: registry-based shape function lookup - Registry fallback in `callable_infer_inner`: now decorator-only - Registry fallback in `maybe_wrap_nn_module`: now ClassMetadata-only - Phase 2 unit test in `meta_shape_dsl.rs`: superseded by e2e tests The `tensor_shapes` config flag remains — it gates non-registry behavior (Tensor subscript support, jaxtyping, operator overloads). Differential Revision: D105775132
…lution Summary: Each `shape_dsl_function` now carries only its transitive callees instead of every DSL function in the module. A leaf function like `randn_ir` gets `helpers = [self]` (1 entry) instead of all 86 siblings. `reshape_ir` gets `[reshape_ir, normalize_dim]` (2 entries). `movedim_ir` gets its 6 transitive callees. Adds `ShapeDslFunction::call_targets()` which walks `DslBody`/`DslExpr` collecting `DslCallTarget::UserDefined` names, and `compute_transitive_helpers` which resolves those names against the per-module `BindingsMetadata` index and computes the fixed-point closure. Behaviorally identical — same shape inference results, same types. The change reduces per-function memory footprint and will enable finer cache invalidation once solver dependency edges are per-helper. Differential Revision: D105783601
Summary: Run `type_check_program` on each DSL function's transitive-callee closure at `shape_dsl_function` solve time. This validates that cross-function call signatures are consistent (e.g., `reshape_ir` calling `normalize_dim` with the right argument types). Previously, validation was done once globally by the deleted `TensorOpsRegistry`. With per-caller resolution from 7a-1, each function is now validated against only its actual callees. Panics on type errors for now; Phase 7b will convert to diagnostics. Differential Revision: D105783602
Summary: When `uses_shape_dsl(ir_fn)` references a function that is not decorated with `shape_dsl_function`, emit an `InvalidArgument` diagnostic instead of silently producing a function with no shape inference. The function still falls back to its declared return type. Previously this was a silent no-op — the `if let FunctionKind::ShapeDsl` chain simply didn't match and `shape_transform` stayed `None` with no indication to the user that their decorator was misconfigured. Differential Revision: D105783626
Summary: When a `shape_dsl_function` body uses unsupported Python syntax, emit a diagnostic instead of panicking. The function degrades to a normal `FunctionKind::Def` with no shape inference. Also adds warnings for unsupported parameter kinds (`*args`, `**kwargs`, keyword-only, positional-only) — these are silently dropped by the DSL converter but the user should know they have no effect. Previously, any `convert_fndef` failure crashed the type checker via `.expect()`. Third-party stub authors writing new DSL functions would hit this crash on any syntax mistake. Differential Revision: D105783605
Summary: Pull Request resolved: #3487 **This stack** Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs for pytorch (and even more so if we want to extend to other libraries like numpy and jax). The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to associate the DSL function `reshape_ir` with a `torch.reshape` function). Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md **This commit** Replace ~20 panic sites in the DSL type checker with error collection, so type errors in `shape_dsl_function` stubs produce diagnostics instead of crashing the type checker. `type_check_program` now returns `Result<(), Vec<String>>`, threading errors through `check_body`, `check_expr`, `infer_expr`, `infer_call`, and the narrowing/joining helpers. `validate_shape_dsl_functions` propagates these errors to the solver, which emits them as `InvalidArgument` diagnostics on the function definition. Eval-time panics in `eval_dsl_body` / `eval_dsl_expr` are intentionally left as panics — they are correctness assertions that should be unreachable for type-checked programs. Differential Revision: D105783604
719e973 to
53dc247
Compare
|
According to mypy_primer, this change doesn't affect type check results on a corpus of open source code. ✅ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
This stack
Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are
directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space
stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs
for pytorch (and even more so if we want to extend to other libraries like numpy and jax).
The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different
decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to
associate the DSL function
reshape_irwith atorch.reshapefunction).Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md
This commit
Replace ~20 panic sites in the DSL type checker with error collection,
so type errors in
shape_dsl_functionstubs produce diagnostics insteadof crashing the type checker.
type_check_programnow returnsResult<(), Vec<String>>, threadingerrors through
check_body,check_expr,infer_expr,infer_call,and the narrowing/joining helpers.
validate_shape_dsl_functionspropagates these errors to the solver, which emits them as
InvalidArgumentdiagnostics on the function definition.Eval-time panics in
eval_dsl_body/eval_dsl_exprare intentionallyleft as panics — they are correctness assertions that should be
unreachable for type-checked programs.
Differential Revision: D105783604