Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4135,6 +4135,45 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.get_dunder_init_helper(&Instance::of_class(cls), get_object_init)
}

/// Get the class's `__init_subclass__` method, excluding `object.__init_subclass__`.
pub fn get_dunder_init_subclass(
&self,
cls: &ClassType,
include_ancestors: bool,
) -> Option<Type> {
if cls.class_object().is_builtin("object") {
return None;
}
let init_subclass_member = if let Some(field) = self
.get_non_synthesized_field_from_current_class_only(
cls.class_object(),
&dunder::INIT_SUBCLASS,
) {
WithDefiningClass {
value: field,
defining_class: cls.class_object().dupe(),
}
} else if !include_ancestors {
return None;
} else {
let mro = self.get_mro_for_class(cls.class_object());
self.get_field_from_ancestors(
cls.class_object(),
mro.ancestors_no_object().iter(),
&dunder::INIT_SUBCLASS,
&|cls, name| self.get_non_synthesized_field_from_current_class_only(cls, name),
)?
};
if init_subclass_member.value.is_init_var() {
return None;
}
Arc::unwrap_or_clone(init_subclass_member.value)
.as_raw_special_method_type(self.heap, &Instance::of_class(cls))
.and_then(|ty| {
make_bound_classmethod(self.heap, &ClassBase::ClassType(cls.clone()), ty).ok()
})
}

