Skip to content

Commit

Permalink
feat(frontend/lean/notation_cmds.cpp): notation (name := ...) syntax (
Browse files Browse the repository at this point in the history
#754)

This is an attempt to solve the issues in leanprover-community/mathport#158 once and for all. The main new user-facing behavior is that notations have names, and if you have an overlapping name the definition is rejected. That means that the following is now rejected:
```lean
notation `foo` := nat
notation `foo` := nat
```
To fix the issue, name one or the other using the new `(name := ...)` syntax:
```lean
notation `foo` := nat
notation (name := bar) `foo` := nat
```
Unlike declaration names, notation names are only required to be distinct within a scope. So the following is legal:
```lean
section
local notation (name := foo) `foo` := nat
end
local notation (name := foo) `foo` := nat
```

Reserved notations do not have names / do not cause name conflicts with regular notations, although you are syntactically allowed to name them.
  • Loading branch information
digama0 committed Aug 17, 2022
1 parent 741670c commit 31f3a46
Show file tree
Hide file tree
Showing 45 changed files with 164 additions and 95 deletions.
6 changes: 3 additions & 3 deletions library/init/algebra/classes.lean
Expand Up @@ -155,7 +155,7 @@ is, `is_preorder X r` and `is_symm X r`. -/
is, `is_symm X r` and `is_trans X r`. -/
@[algebra] class is_per (α : Type u) (r : α → α → Prop) extends is_symm α r, is_trans α r : Prop.

/-- `is_strict_order X r` means that the binary relation `r` on `X` is a strict order, that is,
/-- `is_strict_order X r` means that the binary relation `r` on `X` is a strict order, that is,
`is_irrefl X r` and `is_trans X r`. -/
@[algebra] class is_strict_order (α : Type u) (r : α → α → Prop) extends
is_irrefl α r, is_trans α r : Prop.
Expand All @@ -170,7 +170,7 @@ that is, `is_strict_order X lt` and `is_incomp_trans X lt`. -/
@[algebra] class is_strict_weak_order (α : Type u) (lt : α → α → Prop) extends
is_strict_order α lt, is_incomp_trans α lt : Prop.

/-- `is_trichotomous X lt` means that the binary relation `lt` on `X` is trichotomous, that is,
/-- `is_trichotomous X lt` means that the binary relation `lt` on `X` is trichotomous, that is,
either `lt a b` or `a = b` or `lt b a` for any `a` and `b`. -/
@[algebra] class is_trichotomous (α : Type u) (lt : α → α → Prop) : Prop :=
(trichotomous : ∀ a b, lt a b ∨ a = b ∨ lt b a)
Expand Down Expand Up @@ -258,7 +258,7 @@ def equiv (a b : α) : Prop :=

parameter [is_strict_weak_order α r]

local infix ` ≈ `:50 := equiv
local infix (name := equiv) ` ≈ `:50 := equiv

lemma erefl (a : α) : a ≈ a :=
⟨irrefl a, irrefl a⟩
Expand Down
6 changes: 3 additions & 3 deletions library/init/logic.lean
Expand Up @@ -1066,10 +1066,10 @@ variables {α : Type u} {β : Type v}
variable f : α → α → α
variable inv : α → α
variable one : α
local notation a * b := f a b
local notation a ⁻¹ := inv a
local notation (name := f) a * b := f a b
local notation (name := inv) a ⁻¹ := inv a
variable g : α → α → α
local notation a + b := g a b
local notation (name := g) a + b := g a b

def commutative := ∀ a b, a * b = b * a
def associative := ∀ a b c, (a * b) * c = a * (b * c)
Expand Down
84 changes: 59 additions & 25 deletions src/frontends/lean/notation_cmd.cpp
Expand Up @@ -111,10 +111,35 @@ static optional<unsigned> get_precedence(environment const & env, char const * t
return get_expr_precedence(get_token_table(env), tk);
}

void check_notation_name(environment const & env, notation_entry_group grp,
const pos_info & pos, name const & name, bool was_anon) {
if (grp == notation_entry_group::Reserve || !has_notation(env, name)) return;
if (was_anon)
throw parser_error(sstream() <<
"invalid notation: notation already declared. Consider using 'notation (name := ...)'", pos);
else
throw parser_error(sstream() <<
"invalid notation: notation '" << name << "' already declared", pos);
}

static pair<ast_id, name> parse_optional_name(parser & p) {
if (!p.curr_is_token(get_lparen_tk())) return {0, {}};
p.next();
auto tk = p.check_id_next("invalid notation declaration, expected 'name'");
if (tk.second != get_name_tk())
p.maybe_throw_error({"invalid notation declaration, expected 'name'", p.get_ast(tk.first).m_start});
p.check_token_next(get_assign_tk(), "invalid notation declaration, expected ':='");
auto r = p.check_id_next("invalid notation declaration, expected identifier");
p.check_token_next(get_rparen_tk(), "invalid notation declaration, expected ')'");
return r;
}

static auto parse_mixfix_notation(parser & p, ast_data & parent, mixfix_kind k, bool overload, notation_entry_group grp, bool parse_only,
unsigned priority)
-> pair<notation_entry, optional<token_entry>> {
bool explicit_pp = p.curr_is_quoted_symbol();
auto name = parse_optional_name(p);
parent.push(name.first);
pos_info tk_pos = p.pos();
std::string pp_tk = parse_symbol(p, parent, "invalid notation declaration, quoted symbol or identifier expected");
std::string tk = utf8_trim(pp_tk);
Expand Down Expand Up @@ -211,26 +236,29 @@ static auto parse_mixfix_notation(parser & p, ast_data & parent, mixfix_kind k,
if (reserved_action && !explicit_pp)
pp_tk = reserved_transition->get_pp_token().to_string_unescaped();

bool is_nud = k == mixfix_kind::prefix;
list<transition> ts;
switch (k) {
case mixfix_kind::infixl:
ts = to_list(transition(tks, mk_expr_action(*prec), pp_tk));
break;
case mixfix_kind::infixr:
ts = to_list(transition(tks, mk_expr_action(*prec), pp_tk));
break;
case mixfix_kind::postfix:
ts = to_list(transition(tks, mk_skip_action(), pp_tk));
break;
case mixfix_kind::prefix:
ts = to_list(transition(tks, mk_expr_action(*prec), pp_tk));
break;
}
expr e;
if (grp == notation_entry_group::Reserve) {
// reserve notation commands do not have a denotation
parent.push(0);
expr dummy = mk_Prop();
e = mk_Prop();
if (p.curr_is_token(get_assign_tk()))
throw parser_error("invalid reserve notation, found `:=`", p.pos());
switch (k) {
case mixfix_kind::infixl:
return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
dummy, overload, priority, grp, parse_only), new_token);
case mixfix_kind::infixr:
return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
dummy, overload, priority, grp, parse_only), new_token);
case mixfix_kind::postfix:
return mk_pair(notation_entry(false, to_list(transition(tks, mk_skip_action(), pp_tk)),
dummy, overload, priority, grp, parse_only), new_token);
case mixfix_kind::prefix:
return mk_pair(notation_entry(true, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
dummy, overload, priority, grp, parse_only), new_token);
}
} else {
p.check_token_next(get_assign_tk(), "invalid notation declaration, ':=' expected");
auto f_pos = p.pos();
Expand All @@ -242,20 +270,22 @@ static auto parse_mixfix_notation(parser & p, ast_data & parent, mixfix_kind k,
#if defined(__GNUC__) && !defined(__CLANG__)
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
mk_app(f, Var(1), Var(0)), overload, priority, grp, parse_only), new_token);
e = mk_app(f, Var(1), Var(0));
break;
case mixfix_kind::infixr:
return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
mk_app(f, Var(1), Var(0)), overload, priority, grp, parse_only), new_token);
e = mk_app(f, Var(1), Var(0));
break;
case mixfix_kind::postfix:
return mk_pair(notation_entry(false, to_list(transition(tks, mk_skip_action(), pp_tk)),
mk_app(f, Var(0)), overload, priority, grp, parse_only), new_token);
e = mk_app(f, Var(0));
break;
case mixfix_kind::prefix:
return mk_pair(notation_entry(true, to_list(transition(tks, mk_expr_action(*prec), pp_tk)),
mk_app(f, Var(0)), overload, priority, grp, parse_only), new_token);
e = mk_app(f, Var(0));
break;
}
}
lean_unreachable(); // LCOV_EXCL_LINE
notation_entry entry(is_nud, ts, e, overload, priority, grp, parse_only, name.second);
check_notation_name(p.env(), grp, tk_pos, entry.get_name(), name.second.is_anonymous());
return mk_pair(entry, new_token);
}

