Skip to content

Commit

Permalink
feat(frontends/lean/elaborator): backport Lean 4 field notation (#757)
Browse files Browse the repository at this point in the history
This backports a number of features from the Lean 4 field notation ("dot notation") and fixes a couple of bugs.

* Field notation not in a function application position (i.e., `x.f` instead of `x.f a b c`) now uses the general resolution procedure rather than a stripped-down one that is unaware of extended field notation. It also now correctly records its source position and correctly allows local recursive calls to have the recursive argument be after the first.
* If `f` is a function, then `f.foo` resolves to `function.foo f` (or more precisely `f` is inserted at the first function argument of `function.foo`).
* If `s` is a structure, then `s.foo` is resolved in the following way:
  * If `foo` is a field of `s`, then it resolves into a projection. (Unchanged.)
  * If the structure's namespace or one of its ancestor's namespaces contains `foo`, then resolves to that. (Change: now considers ancestors too.)
  * If the structure's namespace contains `foo` as an alias, then resolves to that. (New, though note it is not fully functional since Lean 3 alias resolution is not as capable as Lean 4's.)
* We leave the "insufficient number of arguments" error (deferring Lean 4's lambda synthesis to a potential future PR), but we improve error recovery, making it more likely that users can go-to definition on the field.

One breaking change (and a consequence of the first item) is illustrated by the following example:
```lean
structure bar := (f : ∀ {m : ℕ}, m = 0)

variables (s : bar)
#check (s.f : 37 = 0) -- formerly #check (s.f : ∀ {m : ℕ}, m = 0)
```
To get the old behavior, one can generally add lambdas, for example `(λ _, s.f : ∀ {m : ℕ}, m = 0)`.

To preserve reverse compatibility, we (temporarily?) add a feature that if the elaborator fails to find an explicit argument corresponding to the structure, it will insert the structure as the first explicit argument. This matches the previous behavior for non-application dot notation (`x.f`). The reason for including this feature is that mathlib has been misusing this type of dot notation for terms with `has_coe_to_fun` instances. It is not currently compatible with Lean 4. (Note that, when paired with the new aliases feature, this feature allows for extension methods.)
  • Loading branch information
kmill committed Aug 25, 2022
1 parent 141b46b commit f2b2bee
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 93 deletions.
200 changes: 124 additions & 76 deletions src/frontends/lean/elaborator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,41 +1876,85 @@ expr elaborator::visit_app_core(expr fn, buffer<expr> const & args, optional<exp
} else if (is_field_notation(fn) && amask == arg_mask::Default) {
expr s = visit(macro_arg(fn, 0), none_expr());
expr s_type = head_beta_reduce(instantiate_mvars(infer_type(s)));
auto field_res = find_field_fn(fn, s, s_type);
auto field_res = resolve_field_notation(fn, s, s_type);

expr proj, proj_type;
if (field_res.m_ldecl) {
proj = copy_tag(fn, field_res.m_ldecl->mk_ref());
proj_type = field_res.m_ldecl->get_type();
} else {
proj = copy_tag(fn, mk_constant(field_res.get_full_fname()));
proj_type = m_env.get(field_res.get_full_fname()).get_type();
switch (field_res.m_kind) {
case field_resolution::kind::ProjFn: {
// Projections are handled specially since `s` is inserted at a specific argument, whether
// or not it is explicit.
expr coerced_s = *mk_base_projections(m_env, field_res.get_struct_name(), field_res.get_base_struct_name(), mk_as_is(s));
expr proj_app = mk_proj_app(m_env, field_res.get_base_struct_name(), field_res.get_field_name(), coerced_s, fn);
expr new_proj = visit_function(proj_app, has_args, has_args ? none_expr() : expected_type, ref);
return visit_base_app(new_proj, arg_mask::Default, args, expected_type, ref);
}
case field_resolution::kind::LocalRec: {
s = mk_as_is(s);
proj = copy_tag(fn, field_res.m_ldecl.mk_ref());
proj_type = field_res.m_ldecl.get_type();
break;
}
case field_resolution::kind::Const: {
expr coerced_s = *mk_base_projections(m_env, field_res.get_struct_name(), field_res.get_base_struct_name(), mk_as_is(s));
s = copy_tag(s, std::move(coerced_s));
proj = copy_tag(fn, mk_constant(field_res.get_const_name()));
proj_type = m_env.get(field_res.get_full_name()).get_type();
break;
}
default: lean_unreachable();
}

expr new_proj = visit_function(proj, has_args, has_args ? none_expr() : expected_type, ref);

name base_name = field_res.get_base_struct_name();
buffer<expr> new_args;
unsigned i = 0;
optional<elaborator_exception> insufficient;
unsigned i = 0;
while (is_pi(proj_type)) {
if (is_explicit(binding_info(proj_type))) {
if (is_app_of(binding_domain(proj_type), field_res.m_base_S_name)) {
if (is_app_of(binding_domain(proj_type), base_name)
|| (is_pi(binding_domain(proj_type)) && base_name == get_function_name())) {
/* found s location */
expr coerced_s = *mk_base_projections(m_env, field_res.m_S_name, field_res.m_base_S_name, mk_as_is(s));
new_args.push_back(copy_tag(fn, std::move(coerced_s)));

if (insufficient) {
// We defer reporting this to here to prefer the "does not have explicit argument with type" error
report_or_throw(*insufficient);
}

new_args.push_back(s);
for (; i < args.size(); i++)
new_args.push_back(args[i]);
expr new_proj = visit_function(proj, has_args, has_args ? none_expr() : expected_type, ref);
return visit_base_app(new_proj, amask, new_args, expected_type, ref);
} else {
if (i >= args.size()) {
throw elaborator_exception(ref, sstream() << "invalid field notation, insufficient number of arguments for '"
<< field_res.get_full_fname() << "'");
}

return visit_base_app(new_proj, arg_mask::Default, new_args, expected_type, ref);
} else if (i < args.size()) {
new_args.push_back(args[i]);
i++;
} else {
if (!insufficient) {
insufficient = elaborator_exception(ref, sstream() << "invalid field notation, insufficient number of arguments for '"
<< field_res.get_full_name() << "'");
}
new_args.push_back(mk_sorry(none_expr(), fn));
}
}
proj_type = binding_body(proj_type);
}
throw elaborator_exception(ref, sstream() << "invalid field notation, function '"
<< field_res.get_full_fname() << "' does not have explicit argument with type ("
<< field_res.m_base_S_name << " ...)");

// If there is no explicit argument of the right type, we try inserting it as the first explicit argument.
// This gives some ability to use dot notation with terms that have a `has_coe_to_fun` instance.

new_args.clear();
new_args.push_back(s);
new_args.append(args);

try {
return visit_base_app(new_proj, arg_mask::Default, new_args, expected_type, ref);
} catch (elaborator_exception & ex) {
throw nested_elaborator_exception(ref, ex, format("invalid field notation, function '")
+ format(field_res.get_full_name())
+ format("' does not have explicit argument with type (")
+ format(field_res.get_base_struct_name()) + format(" ...)"));
}
} else {
expr new_fn = visit_function(fn, has_args, has_args ? none_expr() : expected_type, ref);
/* Check if we should use a custom elaboration procedure for this application. */
Expand Down Expand Up @@ -2685,42 +2729,47 @@ expr elaborator::visit_inaccessible(expr const & e, optional<expr> const & expec
return copy_tag(e, mk_inaccessible(new_a));
}

elaborator::field_resolution elaborator::field_to_decl(expr const & e, expr const & s, expr const & s_type) {
// prefer 'unknown identifier' error when lhs is a constant of non-value type
if (is_field_notation(e)) {
auto lhs = macro_arg(e, 0);
if (is_constant(lhs)) {
type_context_old::tmp_locals locals(m_ctx);
expr t = whnf(s_type);
while (is_pi(t)) {
t = whnf(instantiate(binding_body(t), locals.push_local_from_binding(t)));
}
if (is_sort(t) && !is_anonymous_field_notation(e)) {
name fname = get_field_notation_field_name(e);
throw elaborator_exception(lhs, format("unknown identifier '") + format(const_name(lhs)) + format(".") +
format(fname) + format("'"));
}
elaborator::field_resolution elaborator::resolve_field_notation_aux(expr const & e, expr const & s, expr const & s_type) {
lean_assert(is_field_notation(e));

// If it's a function, resolve the field as a declaration in the `function` namespace.
if (is_pi(s_type) && !is_anonymous_field_notation(e)) {
auto fname = get_field_notation_field_name(e);
auto full_fname = get_function_name() + fname;
if (m_env.find(full_fname)) {
return field_resolution_const(get_function_name(), get_function_name(), full_fname);
}
}
expr I = get_app_fn(s_type);

expr I = get_app_fn(s_type);

if (!is_constant(I)) {
auto pp_fn = mk_pp_ctx();
// prefer 'unknown identifier' error when lhs (the unelaborated s) is a constant of non-value type
auto lhs = macro_arg(e, 0);
if (!is_anonymous_field_notation(e) && is_constant(lhs)) {
name fname = get_field_notation_field_name(e);
throw elaborator_exception(lhs, format("unknown identifier '") + format(const_name(lhs) + fname) + format("'"));
}
throw elaborator_exception(e, format("invalid field notation, type is not of the form (C ...) where C is a constant") +
pp_indent(pp_fn, s) +
line() + format("has type") +
pp_indent(pp_fn, s_type));
}

auto struct_name = const_name(I);

if (is_anonymous_field_notation(e)) {
if (!is_structure(m_env, const_name(I))) {
if (!is_structure(m_env, struct_name)) {
auto pp_fn = mk_pp_ctx();
throw elaborator_exception(e, format("invalid projection, structure expected") +
pp_indent(pp_fn, s) +
line() + format("has type") +
pp_indent(pp_fn, s_type));
}
auto fnames = get_structure_fields(m_env, const_name(I));
auto fnames = get_structure_fields(m_env, struct_name);
unsigned fidx = get_field_notation_field_idx(e);
if (fidx == 0) {
if (fidx == 0) {
throw elaborator_exception(e, "invalid projection, index must be greater than 0");
}
if (fidx > fnames.size()) {
Expand All @@ -2731,37 +2780,46 @@ elaborator::field_resolution elaborator::field_to_decl(expr const & e, expr cons
line() + format("which has type") +
pp_indent(pp_fn, s_type));
}
return const_name(I) + fnames[fidx-1];
return field_resolution_proj_fn(struct_name, struct_name, fnames[fidx-1]);
} else {
name fname = get_field_notation_field_name(e);
// search for "true" fields first, including in parent structures
if (is_structure_like(m_env, const_name(I)))
if (auto p = find_field(m_env, const_name(I), fname))
return field_resolution(const_name(I), *p, fname);
name full_fname = const_name(I) + fname;
name local_name = full_fname.replace_prefix(get_namespace(env()), {});
if (auto ldecl = m_ctx.lctx().find_if([&](local_decl const & decl) {
return decl.get_info().is_rec() && decl.get_pp_name() == local_name;
})) {
// projection is recursive call
return field_resolution(full_fname, ldecl);
name fname = get_field_notation_field_name(e);

// Search for "true" fields first, including in parent structures
if (is_structure_like(m_env, struct_name))
if (auto p = find_field(m_env, struct_name, fname))
return field_resolution_proj_fn(*p, struct_name, fname);

// Check if field notation is being used to make a "local" recursive call.
name full_fname = struct_name + fname;
name local_name = full_fname.replace_prefix(get_namespace(m_env), {});
if (auto ldecl = m_ctx.lctx().find_if([&](local_decl const & decl) { return decl.get_info().is_rec() && decl.get_pp_name() == local_name; })) {
return field_resolution_local_rec(struct_name, full_fname, *ldecl);
}
if (!m_env.find(full_fname)) {
auto pp_fn = mk_pp_ctx();
throw elaborator_exception(e, format("invalid field notation, '") + format(fname) + format("'") +
format(" is not a valid \"field\" because environment does not contain ") +
format("'") + format(full_fname) + format("'") +
pp_indent(pp_fn, s) +
line() + format("which has type") +
pp_indent(pp_fn, s_type));

// Otherwise we search the environment for this "extended" dot notation

if (auto m = find_method(m_env, struct_name, fname)) {
return field_resolution_const(m->first, struct_name, m->second);
}

if (auto m = find_method_alias(m_env, struct_name, fname)) {
return field_resolution_const(m->first, m->first, m->second);
}
return full_fname;

auto pp_fn = mk_pp_ctx();
throw elaborator_exception(e, format("invalid field notation, '") + format(fname) + format("'") +
format(" is not a valid \"field\" because environment does not contain ") +
format("'") + format(struct_name + fname) + format("'") +
pp_indent(pp_fn, s) +
line() + format("which has type") +
pp_indent(pp_fn, s_type));
}
}

elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr const & s, expr const & s_type) {
elaborator::field_resolution elaborator::resolve_field_notation(expr const & e, expr const & s, expr const & s_type) {
lean_assert(is_field_notation(e));
try {
return field_to_decl(e, s, s_type);
return resolve_field_notation_aux(e, s, s_type);
} catch (elaborator_exception & ex1) {
expr new_s_type = s_type;
if (auto d = unfold_term(env(), new_s_type))
Expand All @@ -2770,7 +2828,7 @@ elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr cons
if (new_s_type == s_type)
throw;
try {
return find_field_fn(e, s, new_s_type);
return resolve_field_notation(e, s, new_s_type);
} catch (elaborator_exception & ex2) {
throw nested_elaborator_exception(ex2.get_pos(), ex1, ex2.pp());
}
Expand All @@ -2779,17 +2837,7 @@ elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr cons

expr elaborator::visit_field(expr const & e, optional<expr> const & expected_type) {
lean_assert(is_field_notation(e));
expr s = visit(macro_arg(e, 0), none_expr());
expr s_type = head_beta_reduce(instantiate_mvars(infer_type(s)));
auto field_res = find_field_fn(e, s, s_type);
expr proj_app;
if (field_res.m_ldecl) {
proj_app = copy_tag(e, mk_app(field_res.m_ldecl->mk_ref(), mk_as_is(s)));
} else {
expr new_e = *mk_base_projections(m_env, field_res.m_S_name, field_res.m_base_S_name, mk_as_is(s));
proj_app = mk_proj_app(m_env, field_res.m_base_S_name, field_res.m_fname, new_e, e);
}
return visit(proj_app, expected_type);
return visit_app_core(e, buffer<expr>(), expected_type, e);
}

class reduce_projections_visitor : public replace_visitor {
Expand Down
Loading

0 comments on commit f2b2bee

Please sign in to comment.