Skip to content

Commit

Permalink
Auto merge of rust-lang#122385 - lcnr:analyze-obligations-for-infer, …
Browse files Browse the repository at this point in the history
…r=compiler-errors

`obligations_for_self_ty`: use `ProofTreeVisitor` for nested goals

As always, dealing with proof trees continues to be a hacked together mess. After this PR and rust-lang#124380 the only remaining blocker for core is rust-lang/trait-system-refactor-initiative#90. There is also a `ProofTreeVisitor` issue causing an ICE when compiling `alloc` which I will handle in a separate PR. This issue likely affects coherence diagnostics more generally.

The core idea is to extend the proof tree visitor to support visiting nested candidates without using a `probe`. We then simply recurse into nested candidates if they are the only potentially applicable candidate for a given goal and check whether the self type matches the expected one.

For that to work, we need to improve `CanonicalState` to also handle unconstrained inference variables created inside of the trait solver. This is done by extending the `var_values` of `CanoncalState` with each fresh inference variables. Furthermore, we also store the state of all inference variables at the end of each probe. When recursing into `InspectCandidates` we then unify the values of all these states.

r? `@compiler-errors`
  • Loading branch information
bors committed Apr 26, 2024
2 parents 9adafa7 + 146f637 commit 1b3a329
Show file tree
Hide file tree
Showing 25 changed files with 677 additions and 350 deletions.
6 changes: 4 additions & 2 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
ty::Infer(ty::TyVar(vid)) => self.deduce_closure_signature_from_predicates(
Ty::new_var(self.tcx, self.root_var(vid)),
closure_kind,
self.obligations_for_self_ty(vid).map(|obl| (obl.predicate, obl.cause.span)),
self.obligations_for_self_ty(vid)
.into_iter()
.map(|obl| (obl.predicate, obl.cause.span)),
),
ty::FnPtr(sig) => match closure_kind {
hir::ClosureKind::Closure => {
Expand Down Expand Up @@ -889,7 +891,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let output_ty = match *ret_ty.kind() {
ty::Infer(ty::TyVar(ret_vid)) => {
self.obligations_for_self_ty(ret_vid).find_map(|obligation| {
self.obligations_for_self_ty(ret_vid).into_iter().find_map(|obligation| {
get_future_output(obligation.predicate, obligation.cause.span)
})?
}
Expand Down
103 changes: 22 additions & 81 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::errors::CtorIsPrivate;
use crate::method::{self, MethodCallee, SelfSource};
use crate::rvalue_scopes;
use crate::{BreakableCtxt, Diverges, Expectation, FnCtxt, LoweredTy};
use rustc_data_structures::captures::Captures;
use rustc_data_structures::fx::FxHashSet;
use rustc_errors::{Applicability, Diag, ErrorGuaranteed, MultiSpan, StashKey};
use rustc_hir as hir;
Expand Down Expand Up @@ -47,7 +46,7 @@ use std::slice;
impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// Produces warning on the given node, if the current point in the
/// function is unreachable, and there hasn't been another warning.
pub(in super::super) fn warn_if_unreachable(&self, id: HirId, span: Span, kind: &str) {
pub(crate) fn warn_if_unreachable(&self, id: HirId, span: Span, kind: &str) {
// FIXME: Combine these two 'if' expressions into one once
// let chains are implemented
if let Diverges::Always { span: orig_span, custom_note } = self.diverges.get() {
Expand Down Expand Up @@ -87,7 +86,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// FIXME(-Znext-solver): A lot of the calls to this method should
// probably be `try_structurally_resolve_type` or `structurally_resolve_type` instead.
#[instrument(skip(self), level = "debug", ret)]
pub(in super::super) fn resolve_vars_with_obligations(&self, mut ty: Ty<'tcx>) -> Ty<'tcx> {
pub(crate) fn resolve_vars_with_obligations(&self, mut ty: Ty<'tcx>) -> Ty<'tcx> {
// No Infer()? Nothing needs doing.
if !ty.has_non_region_infer() {
debug!("no inference var, nothing needs doing");
Expand All @@ -109,7 +108,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self.resolve_vars_if_possible(ty)
}

pub(in super::super) fn record_deferred_call_resolution(
pub(crate) fn record_deferred_call_resolution(
&self,
closure_def_id: LocalDefId,
r: DeferredCallResolution<'tcx>,
Expand All @@ -118,7 +117,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
deferred_call_resolutions.entry(closure_def_id).or_default().push(r);
}

pub(in super::super) fn remove_deferred_call_resolutions(
pub(crate) fn remove_deferred_call_resolutions(
&self,
closure_def_id: LocalDefId,
) -> Vec<DeferredCallResolution<'tcx>> {
Expand Down Expand Up @@ -172,7 +171,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

#[instrument(level = "debug", skip(self))]
pub(in super::super) fn write_resolution(
pub(crate) fn write_resolution(
&self,
hir_id: HirId,
r: Result<(DefKind, DefId), ErrorGuaranteed>,
Expand Down Expand Up @@ -336,7 +335,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

/// Instantiates and normalizes the bounds for a given item
pub(in super::super) fn instantiate_bounds(
pub(crate) fn instantiate_bounds(
&self,
span: Span,
def_id: DefId,
Expand All @@ -349,7 +348,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
result
}

pub(in super::super) fn normalize<T>(&self, span: Span, value: T) -> T
pub(crate) fn normalize<T>(&self, span: Span, value: T) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
Expand Down Expand Up @@ -537,7 +536,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self.normalize(span, field.ty(self.tcx, args))
}

pub(in super::super) fn resolve_rvalue_scopes(&self, def_id: DefId) {
pub(crate) fn resolve_rvalue_scopes(&self, def_id: DefId) {
let scope_tree = self.tcx.region_scope_tree(def_id);
let rvalue_scopes = { rvalue_scopes::resolve_rvalue_scopes(self, scope_tree, def_id) };
let mut typeck_results = self.typeck_results.borrow_mut();
Expand All @@ -553,7 +552,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// We must not attempt to select obligations after this method has run, or risk query cycle
/// ICE.
#[instrument(level = "debug", skip(self))]
pub(in super::super) fn resolve_coroutine_interiors(&self) {
pub(crate) fn resolve_coroutine_interiors(&self) {
// Try selecting all obligations that are not blocked on inference variables.
// Once we start unifying coroutine witnesses, trying to select obligations on them will
// trigger query cycle ICEs, as doing so requires MIR.
Expand Down Expand Up @@ -594,7 +593,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

#[instrument(skip(self), level = "debug")]
pub(in super::super) fn report_ambiguity_errors(&self) {
pub(crate) fn report_ambiguity_errors(&self) {
let mut errors = self.fulfillment_cx.borrow_mut().collect_remaining_errors(self);

if !errors.is_empty() {
Expand All @@ -609,7 +608,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

/// Select as many obligations as we can at present.
pub(in super::super) fn select_obligations_where_possible(
pub(crate) fn select_obligations_where_possible(
&self,
mutate_fulfillment_errors: impl Fn(&mut Vec<traits::FulfillmentError<'tcx>>),
) {
Expand All @@ -625,7 +624,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// returns a type of `&T`, but the actual type we assign to the
/// *expression* is `T`. So this function just peels off the return
/// type by one layer to yield `T`.
pub(in super::super) fn make_overloaded_place_return_type(
pub(crate) fn make_overloaded_place_return_type(
&self,
method: MethodCallee<'tcx>,
) -> ty::TypeAndMut<'tcx> {
Expand All @@ -636,67 +635,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
ret_ty.builtin_deref(true).unwrap()
}

#[instrument(skip(self), level = "debug")]
fn self_type_matches_expected_vid(&self, self_ty: Ty<'tcx>, expected_vid: ty::TyVid) -> bool {
let self_ty = self.shallow_resolve(self_ty);
debug!(?self_ty);

match *self_ty.kind() {
ty::Infer(ty::TyVar(found_vid)) => {
let found_vid = self.root_var(found_vid);
debug!("self_type_matches_expected_vid - found_vid={:?}", found_vid);
expected_vid == found_vid
}
_ => false,
}
}

#[instrument(skip(self), level = "debug")]
pub(in super::super) fn obligations_for_self_ty<'b>(
&'b self,
self_ty: ty::TyVid,
) -> impl DoubleEndedIterator<Item = traits::PredicateObligation<'tcx>> + Captures<'tcx> + 'b
{
let ty_var_root = self.root_var(self_ty);
trace!("pending_obligations = {:#?}", self.fulfillment_cx.borrow().pending_obligations());

self.fulfillment_cx.borrow().pending_obligations().into_iter().filter_map(
move |obligation| match &obligation.predicate.kind().skip_binder() {
ty::PredicateKind::Clause(ty::ClauseKind::Projection(data))
if self.self_type_matches_expected_vid(
data.projection_ty.self_ty(),
ty_var_root,
) =>
{
Some(obligation)
}
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data))
if self.self_type_matches_expected_vid(data.self_ty(), ty_var_root) =>
{
Some(obligation)
}

ty::PredicateKind::Clause(ty::ClauseKind::Trait(..))
| ty::PredicateKind::Clause(ty::ClauseKind::Projection(..))
| ty::PredicateKind::Clause(ty::ClauseKind::ConstArgHasType(..))
| ty::PredicateKind::Subtype(..)
| ty::PredicateKind::Coerce(..)
| ty::PredicateKind::Clause(ty::ClauseKind::RegionOutlives(..))
| ty::PredicateKind::Clause(ty::ClauseKind::TypeOutlives(..))
| ty::PredicateKind::Clause(ty::ClauseKind::WellFormed(..))
| ty::PredicateKind::ObjectSafe(..)
| ty::PredicateKind::NormalizesTo(..)
| ty::PredicateKind::AliasRelate(..)
| ty::PredicateKind::Clause(ty::ClauseKind::ConstEvaluatable(..))
| ty::PredicateKind::ConstEquate(..)
| ty::PredicateKind::Ambiguous => None,
},
)
}

pub(in super::super) fn type_var_is_sized(&self, self_ty: ty::TyVid) -> bool {
pub(crate) fn type_var_is_sized(&self, self_ty: ty::TyVid) -> bool {
let sized_did = self.tcx.lang_items().sized_trait();
self.obligations_for_self_ty(self_ty).any(|obligation| {
self.obligations_for_self_ty(self_ty).into_iter().any(|obligation| {
match obligation.predicate.kind().skip_binder() {
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data)) => {
Some(data.def_id()) == sized_did
Expand All @@ -706,15 +647,15 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
})
}

pub(in super::super) fn err_args(&self, len: usize) -> Vec<Ty<'tcx>> {
pub(crate) fn err_args(&self, len: usize) -> Vec<Ty<'tcx>> {
let ty_error = Ty::new_misc_error(self.tcx);
vec![ty_error; len]
}

/// Unifies the output type with the expected type early, for more coercions
/// and forward type information on the input expressions.
#[instrument(skip(self, call_span), level = "debug")]
pub(in super::super) fn expected_inputs_for_expected_output(
pub(crate) fn expected_inputs_for_expected_output(
&self,
call_span: Span,
expected_ret: Expectation<'tcx>,
Expand Down Expand Up @@ -747,7 +688,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
expect_args
}

pub(in super::super) fn resolve_lang_item_path(
pub(crate) fn resolve_lang_item_path(
&self,
lang_item: hir::LangItem,
span: Span,
Expand Down Expand Up @@ -926,7 +867,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// but we often want access to the parent function's signature.
///
/// Otherwise, return false.
pub(in super::super) fn get_node_fn_decl(
pub(crate) fn get_node_fn_decl(
&self,
node: Node<'tcx>,
) -> Option<(LocalDefId, &'tcx hir::FnDecl<'tcx>, Ident, bool)> {
Expand Down Expand Up @@ -1004,7 +945,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
})
}

pub(in super::super) fn note_internal_mutation_in_method(
pub(crate) fn note_internal_mutation_in_method(
&self,
err: &mut Diag<'_>,
expr: &hir::Expr<'_>,
Expand Down Expand Up @@ -1549,7 +1490,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}
}

pub(in super::super) fn with_breakable_ctxt<F: FnOnce() -> R, R>(
pub(crate) fn with_breakable_ctxt<F: FnOnce() -> R, R>(
&self,
id: HirId,
ctxt: BreakableCtxt<'tcx>,
Expand All @@ -1575,7 +1516,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

/// Instantiate a QueryResponse in a probe context, without a
/// good ObligationCause.
pub(in super::super) fn probe_instantiate_query_response(
pub(crate) fn probe_instantiate_query_response(
&self,
span: Span,
original_values: &OriginalQueryValues<'tcx>,
Expand All @@ -1590,7 +1531,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

/// Returns `true` if an expression is contained inside the LHS of an assignment expression.
pub(in super::super) fn expr_in_place(&self, mut expr_id: HirId) -> bool {
pub(crate) fn expr_in_place(&self, mut expr_id: HirId) -> bool {
let mut contained_in_place = false;

while let hir::Node::Expr(parent_expr) = self.tcx.parent_hir_node(expr_id) {
Expand Down
Loading

0 comments on commit 1b3a329

Please sign in to comment.