Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4cd50dd
Add `uses_shape_dsl` to `shape_extensions`
stroxler May 22, 2026
95e72e2
Add `shape_extensions.dsl` submodule with `shape_dsl_function`
stroxler May 22, 2026
2986239
Add DSL builtins to `shape_extensions.dsl` and update builtin prefix
stroxler May 22, 2026
5d0d7d4
Add `import shape_extensions.dsl` to `DSL_SOURCE`
stroxler May 22, 2026
8afde14
Add public shape-DSL wrapper API in pyrefly_types
stroxler May 22, 2026
7fdf1da
Dogfood Phase 2 wrapper API in tensor_ops_registry
stroxler May 22, 2026
28b7426
Add `capture_init` plumbing to class metadata
stroxler May 22, 2026
8ca7008
Add `FunctionKind::ShapeDsl` variant
stroxler May 22, 2026
0747d77
Detect `@shape_dsl_function` and produce `FunctionKind::ShapeDsl`
stroxler May 22, 2026
967bc22
Extract `@uses_shape_dsl(ir_fn)` argument at binding time
stroxler May 22, 2026
a7990e6
Add `shape_transform` field to `FuncFlags`
stroxler May 22, 2026
6dbc7a4
Wire up `@uses_shape_dsl` decorator recognition and `shape_transform`…
stroxler May 22, 2026
67627db
Add overload regression tests for `@uses_shape_dsl`
stroxler May 22, 2026
1fe03d3
Document why `val_to_type` Int/Bool branches use `Literal[n]`
stroxler May 22, 2026
c12ccd5
Wire up solver consumption with legacy fallback
stroxler May 22, 2026
bad31c9
Create `torch/_shapes.pyi` with all DSL functions
stroxler May 22, 2026
79bed43
Add fn_lookup infrastructure for DSL helper resolution
stroxler May 22, 2026
8bc14e5
Add `@uses_shape_dsl` decorators to torch fixture stubs
stroxler May 22, 2026
56d9837
Add `capture_init` annotations to nn.Module forward methods
stroxler May 22, 2026
4e4e795
Delete `TensorOpsRegistry` and all legacy fallback paths
stroxler May 22, 2026
ef5afd8
Replace all-siblings fn_lookup with per-caller transitive-callee reso…
stroxler May 22, 2026
2774b74
Add per-function `type_check_program` validation
stroxler May 22, 2026
606c170
Emit diagnostic for invalid `@uses_shape_dsl` arguments
stroxler May 22, 2026
c3b9a18
Replace `convert_fndef` panic with diagnostic
stroxler May 22, 2026
53dc247
Convert `type_check_program` from panics to collected errors (#3487)
stroxler May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ MARSHMALLOW_TEST_PATH = { value = "pyrefly/lib/test/marshmallow/third-party", re
GLEAN_SNAPSHOTS_PATH = { value = "pyrefly/lib/report/glean/snapshots", relative = true }
REPORT_TEST_PATH = { value = "pyrefly/lib/test/report/test_files", relative = true }
STUBGEN_TEST_PATH = { value = "pyrefly/lib/test/stubgen", relative = true }
SHAPE_DSL_TEST_PATH = { value = "test/tensor_shapes/fixtures", relative = true }
81 changes: 80 additions & 1 deletion crates/pyrefly_types/src/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::fmt;
use std::fmt::Display;
use std::hash::Hash;
use std::hash::Hasher;
use std::ops::Deref;
use std::sync::Arc;

use dupe::Dupe;
Expand All @@ -36,10 +37,68 @@ use crate::display::TypeDisplayContext;
use crate::equality::TypeEq;
use crate::equality::TypeEqCtx;
use crate::keywords::DataclassTransformMetadata;
use crate::meta_shape_dsl::ShapeDslFunction;
use crate::meta_shape_dsl::ShapeTransformRef;
use crate::type_output::TypeOutput;
use crate::types::AnyStyle;
use crate::types::Type;

/// A wrapper for derived/cached data that should not participate in
/// equality, hashing, or ordering comparisons. `Derived<T>` always
/// compares as equal, hashes as a no-op, and orders as `Equal`.
///
/// This is useful for attaching auxiliary data to types that derive
/// `PartialEq`, `Hash`, `Ord`, etc. without affecting their identity.
#[derive(Debug, Clone)]
pub struct Derived<T>(pub T);

impl<T> PartialEq for Derived<T> {
fn eq(&self, _other: &Self) -> bool {
true
}
}

impl<T> Eq for Derived<T> {}

impl<T> Hash for Derived<T> {
fn hash<H: Hasher>(&self, _state: &mut H) {}
}

impl<T> PartialOrd for Derived<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl<T> Ord for Derived<T> {
fn cmp(&self, _other: &Self) -> Ordering {
Ordering::Equal
}
}

impl<T> Visit<Type> for Derived<T> {
const RECURSE_CONTAINS: bool = false;
fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {}
}

impl<T> VisitMut<Type> for Derived<T> {
const RECURSE_CONTAINS: bool = false;
fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {}
}

impl<T> TypeEq for Derived<T> {
fn type_eq(&self, _other: &Self, _ctx: &mut TypeEqCtx) -> bool {
true
}
}

impl<T> Deref for Derived<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Visit, VisitMut, TypeEq)]
pub struct Callable {
Expand Down Expand Up @@ -645,6 +704,9 @@ pub struct FuncFlags {
/// `dataclass_transform` call. See
/// https://typing.python.org/en/latest/spec/dataclasses.html#specification.
pub dataclass_transform_metadata: Option<DataclassTransformMetadata>,
/// A function decorated with `@uses_shape_dsl`, whose return type should be
/// refined by evaluating the referenced shape-DSL function at call sites.
pub shape_transform: Option<Arc<ShapeTransformRef>>,
}

impl FuncFlags {
Expand Down Expand Up @@ -810,6 +872,16 @@ pub enum FunctionKind {
NumbaJit,
/// `numba.njit()`
NumbaNjit,
/// A function whose return type is computed by a shape DSL definition.
/// The `FuncId` provides identity (module, class, name) for display and
/// lookup; the `ShapeDslFunction` carries the parsed DSL IR.
ShapeDsl(
Arc<FuncId>,
Arc<ShapeDslFunction>,
Derived<Arc<Vec<Arc<ShapeDslFunction>>>>,
),
/// The `shape_extensions.uses_shape_dsl` decorator function itself.
UsesShapeDsl,
}

impl Callable {
Expand Down Expand Up @@ -1185,6 +1257,7 @@ impl FunctionKind {
("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase,
("numba.core.decorators", None, "jit") => Self::NumbaJit,
("numba.core.decorators", None, "njit") => Self::NumbaNjit,
("shape_extensions", None, "uses_shape_dsl") => Self::UsesShapeDsl,
_ => Self::Def(Arc::new(FuncId {
module,
cls,
Expand Down Expand Up @@ -1218,6 +1291,8 @@ impl FunctionKind {
Self::NumbaJit => ModuleName::from_str("numba"),
Self::NumbaNjit => ModuleName::from_str("numba"),
Self::Def(func_id) => func_id.module.name().dupe(),
Self::ShapeDsl(id, _, _) => id.module.name().dupe(),
Self::UsesShapeDsl => ModuleName::from_str("shape_extensions"),
}
}

Expand All @@ -1244,6 +1319,8 @@ impl FunctionKind {
Self::NumbaJit => Cow::Owned(Name::new_static("jit")),
Self::NumbaNjit => Cow::Owned(Name::new_static("njit")),
Self::Def(func_id) => Cow::Borrowed(&func_id.name),
Self::ShapeDsl(id, _, _) => Cow::Borrowed(&id.name),
Self::UsesShapeDsl => Cow::Owned(Name::new_static("uses_shape_dsl")),
}
}

Expand All @@ -1270,12 +1347,14 @@ impl FunctionKind {
Self::TotalOrdering => None,
Self::DisjointBase => None,
Self::Def(func_id) => func_id.cls.clone(),
Self::ShapeDsl(id, _, _) => id.cls.clone(),
Self::UsesShapeDsl => None,
}
}

pub fn outer_funcs(&self) -> Option<&Name> {
match self {
Self::Def(func_id) => func_id.outer_funcs.as_ref(),
Self::Def(func_id) | Self::ShapeDsl(func_id, _, _) => func_id.outer_funcs.as_ref(),
_ => None,
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/pyrefly_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ pub mod simplify;
pub mod special_form;
pub mod stdlib;
pub mod tensor;
pub mod tensor_ops_registry;
pub mod tuple;
pub mod type_alias;
pub mod type_info;
Expand Down
Loading
Loading