static notation_entry parse_mixfix_notation(parser & p, ast_data & data, mixfix_kind k, bool overload, notation_entry_group grp,
Expand Down Expand Up @@ -521,6 +551,8 @@ static notation_entry parse_notation_core(parser & p, ast_data & parent, bool ov
bool is_nud = true;
optional<parse_table> pt;
optional<parse_table> reserved_pt;
auto notation_name = parse_optional_name(p);
parent.push(notation_name.first);
auto& args = p.new_ast("args", p.pos());
parent.push(args.m_id);
if (p.curr_is_numeral()) {
Expand Down Expand Up @@ -661,7 +693,9 @@ static notation_entry parse_notation_core(parser & p, ast_data & parent, bool ov
std::tie(id, n) = parse_notation_expr(p, locals);
parent.push(id);
}
return notation_entry(is_nud, to_list(ts.begin(), ts.end()), n, overload, priority, grp, parse_only);
notation_entry entry(is_nud, to_list(ts.begin(), ts.end()), n, overload, priority, grp, parse_only, notation_name.second);
check_notation_name(p.env(), grp, parent.m_start, entry.get_name(), notation_name.second.is_anonymous());
return entry;
}

bool curr_is_notation_decl(parser & p) {
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/lean/parser.h
Expand Up @@ -429,7 +429,7 @@ class parser : public abstract_parser, public parser_info {
/** \brief Lookahead version of \c curr_is_token. See \c ahead(). */
bool ahead_is_token(name const & tk, int n = 0);

/** \brief Check current token, and move to next characther, throw exception if current token is not \c tk. Returns true if succesful. */
/** \brief Check current token, and move to next character, throw exception if current token is not \c tk. Returns true if succesful. */
bool check_token_next(name const & tk, char const * msg);
void check_token_or_id_next(name const & tk, char const * msg);
/** \brief Check if the current token is an identifier, if it is return it and move to next token,
Expand Down
13 changes: 12 additions & 1 deletion src/frontends/lean/parser_config.cpp
Expand Up @@ -62,7 +62,7 @@ notation_entry::notation_entry(bool is_nud, list<transition> const & ts, expr co
m_name(n), m_expr(e), m_overload(overload), m_group(g), m_parse_only(parse_only), m_priority(priority) {
new (&m_transitions) list<transition>(ts);
m_safe_ascii = std::all_of(ts.begin(), ts.end(), [](transition const & t) { return t.is_safe_ascii(); });
if (n.is_anonymous()) m_name = heuristic_name();
if (g == notation_entry_group::Main && n.is_anonymous()) m_name = heuristic_name();
}
notation_entry::notation_entry(notation_entry const & e, bool overload):
notation_entry(e) {
Expand Down Expand Up @@ -264,9 +264,11 @@ struct notation_prio_fn { unsigned operator()(notation_entry const & v) const {
struct notation_state {
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> num_map;
typedef head_map_prio<notation_entry, notation_prio_fn> head_to_entries;
typedef name_set notation_names;
parse_table m_nud;
parse_table m_led;
num_map m_num_map;
name_set m_notation_names;
head_to_entries m_inv_map;
// The following two tables are used to implement `reserve notation` commands
parse_table m_reserved_nud;
Expand Down Expand Up @@ -306,6 +308,11 @@ struct notation_config {
}

static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
if (e.group() == notation_entry_group::Main) {
if (s.m_notation_names.contains(e.get_name()))
throw exception(sstream() << "invalid notation, a notation named '" << e.get_name() << "' has already been declared");
s.m_notation_names.insert(e.get_name());
}
buffer<transition> ts;
switch (e.kind()) {
case notation_entry_kind::NuD: {
Expand Down Expand Up @@ -452,6 +459,10 @@ parse_table const & get_reserved_led_table(environment const & env) {
return notation_ext::get_state(env).m_reserved_led;
}

bool has_notation(environment const & env, name const & n) {
return notation_ext::get_state(env).m_notation_names.contains(n);
}

environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload, bool parse_only) {
return add_notation(env, notation_entry(n, e, overload, parse_only));
}
Expand Down
1 change: 1 addition & 0 deletions src/frontends/lean/parser_config.h
Expand Up @@ -80,6 +80,7 @@ parse_table const & get_led_table(environment const & env);
parse_table const & get_reserved_nud_table(environment const & env);
parse_table const & get_reserved_led_table(environment const & env);
cmd_table const & get_cmd_table(environment const & env);
bool has_notation(environment const & env, name const & n);
environment add_command(environment const & env, name const & n, cmd_info const & info);

/** \brief Add \c n as notation for \c e */
Expand Down
4 changes: 4 additions & 0 deletions src/frontends/lean/tokens.cpp
Expand Up @@ -115,6 +115,7 @@ static name const * g_infixr_tk = nullptr;
static name const * g_postfix_tk = nullptr;
static name const * g_prefix_tk = nullptr;
static name const * g_notation_tk = nullptr;
static name const * g_name_tk = nullptr;
static name const * g_calc_tk = nullptr;
static name const * g_root_tk = nullptr;
static name const * g_fields_tk = nullptr;
Expand Down Expand Up @@ -241,6 +242,7 @@ void initialize_tokens() {
g_postfix_tk = new name{"postfix"};
g_prefix_tk = new name{"prefix"};
g_notation_tk = new name{"notation"};
g_name_tk = new name{"name"};
g_calc_tk = new name{"calc"};
g_root_tk = new name{"_root_"};
g_fields_tk = new name{"fields"};
Expand Down Expand Up @@ -368,6 +370,7 @@ void finalize_tokens() {
delete g_postfix_tk;
delete g_prefix_tk;
delete g_notation_tk;
delete g_name_tk;
delete g_calc_tk;
delete g_root_tk;
delete g_fields_tk;
Expand Down Expand Up @@ -494,6 +497,7 @@ name const & get_infixr_tk() { return *g_infixr_tk; }
name const & get_postfix_tk() { return *g_postfix_tk; }
name const & get_prefix_tk() { return *g_prefix_tk; }
name const & get_notation_tk() { return *g_notation_tk; }
name const & get_name_tk() { return *g_name_tk; }
name const & get_calc_tk() { return *g_calc_tk; }
name const & get_root_tk() { return *g_root_tk; }
name const & get_fields_tk() { return *g_fields_tk; }
Expand Down
1 change: 1 addition & 0 deletions src/frontends/lean/tokens.h
Expand Up @@ -117,6 +117,7 @@ name const & get_infixr_tk();
name const & get_postfix_tk();
name const & get_prefix_tk();
name const & get_notation_tk();
name const & get_name_tk();
name const & get_calc_tk();
name const & get_root_tk();
name const & get_fields_tk();
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/lean/tokens.txt
Expand Up @@ -110,6 +110,7 @@ infixr infixr
postfix postfix
prefix prefix
notation notation
name name
calc calc
root _root_
fields fields
Expand All @@ -118,6 +119,7 @@ inductive inductive
instance instance
this this
noncomputable noncomputable
exclam !
theory theory
key_equivalences key_equivalences
using using
Expand Down
5 changes: 3 additions & 2 deletions src/library/tactic/backward/backward_lemmas.cpp
Expand Up @@ -20,6 +20,7 @@ Author: Leonardo de Moura
#include "library/tactic/tactic_state.h"
#include "library/tactic/backward/backward_lemmas.h"
#include "frontends/lean/parser.h"
#include "frontends/lean/tokens.h"

namespace lean {
static optional<head_index> get_backward_target(type_context_old & ctx, expr type) {
Expand Down Expand Up @@ -54,10 +55,10 @@ struct intro_attr_data : public attr_data {

ast_id parse(abstract_parser & p) override {
ast_id r = 0;
if (p.curr_is_token("!")) {
if (p.curr_is_token(get_exclam_tk())) {
lean_assert(dynamic_cast<parser *>(&p));
auto& p2 = *static_cast<parser *>(&p);
r = p2.new_ast("!", p2.pos()).m_id;
r = p2.new_ast(get_exclam_tk(), p2.pos()).m_id;
p2.next();
m_eager = true;
}
Expand Down
4 changes: 2 additions & 2 deletions tests/lean/712.lean
Expand Up @@ -6,11 +6,11 @@ local infix `~~~` := eq

#print notation ~~~

local infix `~~~`:50 := eq
local infix (name := eq2) `~~~`:50 := eq

#print notation ~~~

local infix `~~~`:100 := eq
local infix (name := eq3) `~~~`:100 := eq

infix `~~~`:100 := eq -- FAIL

Expand Down
2 changes: 1 addition & 1 deletion tests/lean/assertion1.lean
Expand Up @@ -17,7 +17,7 @@ structure Functor (C : Category.{ u1 v1 }) (D : Category.{ u2 v2 }) :=
}

namespace ProductCategory
notation C `×` D := ProductCategory C D
notation (name := prod) C `×` D := ProductCategory C D
end ProductCategory

@[reducible] definition TensorProduct ( C: Category ) := Functor ( C × C ) C
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/bad_quoted_symbol.lean
@@ -1,4 +1,4 @@
notation a ` \/ ` b := a ∨ b
notation (name := or2) a ` \/ ` b := a ∨ b
notation a `1\/` b := a ∨ b
notation a ` 1\/` b := a ∨ b
notation a ` \ / ` b := a ∨ b
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/calc1.lean
Expand Up @@ -41,7 +41,7 @@ attribute [trans] le_lt_trans
... < d : H5

constant le2 : A → A → bool
infixl ` ≤ `:50 := le2
infixl (name := le2) ` ≤ `:50 := le2
constant le2_trans (a b c : A) (H1 : le2 a b) (H2 : le2 b c) : le2 a c
attribute [trans] le2_trans
-- print raw calc b ≤ c : H2
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/hole_issue2.lean
Expand Up @@ -13,7 +13,7 @@ noncomputable definition count {A} (a : A) (b : bag A) : nat :=
quotient.lift_on b (λ l, list.count a l)
(λ l₁ l₂ h, sorry)
definition subbag {A} (b₁ b₂ : bag A) := ∀ a, count a b₁ ≤ count a b₂
infix ` ⊆ ` := subbag
infix (name := subbag) ` ⊆ ` := subbag

noncomputable definition decidable_subbag_1 {A} (b₁ b₂ : bag A) : decidable (b₁ ⊆ b₂) :=
quotient.rec_on_subsingleton₂ b₁ b₂ (λ l₁ l₂,
Expand Down
10 changes: 5 additions & 5 deletions tests/lean/local_notation_meta_bug.lean
@@ -1,5 +1,5 @@
local infix ` + ` := nat.add
@[class] local infix ` + ` := nat.add
noncomputable local infix ` + ` := nat.add
@[class] noncomputable local infix ` + ` := nat.add
/-- foo -/ local infix ` + ` := nat.add
local infix (name := plus1) ` + ` := nat.add
@[class] local infix (name := plus2) ` + ` := nat.add
noncomputable local infix (name := plus3) ` + ` := nat.add
@[class] noncomputable local infix (name := plus4) ` + ` := nat.add
/-- foo -/ local infix (name := plus5) ` + ` := nat.add
4 changes: 2 additions & 2 deletions tests/lean/nary_overload.lean
Expand Up @@ -7,8 +7,8 @@ constant lst.nil {A : Type} : lst A
constant vec.cons {A : Type} : A → vec A → vec A
constant lst.cons {A : Type} : A → lst A → lst A

notation `[` l:(foldr `, ` (h t, vec.cons h t) vec.nil `]`) := l
notation `[` l:(foldr `, ` (h t, lst.cons h t) lst.nil `]`) := l
notation (name := list1) `[` l:(foldr `, ` (h t, vec.cons h t) vec.nil `]`) := l
notation (name := list2) `[` l:(foldr `, ` (h t, lst.cons h t) lst.nil `]`) := l

constant A : Type
variables a b c : A
Expand Down
9 changes: 7 additions & 2 deletions tests/lean/notation2.lean
@@ -1,6 +1,11 @@
--
inductive List (T : Type) : Type | nil {} : List | cons : T → List → List open List notation h :: t := cons h t notation `[` l:(foldr `,` (h t, cons h t) nil) `]` := l
infixr `::` := cons
inductive List (T : Type) : Type
| nil {} : List
| cons : T → List → List
open List
notation (name := cons2) h :: t := cons h t
notation (name := list2) `[` l:(foldr `,` (h t, cons h t) nil) `]` := l
infixr (name := cons) `::` := cons
#check (1:nat) :: 2 :: nil
#check (1:nat) :: 2 :: 3 :: 4 :: 5 :: nil
#print ]

0 comments on commit 31f3a46

Please sign in to comment.