Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions pyrefly/lib/stubgen/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pyrefly_python::module::Module;
use pyrefly_python::short_identifier::ShortIdentifier;
use pyrefly_python::sys_info::SysInfo;
use pyrefly_types::callable::Param;
use pyrefly_types::class::ClassDefIndex;
use pyrefly_types::display::TypeDisplayContext;
use pyrefly_types::types::Type;
use ruff_python_ast::Expr;
Expand All @@ -29,10 +30,14 @@ use ruff_python_ast::StmtFunctionDef;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use starlark_map::Hashed;

use crate::alt::answers::Answers;
use crate::alt::types::decorated_function::DecoratedFunction;
use crate::binding::binding::BindingClass;
use crate::binding::binding::Key;
use crate::binding::binding::KeyClass;
use crate::binding::binding::KeyClassField;
use crate::binding::binding::KeyDecoratedFunction;
use crate::binding::bindings::Bindings;
use crate::export::definitions::Definitions;
Expand Down Expand Up @@ -485,6 +490,101 @@ fn extract_return_type(
None
}

/// Resolves the per-module class index for a class statement (used to look up
/// `KeyClassField` and `ClassFields` metadata).
fn class_def_index(class_def: &StmtClassDef, ctx: &ExtractionContext) -> Option<ClassDefIndex> {
let key = KeyClass(ShortIdentifier::new(&class_def.name));
let idx = ctx.bindings.key_to_idx_hashed_opt(Hashed::new(&key))?;
match ctx.bindings.get(idx) {
BindingClass::ClassDef(c) => Some(c.def_index),
BindingClass::FunctionalClassDef(d, ..) => Some(*d),
}
}

/// Names already represented as variables in the extracted class body.
fn stub_class_level_variable_names(body: &[StubItem]) -> HashSet<String> {
let mut out = HashSet::new();
for item in body {
if let StubItem::Variable(v) = item {
out.insert(v.name.clone());
}
}
out
}

/// Instance attributes inferred from methods (e.g. `self.name` in `__init__`) that
/// are not already declared in the class body, materialized as stub `name: T` lines.
fn extract_instance_attr_stubs_from_class_fields(
def_index: ClassDefIndex,
class_body: &[StubItem],
ctx: &mut ExtractionContext,
) -> Vec<StubVariable> {
let already = stub_class_level_variable_names(class_body);
let class_fields = match ctx.bindings.get_class_fields(def_index) {
Some(f) => f,
None => return Vec::new(),
};

let mut out = Vec::new();
for (name, _) in class_fields.iter() {
if already.contains(name.as_str()) {
continue;
}
if !should_include_name(name.as_str(), ctx.config, true, ctx.dunder_all) {
continue;
}
let key = KeyClassField(def_index, name.clone());
let Some(field_idx) = ctx.bindings.key_to_idx_hashed_opt(Hashed::new(&key)) else {
continue;
};
let Some(field) = ctx.answers.get_idx(field_idx) else {
continue;
};
if !field.is_simple_instance_attribute() {
continue;
}
let Some(ann) = format_type(&field.ty(), ctx) else {
continue;
};
out.push(StubVariable {
name: name.to_string(),
annotation: Some(ann),
value: None,
});
}
out.sort_by(|a, b| a.name.cmp(&b.name));
out
}

/// Inserts synthesized instance attribute stubs before `__init__` when present,
/// otherwise at the start of the class body.
fn merge_instance_field_stubs(
synthesized: Vec<StubVariable>,
mut body: Vec<StubItem>,
) -> Vec<StubItem> {
if synthesized.is_empty() {
return body;
}
let init_idx = body
.iter()
.position(|item| matches!(item, StubItem::Function(f) if f.name == "__init__"));
let synth: Vec<StubItem> = synthesized.into_iter().map(StubItem::Variable).collect();
match init_idx {
Some(i) => {
let mut tail = body.split_off(i);
let mut out = body;
out.extend(synth);
out.append(&mut tail);
out
}
None => {
let mut out = synth;
out.append(&mut body);
out
}
}
}

fn extract_class(class_def: &StmtClassDef, ctx: &mut ExtractionContext) -> Option<StubClass> {
let name = class_def.name.id.as_str();
if !should_include_name(name, ctx.config, false, ctx.dunder_all) {
Expand Down Expand Up @@ -535,6 +635,12 @@ fn extract_class(class_def: &StmtClassDef, ctx: &mut ExtractionContext) -> Optio
.map(|tp| source_text(ctx.module_info, tp.range()).to_owned());

let body = extract_stmts(&class_def.body, ctx, true);
let body = if let Some(def_index) = class_def_index(class_def, &*ctx) {
let extra = extract_instance_attr_stubs_from_class_fields(def_index, &body, ctx);
merge_instance_field_stubs(extra, body)
} else {
body
};

Some(StubClass {
name: name.to_owned(),
Expand Down
23 changes: 23 additions & 0 deletions pyrefly/lib/stubgen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,29 @@ from typing import Self

class C:
def __new__(cls) -> Self: ...
"#
.trim(),
actual.trim(),
);
}

/// Instance fields assigned in `__init__` (without class-level annotations) appear in the stub
/// using inferred types. See <https://github.com/facebook/pyrefly/issues/3208>.
#[test]
fn test_stubgen_instance_fields_from_init() {
let actual = run_stubgen(
r#"
class A:
def __init__(self, name: str) -> None:
self.name = name
"#,
);
pretty_assertions::assert_str_eq!(
r#"
class A:
name: str

def __init__(self, name: str) -> None: ...
"#
.trim(),
actual.trim(),
Expand Down
Loading