Skip to content

Commit

Permalink
Make StrgRef a type in rty (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Jun 16, 2024
1 parent da9fb4c commit a6baf28
Show file tree
Hide file tree
Showing 28 changed files with 664 additions and 575 deletions.
41 changes: 19 additions & 22 deletions crates/flux-desugar/src/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,12 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
for surface_requires in &fn_sig.requires {
let params = self.desugar_refine_params(&surface_requires.params);
let pred = self.desugar_expr(&surface_requires.pred)?;
requires.push(fhir::Constraint::Pred(params, pred));
requires.push(fhir::Requires { params, pred });
}

// Bail out if there's an error in the arguments to avoid confusing error messages
let args = try_alloc_slice!(self.genv, &fn_sig.args, |arg| {
self.desugar_fun_arg(arg, &mut requires)
})?;
let inputs =
try_alloc_slice!(self.genv, &fn_sig.inputs, |arg| self.desugar_fn_input(arg))?;

let output = self.desugar_fn_output(fn_sig.asyncness, &fn_sig.output)?;

Expand All @@ -519,7 +518,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
let decl = fhir::FnDecl {
generics,
requires: self.genv.alloc_slice(&requires),
args,
inputs,
output,
span: fn_sig.span,
lifted: false,
Expand Down Expand Up @@ -580,16 +579,15 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
) -> Result<fhir::FnOutput<'genv>> {
let ret = self.desugar_asyncness(asyncness, &output.returns);

let ensures =
try_alloc_slice!(self.genv, &output.ensures, |cstr| self.desugar_constraint(cstr))?;
let ensures = try_alloc_slice!(self.genv, &output.ensures, |it| self.desugar_ensures(it))?;

let params = self
.genv
.alloc_slice_fill_iter(self.implicit_params_to_params(output.node_id));
Ok(fhir::FnOutput { params, ret: ret?, ensures })
}

fn desugar_constraint(&mut self, cstr: &surface::Ensures) -> Result<fhir::Constraint<'genv>> {
fn desugar_ensures(&mut self, cstr: &surface::Ensures) -> Result<fhir::Ensures<'genv>> {
match cstr {
surface::Ensures::Type(loc, ty, node_id) => {
let res = self.desugar_loc(*loc, *node_id)?;
Expand All @@ -600,22 +598,18 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
span: loc.span,
};
let ty = self.desugar_ty(ty)?;
Ok(fhir::Constraint::Type(path, ty))
Ok(fhir::Ensures::Type(path, ty))
}
surface::Ensures::Pred(e) => {
let pred = self.desugar_expr(e)?;
Ok(fhir::Constraint::Pred(&[], pred))
Ok(fhir::Ensures::Pred(pred))
}
}
}

