From 90afc0765e5e536af6307b63e1655a38df06e235 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Wed, 22 Jan 2020 19:23:37 -0500 Subject: [PATCH] Use a `ParamEnvAnd` for caching in `ObligationForest` Previously, we used a plain `Predicate` to cache results (e.g. successes and failures) in ObligationForest. However, fulfillment depends on the precise `ParamEnv` used, so this is unsound in general. This commit changes the impl of `ForestObligation` for `PendingPredicateObligation` to use `ParamEnvAnd` instead of `Predicate` for the associated type. The associated type and method are renamed from 'predicate' to 'cache_key' to reflect the fact that type is no longer just a predicate. --- src/librustc/traits/fulfill.rs | 9 ++++-- .../obligation_forest/graphviz.rs | 2 +- .../obligation_forest/mod.rs | 29 +++++++++++-------- .../obligation_forest/tests.rs | 4 +-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/librustc/traits/fulfill.rs b/src/librustc/traits/fulfill.rs index 0aac6fb81e4a3..07352a3f9478a 100644 --- a/src/librustc/traits/fulfill.rs +++ b/src/librustc/traits/fulfill.rs @@ -18,10 +18,13 @@ use super::{FulfillmentError, FulfillmentErrorCode}; use super::{ObligationCause, PredicateObligation}; impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> { - type Predicate = ty::Predicate<'tcx>; + /// Note that we include both the `ParamEnv` and the `Predicate`, + /// as the `ParamEnv` can influence whether fulfillment succeeds + /// or fails. + type CacheKey = ty::ParamEnvAnd<'tcx, ty::Predicate<'tcx>>; - fn as_predicate(&self) -> &Self::Predicate { - &self.obligation.predicate + fn as_cache_key(&self) -> Self::CacheKey { + self.obligation.param_env.and(self.obligation.predicate) } } diff --git a/src/librustc_data_structures/obligation_forest/graphviz.rs b/src/librustc_data_structures/obligation_forest/graphviz.rs index ddf89d99621ca..96ee72d187b34 100644 --- a/src/librustc_data_structures/obligation_forest/graphviz.rs +++ b/src/librustc_data_structures/obligation_forest/graphviz.rs @@ -51,7 +51,7 @@ impl<'a, O: ForestObligation + 'a> dot::Labeller<'a> for &'a ObligationForest fn node_label(&self, index: &Self::Node) -> dot::LabelText<'_> { let node = &self.nodes[*index]; - let label = format!("{:?} ({:?})", node.obligation.as_predicate(), node.state.get()); + let label = format!("{:?} ({:?})", node.obligation.as_cache_key(), node.state.get()); dot::LabelText::LabelStr(label.into()) } diff --git a/src/librustc_data_structures/obligation_forest/mod.rs b/src/librustc_data_structures/obligation_forest/mod.rs index 974d9dcfae408..500ce5c71f37a 100644 --- a/src/librustc_data_structures/obligation_forest/mod.rs +++ b/src/librustc_data_structures/obligation_forest/mod.rs @@ -86,9 +86,13 @@ mod graphviz; mod tests; pub trait ForestObligation: Clone + Debug { - type Predicate: Clone + hash::Hash + Eq + Debug; + type CacheKey: Clone + hash::Hash + Eq + Debug; - fn as_predicate(&self) -> &Self::Predicate; + /// Converts this `ForestObligation` suitable for use as a cache key. + /// If two distinct `ForestObligations`s return the same cache key, + /// then it must be sound to use the result of processing one obligation + /// (e.g. success for error) for the other obligation + fn as_cache_key(&self) -> Self::CacheKey; } pub trait ObligationProcessor { @@ -138,12 +142,12 @@ pub struct ObligationForest { nodes: Vec>, /// A cache of predicates that have been successfully completed. - done_cache: FxHashSet, + done_cache: FxHashSet, /// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately, /// its contents are not guaranteed to match those of `nodes`. See the /// comments in `process_obligation` for details. - active_cache: FxHashMap, + active_cache: FxHashMap, /// A vector reused in compress(), to avoid allocating new vectors. node_rewrites: RefCell>, @@ -157,7 +161,7 @@ pub struct ObligationForest { /// See [this][details] for details. /// /// [details]: https://github.com/rust-lang/rust/pull/53255#issuecomment-421184780 - error_cache: FxHashMap>, + error_cache: FxHashMap>, } #[derive(Debug)] @@ -305,11 +309,12 @@ impl ObligationForest { // Returns Err(()) if we already know this obligation failed. fn register_obligation_at(&mut self, obligation: O, parent: Option) -> Result<(), ()> { - if self.done_cache.contains(obligation.as_predicate()) { + if self.done_cache.contains(&obligation.as_cache_key()) { + debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation); return Ok(()); } - match self.active_cache.entry(obligation.as_predicate().clone()) { + match self.active_cache.entry(obligation.as_cache_key().clone()) { Entry::Occupied(o) => { let node = &mut self.nodes[*o.get()]; if let Some(parent_index) = parent { @@ -333,7 +338,7 @@ impl ObligationForest { && self .error_cache .get(&obligation_tree_id) - .map(|errors| errors.contains(obligation.as_predicate())) + .map(|errors| errors.contains(&obligation.as_cache_key())) .unwrap_or(false); if already_failed { @@ -380,7 +385,7 @@ impl ObligationForest { self.error_cache .entry(node.obligation_tree_id) .or_default() - .insert(node.obligation.as_predicate().clone()); + .insert(node.obligation.as_cache_key().clone()); } /// Performs a pass through the obligation list. This must @@ -618,11 +623,11 @@ impl ObligationForest { // `self.nodes`. See the comment in `process_obligation` // for more details. if let Some((predicate, _)) = - self.active_cache.remove_entry(node.obligation.as_predicate()) + self.active_cache.remove_entry(&node.obligation.as_cache_key()) { self.done_cache.insert(predicate); } else { - self.done_cache.insert(node.obligation.as_predicate().clone()); + self.done_cache.insert(node.obligation.as_cache_key().clone()); } if do_completed == DoCompleted::Yes { // Extract the success stories. @@ -635,7 +640,7 @@ impl ObligationForest { // We *intentionally* remove the node from the cache at this point. Otherwise // tests must come up with a different type on every type error they // check against. - self.active_cache.remove(node.obligation.as_predicate()); + self.active_cache.remove(&node.obligation.as_cache_key()); self.insert_into_error_cache(index); node_rewrites[index] = orig_nodes_len; dead_nodes += 1; diff --git a/src/librustc_data_structures/obligation_forest/tests.rs b/src/librustc_data_structures/obligation_forest/tests.rs index e29335aab2808..01652465eea2c 100644 --- a/src/librustc_data_structures/obligation_forest/tests.rs +++ b/src/librustc_data_structures/obligation_forest/tests.rs @@ -4,9 +4,9 @@ use std::fmt; use std::marker::PhantomData; impl<'a> super::ForestObligation for &'a str { - type Predicate = &'a str; + type CacheKey = &'a str; - fn as_predicate(&self) -> &Self::Predicate { + fn as_cache_key(&self) -> Self::CacheKey { self } }