Skip to content

Commit

Permalink
Make TypeFoldable implementors short-circuit on error
Browse files Browse the repository at this point in the history
Co-authored-by: Alan Egerton <eggyal@gmail.com>
  • Loading branch information
LeSeulArtichaut and eggyal committed Nov 26, 2021
1 parent c5f0d0e commit 6e3fa20
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 204 deletions.
70 changes: 69 additions & 1 deletion compiler/rustc_data_structures/src/functor.rs
Expand Up @@ -2,12 +2,16 @@ use rustc_index::vec::{Idx, IndexVec};
use std::mem;
use std::ptr;

pub trait IdFunctor {
pub trait IdFunctor: Sized {
type Inner;

fn map_id<F>(self, f: F) -> Self
where
F: FnMut(Self::Inner) -> Self::Inner;

fn try_map_id<F, E>(self, f: F) -> Result<Self, E>
where
F: FnMut(Self::Inner) -> Result<Self::Inner, E>;
}

impl<T> IdFunctor for Box<T> {
Expand All @@ -31,6 +35,25 @@ impl<T> IdFunctor for Box<T> {
raw.assume_init()
}
}

#[inline]
fn try_map_id<F, E>(self, mut f: F) -> Result<Self, E>
where
F: FnMut(Self::Inner) -> Result<Self::Inner, E>,
{
let raw = Box::into_raw(self);
Ok(unsafe {
// SAFETY: The raw pointer points to a valid value of type `T`.
let value = ptr::read(raw);
// SAFETY: Converts `Box<T>` to `Box<MaybeUninit<T>>` which is the
// inverse of `Box::assume_init()` and should be safe.
let mut raw: Box<mem::MaybeUninit<T>> = Box::from_raw(raw.cast());
// SAFETY: Write the mapped value back into the `Box`.
ptr::write(raw.as_mut_ptr(), f(value)?);
// SAFETY: We just initialized `raw`.
raw.assume_init()
})
}
}

impl<T> IdFunctor for Vec<T> {
Expand All @@ -55,6 +78,35 @@ impl<T> IdFunctor for Vec<T> {
}
self
}

#[inline]
fn try_map_id<F, E>(mut self, mut f: F) -> Result<Self, E>
where
F: FnMut(Self::Inner) -> Result<Self::Inner, E>,
{
// FIXME: We don't really care about panics here and leak
// far more than we should, but that should be fine for now.
let len = self.len();
let mut error = Ok(());
unsafe {
self.set_len(0);
let start = self.as_mut_ptr();
for i in 0..len {
let p = start.add(i);
match f(ptr::read(p)) {
Ok(value) => ptr::write(p, value),
Err(err) => {
error = Err(err);
break;
}
}
}
// Even if we encountered an error, set the len back
// so we don't leak memory.
self.set_len(len);
}
error.map(|()| self)
}
}

