diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index c92f5188f4..0e113ab765 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -4135,6 +4135,45 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.get_dunder_init_helper(&Instance::of_class(cls), get_object_init) } + /// Get the class's `__init_subclass__` method, excluding `object.__init_subclass__`. + pub fn get_dunder_init_subclass( + &self, + cls: &ClassType, + include_ancestors: bool, + ) -> Option { + if cls.class_object().is_builtin("object") { + return None; + } + let init_subclass_member = if let Some(field) = self + .get_non_synthesized_field_from_current_class_only( + cls.class_object(), + &dunder::INIT_SUBCLASS, + ) { + WithDefiningClass { + value: field, + defining_class: cls.class_object().dupe(), + } + } else if !include_ancestors { + return None; + } else { + let mro = self.get_mro_for_class(cls.class_object()); + self.get_field_from_ancestors( + cls.class_object(), + mro.ancestors_no_object().iter(), + &dunder::INIT_SUBCLASS, + &|cls, name| self.get_non_synthesized_field_from_current_class_only(cls, name), + )? + }; + if init_subclass_member.value.is_init_var() { + return None; + } + Arc::unwrap_or_clone(init_subclass_member.value) + .as_raw_special_method_type(self.heap, &Instance::of_class(cls)) + .and_then(|ty| { + make_bound_classmethod(self.heap, &ClassBase::ClassType(cls.clone()), ty).ok() + }) + } + pub fn get_typed_dict_dunder_init(&self, td: &TypedDictInner) -> Type { // We synthesize `__init__`, so the lookup will never entirely fail. // diff --git a/pyrefly/lib/alt/class/class_metadata.rs b/pyrefly/lib/alt/class/class_metadata.rs index cea0ed329a..cd21e75342 100644 --- a/pyrefly/lib/alt/class/class_metadata.rs +++ b/pyrefly/lib/alt/class/class_metadata.rs @@ -27,6 +27,7 @@ use pyrefly_util::display::DisplayWithCtx; use pyrefly_util::prelude::SliceExt; use pyrefly_util::prelude::VecExt; use ruff_python_ast::Expr; +use ruff_python_ast::Identifier; use ruff_python_ast::name::Name; use ruff_text_size::Ranged; use ruff_text_size::TextRange; @@ -37,7 +38,10 @@ use starlark_map::small_set::SmallSet; use crate::alt::answers::LookupAnswer; use crate::alt::answers_solver::AnswersSolver; +use crate::alt::call::CallStyle; +use crate::alt::callable::CallKeyword; use crate::alt::class::django::is_django_choices_subclass; +use crate::alt::expr::TypeOrExpr; use crate::alt::solve::TypeFormContext; use crate::alt::types::abstract_class::AbstractClassMembers; use crate::alt::types::class_metadata::ClassMetadata; @@ -64,6 +68,7 @@ use crate::binding::binding::Key; use crate::binding::binding::KeyClassField; use crate::binding::binding::KeyDecorator; use crate::binding::django::DjangoFieldInfo; +use crate::binding::pydantic::EXTRA; use crate::binding::pydantic::PydanticConfigDict; use crate::binding::pydantic::VALIDATION_ALIAS; use crate::config::error_kind::ErrorKind; @@ -123,7 +128,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, cls: &Class, bases: &[BaseClass], - keywords: &[(Name, Expr)], + keywords: &[(Identifier, Expr)], decorators: &[Idx], is_new_type: bool, pydantic_config_dict: &PydanticConfigDict, @@ -162,11 +167,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let bases_with_metadata = self.bases_with_metadata(parsed_results, is_new_type, errors); // Compute class keywords, including the metaclass. - let (metaclasses, keywords): (Vec<_>, Vec<(_, _)>) = - keywords.iter().partition_map(|(n, x)| match n.as_str() { + let (metaclasses, keyword_annotations): (Vec<_>, Vec<(_, _)>) = + keywords.iter().partition_map(|(n, x)| match n.id.as_str() { "metaclass" => Either::Left(x), _ => Either::Right((n.clone(), self.expr_class_keyword(x, errors))), }); + let keyword_annotations = keyword_annotations.into_map(|(name, annot)| (name.id, annot)); let base_metaclasses = bases_with_metadata .iter() @@ -227,6 +233,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } let metaclass = calculated_metaclass.get(); + self.check_init_subclass_keywords(cls, &bases_with_metadata, metaclass, keywords, errors); let mut directly_inherits_model = false; let mut inherited_django_metadata: Option<&DjangoModelMetadata> = None; @@ -313,7 +320,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let pydantic_config = self.pydantic_config( &bases_with_metadata, pydantic_config_dict, - &keywords, + &keyword_annotations, &decorators, errors, cls.range(), @@ -332,8 +339,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { format!("`{}` is not a typed dictionary. Typed dictionary definitions may only extend other typed dictionaries.", bad.0.name()), ); } - let typed_dict_metadata = - self.typed_dict_metadata(cls, &bases_with_metadata, &keywords, is_typed_dict, errors); + let typed_dict_metadata = self.typed_dict_metadata( + cls, + &bases_with_metadata, + &keyword_annotations, + is_typed_dict, + errors, + ); if metaclass.is_some() && is_typed_dict { self.error( errors, @@ -385,7 +397,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { dataclass_defaults_from_base_class.clone(), ); let dataclass_from_dataclass_transform = self.dataclass_from_dataclass_transform( - &keywords, + &keyword_annotations, &decorators, dataclass_defaults_from_base_class, pydantic_config.as_ref(), @@ -464,7 +476,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }; // Get types of class keywords. - let keywords = keywords.into_map(|(name, annot)| { + let keywords = keyword_annotations.into_map(|(name, annot)| { ( name, annot.ty.unwrap_or_else(|| self.heap.mk_any_implicit()), @@ -510,6 +522,61 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } + fn check_init_subclass_keywords( + &self, + cls: &Class, + bases_with_metadata: &[(Class, Arc)], + metaclass: Option<&ClassType>, + keywords: &[(Identifier, Expr)], + errors: &ErrorCollector, + ) { + let is_pydantic_model = bases_with_metadata.iter().any(|(base, metadata)| { + base.has_toplevel_qname(ModuleName::pydantic().as_str(), "BaseModel") + || metadata.is_pydantic_model() + }); + let keywords = keywords + .iter() + .filter(|(name, _)| name.id != "metaclass" && !(is_pydantic_model && name.id == EXTRA)) + .collect::>(); + if !keywords.is_empty() && metaclass.is_some() { + return; + } + let Some((base, _)) = bases_with_metadata.first() else { + return; + }; + let include_init_subclass_ancestors = !keywords.is_empty(); + let base = self.promote_nontypeddict_silently_to_classtype(base); + let Some(init_subclass) = + self.get_dunder_init_subclass(&base, include_init_subclass_ancestors) + else { + return; + }; + let keywords = keywords + .into_iter() + .map(|(name, value)| CallKeyword { + range: name.range(), + arg: Some(name), + value: TypeOrExpr::Expr(value), + }) + .collect::>(); + self.call_infer( + self.as_call_target_or_error( + init_subclass, + CallStyle::Method(&dunder::INIT_SUBCLASS), + cls.range(), + errors, + None, + ), + &[], + &keywords, + cls.range(), + errors, + None, + None, + None, + ); + } + fn extract_slots_info(&self, cls: &Class) -> Option { let key = KeyClassField(cls.index(), dunder::SLOTS.clone()); let idx = self.bindings().key_to_idx_hashed_opt(Hashed::new(&key))?; diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index e9cd3196a1..07aa8cadf8 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -3110,7 +3110,7 @@ pub struct BindingClassMetadata { /// The class keywords (these are keyword args that appear in the base class list, the /// Python runtime will dispatch most of them to the metaclass, but the metaclass /// itself can also potentially be one of these). - pub keywords: Box<[(Name, Expr)]>, + pub keywords: Box<[(Identifier, Expr)]>, /// The class decorators. pub decorators: Box<[Idx]>, /// Is this a new type? True only for synthesized classes created from a `NewType` call. diff --git a/pyrefly/lib/binding/class.rs b/pyrefly/lib/binding/class.rs index 8faa1ce0c8..1fe57c0527 100644 --- a/pyrefly/lib/binding/class.rs +++ b/pyrefly/lib/binding/class.rs @@ -363,7 +363,7 @@ impl<'a> BindingsBuilder<'a> { args.keywords.iter_mut().for_each(|keyword| { if let Some(name) = &keyword.arg { self.ensure_expr(&mut keyword.value, class_object.usage()); - keywords.push((name.id.clone(), keyword.value.clone())); + keywords.push((name.clone(), keyword.value.clone())); } else { self.error( keyword.range(), @@ -960,7 +960,7 @@ impl<'a> BindingsBuilder<'a> { class_indices: ClassIndices, parent: &NestingContext, base: Option, - keywords: Box<[(Name, Expr)]>, + keywords: Box<[(Identifier, Expr)]>, // name, position, annotation, value member_definitions: Vec<(String, TextRange, Option, Option)>, illegal_identifier_handling: IllegalIdentifierHandling, @@ -1396,8 +1396,12 @@ impl<'a> BindingsBuilder<'a> { (Some(name), _) if name == "extra_items" => Some(name), _ => None, }; - if let Some(kw_name) = recognized_kw { - base_class_keywords.push((kw_name.clone(), kw.value.clone())); + if recognized_kw.is_some() { + let kw_name = kw + .arg + .clone() + .expect("recognized TypedDict keyword must have a name"); + base_class_keywords.push((kw_name, kw.value.clone())); } else { let msg = if let Some(name) = &kw.arg { format!("Unrecognized keyword argument `{name}`") diff --git a/pyrefly/lib/test/constructors.rs b/pyrefly/lib/test/constructors.rs index 7502f9f92f..7b204ecd49 100644 --- a/pyrefly/lib/test/constructors.rs +++ b/pyrefly/lib/test/constructors.rs @@ -187,6 +187,31 @@ assert_type(C(), C[Any]) # Correct, because invalid metaclass. "#, ); +testcase!( + test_init_subclass_class_keywords, + r#" +class Foo: + def __init_subclass__(cls, asdf: int) -> None: + pass + +class Bar(Foo, asdf=1): ... +class Baz(Foo, asdf=""): ... # E: Argument `Literal['']` is not assignable to parameter `asdf` with type `int` +class Qux(Foo): ... # E: Missing argument `asdf` + "#, +); + +testcase!( + test_init_subclass_skips_custom_metaclass_keywords, + r#" +class Meta(type): + def __new__(cls, name, bases, namespace, abstract: bool = False): + return super().__new__(cls, name, bases, namespace) + +class Base(metaclass=Meta): ... +class Child(Base, abstract=True): ... + "#, +); + testcase!( test_metaclass_invalid_generic_legacy_typevar, r#"