Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,11 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
for (idx, param) in self.body.params.iter().enumerate() {
let param_idx = rty::FunctionParamIdx::from(idx);
let mir_ty = self.pat_ty(param.pat);
let ty = self.type_builder.build(mir_ty);
let term = if !self.is_fn_param_wrapper_ty(mir_ty) && ty.to_sort().is_singleton() {
// `at_entry()` yields the `Inner` of a `FnParam<Inner>`; classify by it so
// a singleton wrapped argument collapses like any other singleton below.
let repr_ty = self.fn_param_wrapper_inner_ty(mir_ty).unwrap_or(mir_ty);
let ty = self.type_builder.build(repr_ty);
let term = if ty.to_sort().is_singleton() {
// the analyzer don't expect params with singleton sorts to be used in formula...
// FIXME: fix the analyzer side to uniformly accept all params
Self::singleton_term_for_ty(&ty).unwrap()
Expand Down Expand Up @@ -259,9 +262,16 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
}
}

fn is_fn_param_wrapper_ty(&self, ty: mir_ty::Ty<'tcx>) -> bool {
ty.ty_adt_def()
.is_some_and(|def| Some(def.did()) == self.def_ids.fn_param_wrapper())
/// The `Inner` of a `thrust_models::FnParam<Inner>` wrapper type, if `ty` is one.
fn fn_param_wrapper_inner_ty(&self, ty: mir_ty::Ty<'tcx>) -> Option<mir_ty::Ty<'tcx>> {
match ty.kind() {
mir_ty::TyKind::Adt(def, args)
if Some(def.did()) == self.def_ids.fn_param_wrapper() =>
{
Some(args.type_at(0))
}
_ => None,
}
}

fn build_env_from_pat(
Expand Down
50 changes: 50 additions & 0 deletions tests/ui/fail/loop_invariant_fn_param_closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

// A loop invariant refers to a closure parameter via `FnParam<F>`, whose
// `f.at_entry()` yields `Closure<F>`. Here the invariant relates `acc` to the
// entry closure's postcondition, from which the postcondition below is proven.
#[thrust_macros::ensures((n > 0) ==> thrust_macros::post!(f(n - 1), result))]
#[thrust_macros::invariant_context]
fn last_apply<F>(f: F, n: i64) -> i64
where
F: Fn(i64) -> i64,
{
let mut acc = 0_i64;
let mut i = 0_i64;
while i < n {
thrust_macros::invariant!(
|i: i64, acc: i64, f: thrust_models::FnParam<F>|
(i > 0) ==> thrust_macros::post!(f.at_entry()(i - 1), acc)
);
acc = f(i);
i += 1;
}
acc
}

// A capture-free closure is null (singleton) sorted; comparing its identity
// must collapse to a canonical value rather than ICE during clause building.
#[thrust_macros::invariant_context]
fn unchanged<F>(mut f: F)
where
F: FnMut(i64) -> i64,
{
let _ = &mut f;
let mut i = 0_i64;
while i < 10 {
thrust_macros::invariant!(
|i: i64, f: thrust_models::FnParam<F>|
f.at_entry() == f.at_entry() && i <= 10
);
i += 1;
}
assert!(i == 11); // Unsat: the loop exits with i == 10
}

fn main() {
let c = 7_i64;
let _ = last_apply(|x| x + c, 10);

unchanged(|x| x + 1);
}
50 changes: 50 additions & 0 deletions tests/ui/pass/loop_invariant_fn_param_closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

// A loop invariant refers to a closure parameter via `FnParam<F>`, whose
// `f.at_entry()` yields `Closure<F>`. Here the invariant relates `acc` to the
// entry closure's postcondition, from which the postcondition below is proven.
#[thrust_macros::ensures((n > 0) ==> thrust_macros::post!(f(n - 1), result))]
#[thrust_macros::invariant_context]
fn last_apply<F>(f: F, n: i64) -> i64
where
F: Fn(i64) -> i64,
{
let mut acc = 0_i64;
let mut i = 0_i64;
while i < n {
thrust_macros::invariant!(
|i: i64, acc: i64, f: thrust_models::FnParam<F>|
(i > 0) ==> thrust_macros::post!(f.at_entry()(i - 1), acc)
);
acc = f(i);
i += 1;
}
acc
}

// A capture-free closure is null (singleton) sorted; comparing its identity
// must collapse to a canonical value rather than ICE during clause building.
#[thrust_macros::invariant_context]
fn unchanged<F>(mut f: F)
where
F: FnMut(i64) -> i64,
{
let _ = &mut f;
let mut i = 0_i64;
while i < 10 {
thrust_macros::invariant!(
|i: i64, f: thrust_models::FnParam<F>|
f.at_entry() == f.at_entry() && i <= 10
);
i += 1;
}
assert!(i == 10);
}

fn main() {
let c = 7_i64;
let _ = last_apply(|x| x + c, 10);

unchanged(|x| x + 1);
}
13 changes: 13 additions & 0 deletions thrust-macros/src/formula_fn_type_lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ impl<'a> FormulaFnTypeLowering<'a> {
.collect();
syn::Type::Tuple(tt)
}
syn::Type::Path(tp) => {
let mut tp = tp.clone();
for segment in &mut tp.path.segments {
if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
for arg in &mut args.args {
if let syn::GenericArgument::Type(elem) = arg {
*elem = self.lower_closure_type_params_in_ty(elem);
}
}
}
}
syn::Type::Path(tp)
}
// TODO: support more types including ADT
_ => ty.clone(),
}
Expand Down