impl<T> IdFunctor for Box<[T]> {
Expand All @@ -67,6 +119,14 @@ impl<T> IdFunctor for Box<[T]> {
{
Vec::from(self).map_id(f).into()
}

#[inline]
fn try_map_id<F, E>(self, f: F) -> Result<Self, E>
where
F: FnMut(Self::Inner) -> Result<Self::Inner, E>,
{
Vec::from(self).try_map_id(f).map(Into::into)
}
}

impl<I: Idx, T> IdFunctor for IndexVec<I, T> {
Expand All @@ -79,4 +139,12 @@ impl<I: Idx, T> IdFunctor for IndexVec<I, T> {
{
IndexVec::from_raw(self.raw.map_id(f))
}

#[inline]
fn try_map_id<F, E>(self, f: F) -> Result<Self, E>
where
F: FnMut(Self::Inner) -> Result<Self::Inner, E>,
{
self.raw.try_map_id(f).map(IndexVec::from_raw)
}
}
10 changes: 5 additions & 5 deletions compiler/rustc_infer/src/traits/structural_impls.rs
Expand Up @@ -60,13 +60,13 @@ impl<'tcx> fmt::Debug for traits::MismatchedProjectionTypes<'tcx> {
// TypeFoldable implementations.

impl<'tcx, O: TypeFoldable<'tcx>> TypeFoldable<'tcx> for traits::Obligation<'tcx, O> {
fn super_fold_with<F: TypeFolder<'tcx>>(self, folder: &mut F) -> Self {
traits::Obligation {
fn super_fold_with<F: TypeFolder<'tcx>>(self, folder: &mut F) -> Result<Self, F::Error> {
Ok(traits::Obligation {
cause: self.cause,
recursion_depth: self.recursion_depth,
predicate: self.predicate.fold_with(folder),
param_env: self.param_env.fold_with(folder),
}
predicate: self.predicate.fold_with(folder)?,
param_env: self.param_env.fold_with(folder)?,
})
}

fn super_visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_macros/src/type_foldable.rs
Expand Up @@ -17,7 +17,7 @@ pub fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::
vi.construct(|_, index| {
let bind = &bindings[index];
quote! {
::rustc_middle::ty::fold::TypeFoldable::fold_with(#bind, __folder)
::rustc_middle::ty::fold::TypeFoldable::fold_with(#bind, __folder)?
}
})
});
Expand All @@ -28,8 +28,8 @@ pub fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::
fn super_fold_with<__F: ::rustc_middle::ty::fold::TypeFolder<'tcx>>(
self,
__folder: &mut __F
) -> Self {
match self { #body_fold }
) -> Result<Self, __F::Error> {
Ok(match self { #body_fold })
}

fn super_visit_with<__F: ::rustc_middle::ty::fold::TypeVisitor<'tcx>>(
Expand Down
14 changes: 7 additions & 7 deletions compiler/rustc_middle/src/macros.rs
Expand Up @@ -55,8 +55,8 @@ macro_rules! TrivialTypeFoldableImpls {
fn super_fold_with<F: $crate::ty::fold::TypeFolder<$tcx>>(
self,
_: &mut F
) -> $ty {
self
) -> ::std::result::Result<$ty, F::Error> {
Ok(self)
}

fn super_visit_with<F: $crate::ty::fold::TypeVisitor<$tcx>>(
Expand Down Expand Up @@ -98,7 +98,7 @@ macro_rules! EnumTypeFoldableImpl {
fn super_fold_with<V: $crate::ty::fold::TypeFolder<$tcx>>(
self,
folder: &mut V,
) -> Self {
) -> ::std::result::Result<Self, V::Error> {
EnumTypeFoldableImpl!(@FoldVariants(self, folder) input($($variants)*) output())
}

Expand All @@ -112,9 +112,9 @@ macro_rules! EnumTypeFoldableImpl {
};

(@FoldVariants($this:expr, $folder:expr) input() output($($output:tt)*)) => {
match $this {
Ok(match $this {
$($output)*
}
})
};

(@FoldVariants($this:expr, $folder:expr)
Expand All @@ -126,7 +126,7 @@ macro_rules! EnumTypeFoldableImpl {
output(
$variant ( $($variant_arg),* ) => {
$variant (
$($crate::ty::fold::TypeFoldable::fold_with($variant_arg, $folder)),*
$($crate::ty::fold::TypeFoldable::fold_with($variant_arg, $folder)?),*
)
}
$($output)*
Expand All @@ -145,7 +145,7 @@ macro_rules! EnumTypeFoldableImpl {
$variant {
$($variant_arg: $crate::ty::fold::TypeFoldable::fold_with(
$variant_arg, $folder
)),* }
)?),* }
}
$($output)*
)
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_middle/src/mir/mod.rs
Expand Up @@ -2760,11 +2760,11 @@ impl UserTypeProjection {
TrivialTypeFoldableAndLiftImpls! { ProjectionKind, }

impl<'tcx> TypeFoldable<'tcx> for UserTypeProjection {
fn super_fold_with<F: TypeFolder<'tcx>>(self, folder: &mut F) -> Self {
UserTypeProjection {
base: self.base.fold_with(folder),
projs: self.projs.fold_with(folder),
}
fn super_fold_with<F: TypeFolder<'tcx>>(self, folder: &mut F) -> Result<Self, F::Error> {
Ok(UserTypeProjection {
base: self.base.fold_with(folder)?,
projs: self.projs.fold_with(folder)?,
})
}

fn super_visit_with<Vs: TypeVisitor<'tcx>>(
Expand Down

0 comments on commit 6e3fa20

Please sign in to comment.