Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 25e414d

Browse files
digama0robertylewis
andcommitted
feat(tactic/linarith): nlinarith tactic (#2637)
Based on Coq's [nra](https://coq.inria.fr/refman/addendum/micromega.html#coq:tacn.nra) tactic, and requested on [Zulip](https://leanprover.zulipchat.com/#narrow/stream/113488-general/topic/nonlinear.20linarith). Co-authored-by: Rob Lewis <rob.y.lewis@gmail.com> Co-authored-by: Rob Lewis <Rob.y.lewis@gmail.com>
1 parent 2e752e1 commit 25e414d

File tree

4 files changed

+159
-1
lines changed

4 files changed

+159
-1
lines changed

src/data/bool.lean

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,11 @@ lemma bxor_iff_ne : ∀ {x y : bool}, bxor x y = tt ↔ x ≠ y := dec_trivial
145145
lemma bnot_inj : ∀ {a b : bool}, !a = !b → a = b := dec_trivial
146146

147147
end bool
148+
149+
instance : decidable_linear_order bool :=
150+
begin
151+
constructor,
152+
show bool → bool → Prop,
153+
{ exact λ a b, a = ff ∨ b = tt },
154+
all_goals {apply_instance <|> exact dec_trivial}
155+
end

src/data/list/defs.lean

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,19 @@ def mmap_filter {m : Type → Type v} [monad m] {α β} (f : α → m (option β
524524
| (h :: t) := do b ← f h, t' ← t.mmap_filter, return $
525525
match b with none := t' | (some x) := x::t' end
526526

527+
/--
528+
`mmap'_diag f l` calls `f` on all elements in the "upper diagonal" of `l × l`.
529+
That is, for each `e ∈ l`, it will run `f e e` and then `f e e'`
530+
for each `e'` that appears after `e` in `l`.
531+
532+
Example: suppose `l = [1, 2, 3]`. `mmap'_diag f l` will evaluate, in this order,
533+
`f 1 1`, `f 1 2`, `f 1 3`, `f 2 2`, `f 2 3`, `f 3 3`.
534+
-/
535+
def mmap'_diag {m} [monad m] {α} (f : α → α → m unit) : list α → m unit
536+
| [] := return ()
537+
| (h::t) := f h h >> t.mmap' (f h) >> t.mmap'_diag
538+
539+
527540
protected def traverse {F : Type u → Type v} [applicative F] {α β : Type*} (f : α → F β) :
528541
list α → F (list β)
529542
| [] := pure []

src/tactic/linarith.lean

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,27 @@ do l' ← replace_nat_pfs l,
929929

930930
end normalize
931931

932+
/--
933+
`find_squares m e` collects all terms of the form `a ^ 2` and `a * a` that appear in `e`
934+
and adds them to the set `m`.
935+
A pair `(a, tt)` is added to `m` when `a^2` appears in `e`, and `(a, ff)` is added to `m`
936+
when `a*a` appears in `e`. -/
937+
meta def find_squares : rb_set (expr × bool) → expr → tactic (rb_set (expr × bool))
938+
| s `(%%a ^ 2) := do s ← find_squares s a, return (s.insert (a, tt))
939+
| s e@`(%%e1 * %%e2) := if e1 = e2 then do s ← find_squares s e1, return (s.insert (e1, ff)) else e.mfoldl find_squares s
940+
| s e := e.mfoldl find_squares s
941+
942+
-- used in the `nlinarith` normalization steps. The `_` argument is for uniformity.
943+
@[nolint unused_arguments]
944+
lemma mul_zero_eq {α} {R : α → α → Prop} [semiring α] {a b : α} (_ : R a 0) (h : b = 0) : a * b = 0 :=
945+
by simp [h]
946+
947+
-- used in the `nlinarith` normalization steps. The `_` argument is for uniformity.
948+
@[nolint unused_arguments]
949+
lemma zero_mul_eq {α} {R : α → α → Prop} [semiring α] {a b : α} (h : a = 0) (_ : R b 0) : a * b = 0 :=
950+
by simp [h]
951+
952+
932953
end linarith
933954

934955
section
@@ -1045,11 +1066,78 @@ optional arguments:
10451066
hypotheses.
10461067
* If `exfalso` is false, `linarith` will fail when the goal is neither an inequality nor `false`.
10471068
(True by default.)
1069+
1070+
A variant, `nlinarith`, does some basic preprocessing to handle some nonlinear goals.
10481071
-/
10491072
add_tactic_doc
10501073
{ name := "linarith",
10511074
category := doc_category.tactic,
10521075
decl_names := [`tactic.interactive.linarith],
10531076
tags := ["arithmetic", "decision procedure", "finishing"] }
10541077

1078+
/--
1079+
An extension of `linarith` with some preprocessing to allow it to solve some nonlinear arithmetic
1080+
problems. (Based on Coq's `nra` tactic.) See `linarith` for the available syntax of options,
1081+
which are inherited by `nlinarith`; that is, `nlinarith!` and `nlinarith only [h1, h2]` all work as
1082+
in `linarith`. The preprocessing is as follows:
1083+
1084+
* For every subterm `a ^ 2` or `a * a` in a hypothesis or the goal,
1085+
the assumption `0 ≤ a ^ 2` or `0 ≤ a * a` is added to the context.
1086+
* For every pair of hypotheses `a1 R1 b1`, `a2 R2 b2` in the context, `R1, R2 ∈ {<, ≤, =}`,
1087+
the assumption `0 R' (b1 - a1) * (b2 - a2)` is added to the context (non-recursively),
1088+
where `R ∈ {<, ≤, =}` is the appropriate comparison derived from `R1, R2`.
1089+
-/
1090+
meta def tactic.interactive.nlinarith (red : parse ((tk "!")?))
1091+
(restr : parse ((tk "only")?)) (hyps : parse pexpr_list?)
1092+
(cfg : linarith_config := {}) : tactic unit := do
1093+
ls ← match hyps with
1094+
| none := if restr.is_some then return [] else local_context
1095+
| some hyps := do
1096+
ls ← hyps.mmap i_to_expr,
1097+
if restr.is_some then return ls else (++ ls) <$> local_context
1098+
end,
1099+
(s, ge0) ← (list.mfoldr (λ h ⟨s, l⟩, do
1100+
h ← infer_type h >>= rearr_comp h <|> return h,
1101+
t ← infer_type h,
1102+
s ← find_squares s t,
1103+
return (s, match t with
1104+
| `(%%a ≤ 0) := (ineq.le, h) :: l
1105+
| `(%%a < 0) := (ineq.lt, h) :: l
1106+
| `(%%a = 0) := (ineq.eq, h) :: l
1107+
| _ := l end))
1108+
(mk_rb_set, []) ls : tactic (rb_set (expr × bool) × list (ineq × expr))),
1109+
s ← target >>= find_squares s,
1110+
(hyps, ge0) ← s.fold (return (hyps, ge0)) (λ ⟨e, is_sq⟩ tac, do
1111+
(hyps, ge0) ← tac,
1112+
(do
1113+
t ← infer_type e,
1114+
when cfg.restrict_type.is_some
1115+
(is_def_eq `(some %%t : option Type) cfg.restrict_type_reflect),
1116+
p ← mk_app (if is_sq then ``pow_two_nonneg else ``mul_self_nonneg) [e],
1117+
p ← infer_type p >>= rearr_comp p <|> return p,
1118+
t ← infer_type p,
1119+
h ← assertv `h t p,
1120+
return (hyps.map (λ l, pexpr.of_expr h :: l), (ineq.le, h) :: ge0)) <|>
1121+
return (hyps, ge0)),
1122+
ge0.mmap'_diag (λ ⟨posa, a⟩ ⟨posb, b⟩, do
1123+
p ← match posa, posb with
1124+
| ineq.eq, _ := mk_app ``zero_mul_eq [a, b]
1125+
| _, ineq.eq := mk_app ``mul_zero_eq [a, b]
1126+
| ineq.lt, ineq.lt := mk_app ``mul_pos_of_neg_of_neg [a, b]
1127+
| ineq.lt, ineq.le := do a ← mk_app ``le_of_lt [a], mk_app ``mul_nonneg_of_nonpos_of_nonpos [a, b]
1128+
| ineq.le, ineq.lt := do b ← mk_app ``le_of_lt [b], mk_app ``mul_nonneg_of_nonpos_of_nonpos [a, b]
1129+
| ineq.le, ineq.le := mk_app ``mul_nonneg_of_nonpos_of_nonpos [a, b]
1130+
end,
1131+
t ← infer_type p,
1132+
assertv `h t p, skip),
1133+
tactic.interactive.linarith red restr hyps cfg
1134+
1135+
add_hint_tactic "nlinarith"
1136+
1137+
add_tactic_doc
1138+
{ name := "nlinarith",
1139+
category := doc_category.tactic,
1140+
decl_names := [`tactic.interactive.nlinarith],
1141+
tags := ["arithmetic", "decision procedure", "finishing"] }
1142+
10551143
end

test/linarith.lean

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,30 @@ example (x y : ℚ) (h : x < y) : x ≠ y := by linarith
170170

171171
example (x y : ℚ) (h : x < y) : ¬ x = y := by linarith
172172

173-
lemma test6 (u v x y A B : ℚ)
173+
example (u v x y A B : ℚ)
174+
(a : 0 < A)
175+
(a_1 : 0 <= 1 - A)
176+
(a_2 : 0 <= B - 1)
177+
(a_3 : 0 <= B - x)
178+
(a_4 : 0 <= B - y)
179+
(a_5 : 0 <= u)
180+
(a_6 : 0 <= v)
181+
(a_7 : 0 < A - u)
182+
(a_8 : 0 < A - v) :
183+
u * y + v * x + u * v < 3 * A * B :=
184+
by nlinarith
185+
186+
example (u v x y A B : ℚ) : (0 < A) → (A ≤ 1) → (1 ≤ B)
187+
→ (x ≤ B) → ( y ≤ B)
188+
→ (0 ≤ u ) → (0 ≤ v )
189+
→ (u < A) → ( v < A)
190+
→ (u * y + v * x + u * v < 3 * A * B) :=
191+
begin
192+
intros,
193+
nlinarith
194+
end
174195

196+
example (u v x y A B : ℚ)
175197
(a : 0 < A)
176198
(a_1 : 0 <= 1 - A)
177199
(a_2 : 0 <= B - 1)
@@ -268,3 +290,30 @@ lemma test6 (u v x y A B : ℚ)
268290
intros,
269291
linarith
270292
end
293+
294+
example (A B : ℚ) : (0 < A) → (1 ≤ B) → (0 < A / 8 * B) :=
295+
begin
296+
intros, nlinarith
297+
end
298+
299+
example (x y : ℚ) : 0 ≤ x ^2 + y ^2 :=
300+
by nlinarith
301+
302+
example (x y : ℚ) : 0 ≤ x*x + y*y :=
303+
by nlinarith
304+
305+
example (x y : ℚ) : x = 0 → y = 0 → x*x + y*y = 0 :=
306+
by intros; nlinarith
307+
308+
/- lemma norm_eq_zero_iff {x y : ℚ} : x * x + y * y = 0 ↔ x = 0 ∧ y = 0 :=
309+
begin
310+
split,
311+
{ intro h, split; sorry }, -- should be solved after refactor
312+
{ rintro ⟨⟩, nlinarith }
313+
end -/
314+
315+
-- should be solved after refactor
316+
/- lemma norm_nonpos_right {x y : ℚ} (h1 : x * x + y * y ≤ 0) : y = 0 :=
317+
by nlinarith
318+
lemma norm_nonpos_left (x y : ℚ) (h1 : x * x + y * y ≤ 0) : x = 0 :=
319+
by nlinarith -/

0 commit comments

Comments
 (0)