pub fn get_typed_dict_dunder_init(&self, td: &TypedDictInner) -> Type {
// We synthesize `__init__`, so the lookup will never entirely fail.
//
Expand Down
83 changes: 75 additions & 8 deletions pyrefly/lib/alt/class/class_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use pyrefly_util::display::DisplayWithCtx;
use pyrefly_util::prelude::SliceExt;
use pyrefly_util::prelude::VecExt;
use ruff_python_ast::Expr;
use ruff_python_ast::Identifier;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
Expand All @@ -37,7 +38,10 @@ use starlark_map::small_set::SmallSet;

use crate::alt::answers::LookupAnswer;
use crate::alt::answers_solver::AnswersSolver;
use crate::alt::call::CallStyle;
use crate::alt::callable::CallKeyword;
use crate::alt::class::django::is_django_choices_subclass;
use crate::alt::expr::TypeOrExpr;
use crate::alt::solve::TypeFormContext;
use crate::alt::types::abstract_class::AbstractClassMembers;
use crate::alt::types::class_metadata::ClassMetadata;
Expand All @@ -64,6 +68,7 @@ use crate::binding::binding::Key;
use crate::binding::binding::KeyClassField;
use crate::binding::binding::KeyDecorator;
use crate::binding::django::DjangoFieldInfo;
use crate::binding::pydantic::EXTRA;
use crate::binding::pydantic::PydanticConfigDict;
use crate::binding::pydantic::VALIDATION_ALIAS;
use crate::config::error_kind::ErrorKind;
Expand Down Expand Up @@ -123,7 +128,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&self,
cls: &Class,
bases: &[BaseClass],
keywords: &[(Name, Expr)],
keywords: &[(Identifier, Expr)],
decorators: &[Idx<KeyDecorator>],
is_new_type: bool,
pydantic_config_dict: &PydanticConfigDict,
Expand Down Expand Up @@ -162,11 +167,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let bases_with_metadata = self.bases_with_metadata(parsed_results, is_new_type, errors);

// Compute class keywords, including the metaclass.
let (metaclasses, keywords): (Vec<_>, Vec<(_, _)>) =
keywords.iter().partition_map(|(n, x)| match n.as_str() {
let (metaclasses, keyword_annotations): (Vec<_>, Vec<(_, _)>) =
keywords.iter().partition_map(|(n, x)| match n.id.as_str() {
"metaclass" => Either::Left(x),
_ => Either::Right((n.clone(), self.expr_class_keyword(x, errors))),
});
let keyword_annotations = keyword_annotations.into_map(|(name, annot)| (name.id, annot));

let base_metaclasses = bases_with_metadata
.iter()
Expand Down Expand Up @@ -227,6 +233,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}
let metaclass = calculated_metaclass.get();
self.check_init_subclass_keywords(cls, &bases_with_metadata, metaclass, keywords, errors);

let mut directly_inherits_model = false;
let mut inherited_django_metadata: Option<&DjangoModelMetadata> = None;
Expand Down Expand Up @@ -313,7 +320,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let pydantic_config = self.pydantic_config(
&bases_with_metadata,
pydantic_config_dict,
&keywords,
&keyword_annotations,
&decorators,
errors,
cls.range(),
Expand All @@ -332,8 +339,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
format!("`{}` is not a typed dictionary. Typed dictionary definitions may only extend other typed dictionaries.", bad.0.name()),
);
}
let typed_dict_metadata =
self.typed_dict_metadata(cls, &bases_with_metadata, &keywords, is_typed_dict, errors);
let typed_dict_metadata = self.typed_dict_metadata(
cls,
&bases_with_metadata,
&keyword_annotations,
is_typed_dict,
errors,
);
if metaclass.is_some() && is_typed_dict {
self.error(
errors,
Expand Down Expand Up @@ -385,7 +397,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
dataclass_defaults_from_base_class.clone(),
);
let dataclass_from_dataclass_transform = self.dataclass_from_dataclass_transform(
&keywords,
&keyword_annotations,
&decorators,
dataclass_defaults_from_base_class,
pydantic_config.as_ref(),
Expand Down Expand Up @@ -464,7 +476,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
};

// Get types of class keywords.
let keywords = keywords.into_map(|(name, annot)| {
let keywords = keyword_annotations.into_map(|(name, annot)| {
(
name,
annot.ty.unwrap_or_else(|| self.heap.mk_any_implicit()),
Expand Down Expand Up @@ -510,6 +522,61 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
)
}

fn check_init_subclass_keywords(
&self,
cls: &Class,
bases_with_metadata: &[(Class, Arc<ClassMetadata>)],
metaclass: Option<&ClassType>,
keywords: &[(Identifier, Expr)],
errors: &ErrorCollector,
) {
let is_pydantic_model = bases_with_metadata.iter().any(|(base, metadata)| {
base.has_toplevel_qname(ModuleName::pydantic().as_str(), "BaseModel")
|| metadata.is_pydantic_model()
});
let keywords = keywords
.iter()
.filter(|(name, _)| name.id != "metaclass" && !(is_pydantic_model && name.id == EXTRA))
.collect::<Vec<_>>();
if !keywords.is_empty() && metaclass.is_some() {
return;
}
let Some((base, _)) = bases_with_metadata.first() else {
return;
};
let include_init_subclass_ancestors = !keywords.is_empty();
let base = self.promote_nontypeddict_silently_to_classtype(base);
let Some(init_subclass) =
self.get_dunder_init_subclass(&base, include_init_subclass_ancestors)
else {
return;
};
let keywords = keywords
.into_iter()
.map(|(name, value)| CallKeyword {
range: name.range(),
arg: Some(name),
value: TypeOrExpr::Expr(value),
})
.collect::<Vec<_>>();
self.call_infer(
self.as_call_target_or_error(
init_subclass,
CallStyle::Method(&dunder::INIT_SUBCLASS),
cls.range(),
errors,
None,
),
&[],
&keywords,
cls.range(),
errors,
None,
None,
None,
);
}

fn extract_slots_info(&self, cls: &Class) -> Option<SlotsInfo> {
let key = KeyClassField(cls.index(), dunder::SLOTS.clone());
let idx = self.bindings().key_to_idx_hashed_opt(Hashed::new(&key))?;
Expand Down
2 changes: 1 addition & 1 deletion pyrefly/lib/binding/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3110,7 +3110,7 @@ pub struct BindingClassMetadata {
/// The class keywords (these are keyword args that appear in the base class list, the
/// Python runtime will dispatch most of them to the metaclass, but the metaclass
/// itself can also potentially be one of these).
pub keywords: Box<[(Name, Expr)]>,
pub keywords: Box<[(Identifier, Expr)]>,
/// The class decorators.
pub decorators: Box<[Idx<KeyDecorator>]>,
/// Is this a new type? True only for synthesized classes created from a `NewType` call.
Expand Down
12 changes: 8 additions & 4 deletions pyrefly/lib/binding/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl<'a> BindingsBuilder<'a> {
args.keywords.iter_mut().for_each(|keyword| {
if let Some(name) = &keyword.arg {
self.ensure_expr(&mut keyword.value, class_object.usage());
keywords.push((name.id.clone(), keyword.value.clone()));
keywords.push((name.clone(), keyword.value.clone()));
} else {
self.error(
keyword.range(),
Expand Down Expand Up @@ -960,7 +960,7 @@ impl<'a> BindingsBuilder<'a> {
class_indices: ClassIndices,
parent: &NestingContext,
base: Option<Expr>,
keywords: Box<[(Name, Expr)]>,
keywords: Box<[(Identifier, Expr)]>,
// name, position, annotation, value
member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)>,
illegal_identifier_handling: IllegalIdentifierHandling,
Expand Down Expand Up @@ -1396,8 +1396,12 @@ impl<'a> BindingsBuilder<'a> {
(Some(name), _) if name == "extra_items" => Some(name),
_ => None,
};
if let Some(kw_name) = recognized_kw {
base_class_keywords.push((kw_name.clone(), kw.value.clone()));
if recognized_kw.is_some() {
let kw_name = kw
.arg
.clone()
.expect("recognized TypedDict keyword must have a name");
base_class_keywords.push((kw_name, kw.value.clone()));
} else {
let msg = if let Some(name) = &kw.arg {
format!("Unrecognized keyword argument `{name}`")
Expand Down
25 changes: 25 additions & 0 deletions pyrefly/lib/test/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,31 @@ assert_type(C(), C[Any]) # Correct, because invalid metaclass.
"#,
);

testcase!(
test_init_subclass_class_keywords,
r#"
class Foo:
def __init_subclass__(cls, asdf: int) -> None:
pass

class Bar(Foo, asdf=1): ...
class Baz(Foo, asdf=""): ... # E: Argument `Literal['']` is not assignable to parameter `asdf` with type `int`
class Qux(Foo): ... # E: Missing argument `asdf`
"#,
);

testcase!(
test_init_subclass_skips_custom_metaclass_keywords,
r#"
class Meta(type):
def __new__(cls, name, bases, namespace, abstract: bool = False):
return super().__new__(cls, name, bases, namespace)

class Base(metaclass=Meta): ...
class Child(Base, abstract=True): ...
"#,
);

testcase!(
test_metaclass_invalid_generic_legacy_typevar,
r#"
Expand Down
Loading