fn desugar_fun_arg(
&mut self,
arg: &surface::Arg,
requires: &mut Vec<fhir::Constraint<'genv>>,
) -> Result<fhir::Ty<'genv>> {
match arg {
surface::Arg::Constr(bind, path, pred, node_id) => {
fn desugar_fn_input(&mut self, input: &surface::FnInput) -> Result<fhir::Ty<'genv>> {
match input {
surface::FnInput::Constr(bind, path, pred, node_id) => {
let bty = self.desugar_path_to_bty(None, path)?;

let pred = self.desugar_expr(pred)?;
Expand All @@ -630,7 +624,7 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
let kind = fhir::TyKind::Constr(pred, self.genv.alloc(ty));
Ok(fhir::Ty { kind, span })
}
surface::Arg::StrgRef(loc, ty, node_id) => {
surface::FnInput::StrgRef(loc, ty, node_id) => {
let span = loc.span;
let (id, kind) = self.resolve_implicit_param(*node_id).unwrap();
let path = fhir::PathExpr {
Expand All @@ -640,11 +634,14 @@ impl<'a, 'genv, 'tcx: 'genv> RustItemCtxt<'a, 'genv, 'tcx> {
span: loc.span,
};
let ty = self.desugar_ty(ty)?;
requires.push(fhir::Constraint::Type(path, ty));
let kind = fhir::TyKind::Ptr(self.mk_lft_hole(), path);
let kind = fhir::TyKind::StrgRef(
self.mk_lft_hole(),
self.genv.alloc(path),
self.genv.alloc(ty),
);
Ok(fhir::Ty { kind, span })
}
surface::Arg::Ty(bind, ty, node_id) => {
surface::FnInput::Ty(bind, ty, node_id) => {
if let Some(bind) = bind
&& let surface::TyKind::Base(bty) = &ty.kind
{
Expand Down Expand Up @@ -830,7 +827,7 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
#[track_caller]
fn desugar_loc(&self, ident: surface::Ident, node_id: NodeId) -> Result<ExprRes> {
let res = self.resolver_output().path_expr_res_map[&node_id];
if let ExprRes::Param(fhir::ParamKind::Loc(_), _) = res {
if let ExprRes::Param(fhir::ParamKind::Loc, _) = res {
Ok(res)
} else {
let span = ident.span;
Expand Down
14 changes: 7 additions & 7 deletions crates/flux-desugar/src/resolver/refinement_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ impl<V: ScopedVisitor> surface::visit::Visitor for ScopedVisitorWrapper<V> {
});
}

fn visit_fun_arg(&mut self, arg: &surface::Arg, idx: usize) {
fn visit_fn_input(&mut self, arg: &surface::FnInput) {
match arg {
surface::Arg::Constr(bind, _, _, node_id) => {
surface::FnInput::Constr(bind, _, _, node_id) => {
self.on_implicit_param(*bind, fhir::ParamKind::Colon, *node_id);
}
surface::Arg::StrgRef(loc, _, node_id) => {
self.on_implicit_param(*loc, fhir::ParamKind::Loc(idx), *node_id);
surface::FnInput::StrgRef(loc, _, node_id) => {
self.on_implicit_param(*loc, fhir::ParamKind::Loc, *node_id);
}
surface::Arg::Ty(bind, ty, node_id) => {
surface::FnInput::Ty(bind, ty, node_id) => {
if let &Some(bind) = bind {
let param_kind = if let surface::TyKind::Base(_) = &ty.kind {
fhir::ParamKind::Colon
Expand All @@ -195,7 +195,7 @@ impl<V: ScopedVisitor> surface::visit::Visitor for ScopedVisitorWrapper<V> {
}
}
}
surface::visit::walk_fun_arg(self, arg);
surface::visit::walk_fn_input(self, arg);
}

fn visit_ensures(&mut self, constraint: &surface::Ensures) {
Expand Down Expand Up @@ -837,7 +837,7 @@ impl ScopedVisitor for IllegalBinderVisitor<'_, '_, '_> {
(matches!(scope_kind, ScopeKind::FnOutput), surface::BindKind::Pound)
}
fhir::ParamKind::Colon
| fhir::ParamKind::Loc(_)
| fhir::ParamKind::Loc
| fhir::ParamKind::Error
| fhir::ParamKind::Explicit => return,
};
Expand Down
23 changes: 13 additions & 10 deletions crates/flux-fhir-analysis/src/annot_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,18 @@ impl<'zip, 'genv, 'tcx> Zipper<'zip, 'genv, 'tcx> {
fn_decl: &fhir::FnDecl,
expected_fn_sig: &fhir::FnDecl<'genv>,
) -> Result<(), ErrorGuaranteed> {
if fn_decl.args.len() != expected_fn_sig.args.len() {
if fn_decl.inputs.len() != expected_fn_sig.inputs.len() {
return Err(self.emit_err(errors::FunArgCountMismatch::new(fn_decl, expected_fn_sig)));
}
self.zip_tys(fn_decl.args, expected_fn_sig.args)?;
self.zip_constraints(fn_decl.requires)?;
self.zip_tys(fn_decl.inputs, expected_fn_sig.inputs)?;

self.zip_ty(&fn_decl.output.ret, &expected_fn_sig.output.ret)?;
self.zip_constraints(fn_decl.output.ensures)
self.zip_ensures(fn_decl.output.ensures)
}

fn zip_constraints(&mut self, constrs: &[fhir::Constraint]) -> Result<(), ErrorGuaranteed> {
constrs.iter().try_for_each_exhaust(|constr| {
if let fhir::Constraint::Type(loc, ty) = constr {
fn zip_ensures(&mut self, ensures: &[fhir::Ensures]) -> Result<(), ErrorGuaranteed> {
ensures.iter().try_for_each_exhaust(|ensures| {
if let fhir::Ensures::Type(loc, ty) = ensures {
let ExprRes::Param(_, id) = loc.res else {
span_bug!(loc.span, "unexpected path in loc position")
};
Expand Down Expand Up @@ -234,12 +233,16 @@ impl<'zip, 'genv, 'tcx> Zipper<'zip, 'genv, 'tcx> {
fhir::TyKind::BaseTy(bty) | fhir::TyKind::Indexed(bty, _),
fhir::TyKind::BaseTy(expected_bty),
) => self.zip_bty(&bty, &expected_bty),
(fhir::TyKind::Ptr(lft, loc), fhir::TyKind::Ref(expected_lft, expected_mut_ty)) => {
(
fhir::TyKind::StrgRef(lft, loc, ty),
fhir::TyKind::Ref(expected_lft, expected_mut_ty),
) => {
if expected_mut_ty.mutbl.is_mut() {
let ExprRes::Param(_, id) = loc.res else {
span_bug!(loc.span, "unexpected path in loc position")
};
self.zip_lifetime(lft, expected_lft);
self.zip_ty(ty, expected_mut_ty.ty)?;
self.locs.insert(id, *expected_mut_ty.ty);
Ok(())
} else {
Expand Down Expand Up @@ -512,9 +515,9 @@ mod errors {
pub(super) fn new(decl: &fhir::FnDecl, expected_decl: &fhir::FnDecl) -> Self {
Self {
span: decl.span,
args: decl.args.len(),
args: decl.inputs.len(),
expected_span: expected_decl.span,
expected_args: expected_decl.args.len(),
expected_args: expected_decl.inputs.len(),
}
}
}
Expand Down
64 changes: 28 additions & 36 deletions crates/flux-fhir-analysis/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use rustc_hir::{
};
use rustc_middle::{
middle::resolve_bound_vars::ResolvedArg,
mir::Local,
ty::{self, AssocItem, AssocKind, BoundVar},
};
use rustc_span::{
Expand Down Expand Up @@ -329,13 +328,13 @@ pub(crate) fn conv_fn_decl<'genv>(
env.push_layer(Layer::list(&cx, late_bound_regions.len() as u32, &[])?);

let mut requires = vec![];
for constr in decl.requires {
requires.push(cx.conv_constr(&mut env, constr)?);
for req in decl.requires {
requires.push(cx.conv_requires(&mut env, req)?);
}

let mut args = vec![];
for ty in decl.args {
args.push(cx.conv_ty(&mut env, ty)?);
let mut inputs = vec![];
for ty in decl.inputs {
inputs.push(cx.conv_ty(&mut env, ty)?);
}

let output = cx.conv_fn_output(&mut env, &decl.output)?;
Expand All @@ -346,7 +345,7 @@ pub(crate) fn conv_fn_decl<'genv>(
.cloned()
.collect();

let res = rty::PolyFnSig::new(rty::FnSig::new(requires, args, output), vars);
let res = rty::PolyFnSig::new(rty::FnSig::new(requires.into(), inputs.into(), output), vars);
Ok(rty::EarlyBinder(res))
}

Expand Down Expand Up @@ -555,10 +554,10 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
env.push_layer(Layer::list(self, 0, output.params)?);

let ret = self.conv_ty(env, &output.ret)?;
let ensures: List<rty::Constraint> = output
let ensures: List<rty::Ensures> = output
.ensures
.iter()
.map(|constr| self.conv_constr(env, constr))
.map(|ens| self.conv_ensures(env, ens))
.try_collect()?;
let output = rty::FnOutput::new(ret, ensures);

Expand Down Expand Up @@ -648,31 +647,23 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
}
}

fn conv_constr(
&self,
env: &mut Env,
constr: &fhir::Constraint,
) -> QueryResult<rty::Constraint> {
match constr {
fhir::Constraint::Type(loc, ty) => {
let (idx, _) = loc.res.expect_loc_param();
Ok(rty::Constraint::Type(
env.lookup(loc).to_path(),
self.conv_ty(env, ty)?,
Local::from_usize(idx + 1),
))
}
fhir::Constraint::Pred(params, pred) => {
let pred = if params.is_empty() {
self.conv_expr(env, pred)?
} else {
env.push_layer(Layer::list(self, 0, params)?);
let pred = self.conv_expr(env, pred)?;
let sorts = env.pop_layer().into_bound_vars(self.genv)?;
rty::Expr::forall(rty::Binder::new(pred, sorts))
};
Ok(rty::Constraint::Pred(pred))
fn conv_requires(&self, env: &mut Env, requires: &fhir::Requires) -> QueryResult<rty::Expr> {
if requires.params.is_empty() {
self.conv_expr(env, &requires.pred)
} else {
env.push_layer(Layer::list(self, 0, requires.params)?);
let pred = self.conv_expr(env, &requires.pred)?;
let sorts = env.pop_layer().into_bound_vars(self.genv)?;
Ok(rty::Expr::forall(rty::Binder::new(pred, sorts)))
}
}

fn conv_ensures(&self, env: &mut Env, ensures: &fhir::Ensures) -> QueryResult<rty::Ensures> {
match ensures {
fhir::Ensures::Type(loc, ty) => {
Ok(rty::Ensures::Type(env.lookup(loc).to_path(), self.conv_ty(env, ty)?))
}
fhir::Ensures::Pred(pred) => Ok(rty::Ensures::Pred(self.conv_expr(env, pred)?)),
}
}

Expand Down Expand Up @@ -725,9 +716,10 @@ impl<'a, 'genv, 'tcx> ConvCtxt<'a, 'genv, 'tcx> {
Ok(rty::Ty::exists(rty::Binder::new(ty, sorts)))
}
}
fhir::TyKind::Ptr(lft, loc) => {
let region = self.conv_lifetime(env, *lft);
Ok(rty::Ty::ptr(rty::PtrKind::Mut(region), env.lookup(loc).to_path()))
fhir::TyKind::StrgRef(lft, loc, ty) => {
let re = self.conv_lifetime(env, *lft);
let ty = self.conv_ty(env, ty)?;
Ok(rty::Ty::strg_ref(re, env.lookup(loc).to_path(), ty))
}
fhir::TyKind::Ref(lft, fhir::MutTy { ty, mutbl }) => {
let region = self.conv_lifetime(env, *lft);
Expand Down
Loading

0 comments on commit a6baf28

Please sign in to comment.