Skip to content

Commit

Permalink
perf(linarith): don't repeat nonneg proofs for nat-to-int casts (#3226)
Browse files Browse the repository at this point in the history
This performance issue showed up particularly when using `nlinarith` over `nat`s. Proofs that `(n : int) >= 0` were being repeated many times, which led to quadratic blowup in the `nlinarith` preprocessing.



Co-authored-by: Rob Lewis <rob.y.lewis@gmail.com>
  • Loading branch information
robertylewis and robertylewis committed Jun 30, 2020
1 parent 791744b commit 056a72a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 43 deletions.
38 changes: 19 additions & 19 deletions src/analysis/special_functions/trigonometric.lean
Expand Up @@ -731,8 +731,8 @@ by rw [← sin_pi_div_two_sub, ← sin_pi_div_two_sub, sin_sub_sin, sub_sub_sub_
add_sub, sub_add_eq_add_sub, add_halves, sub_sub, sub_div π, cos_pi_div_two_sub,
← neg_sub, neg_div, sin_neg, ← neg_mul_eq_mul_neg, neg_mul_eq_neg_mul, mul_right_comm]

lemma cos_lt_cos_of_nonneg_of_le_pi_div_two {x y : ℝ} (hx₁ : 0 ≤ x) (hx₂ : x ≤ π / 2)
(hy₁ : 0 ≤ y) (hy₂ : y ≤ π / 2) (hxy : x < y) : cos y < cos x :=
lemma cos_lt_cos_of_nonneg_of_le_pi_div_two {x y : ℝ} (hx₁ : 0 ≤ x) (hy₂ : y ≤ π / 2) (hxy : x < y) :
cos y < cos x :=
calc cos y = cos x * cos (y - x) - sin x * sin (y - x) :
by rw [← cos_add, add_sub_cancel'_right]
... < (cos x * 1) - sin x * sin (y - x) :
Expand All @@ -746,10 +746,10 @@ calc cos y = cos x * cos (y - x) - sin x * sin (y - x) :
exact sub_le_self _ (mul_nonneg (sin_nonneg_of_nonneg_of_le_pi hx₁ (by linarith))
(sin_nonneg_of_nonneg_of_le_pi (by linarith) (by linarith)))

lemma cos_lt_cos_of_nonneg_of_le_pi {x y : ℝ} (hx₁ : 0 ≤ x) (hx₂ : x ≤ π)
(hy₁ : 0 ≤ y) (hy₂ : y ≤ π) (hxy : x < y) : cos y < cos x :=
lemma cos_lt_cos_of_nonneg_of_le_pi {x y : ℝ} (hx₁ : 0 ≤ x) (hy₂ : y ≤ π) (hxy : x < y) :
cos y < cos x :=
match (le_total x (π / 2) : x ≤ π / 2 ∨ π / 2 ≤ x), le_total y (π / 2) with
| or.inl hx, or.inl hy := cos_lt_cos_of_nonneg_of_le_pi_div_two hx₁ hx hy₁ hy hxy
| or.inl hx, or.inl hy := cos_lt_cos_of_nonneg_of_le_pi_div_two hx₁ hy hxy
| or.inl hx, or.inr hy := (lt_or_eq_of_le hx).elim
(λ hx, calc cos y ≤ 0 : cos_nonpos_of_pi_div_two_le_of_le hy (by linarith [pi_pos])
... < cos x : cos_pos_of_neg_pi_div_two_lt_of_lt_pi_div_two (by linarith) hx)
Expand All @@ -760,10 +760,10 @@ match (le_total x (π / 2) : x ≤ π / 2 ∨ π / 2 ≤ x), le_total y (π / 2)
apply cos_lt_cos_of_nonneg_of_le_pi_div_two; linarith)
end

lemma cos_le_cos_of_nonneg_of_le_pi {x y : ℝ} (hx₁ : 0 ≤ x) (hx₂ : x ≤ π)
(hy₁ : 0 ≤ y) (hy₂ : y ≤ π) (hxy : x ≤ y) : cos y ≤ cos x :=
lemma cos_le_cos_of_nonneg_of_le_pi {x y : ℝ} (hx₁ : 0 ≤ x) (hy₂ : y ≤ π) (hxy : x ≤ y) :
cos y ≤ cos x :=
(lt_or_eq_of_le hxy).elim
(le_of_lt ∘ cos_lt_cos_of_nonneg_of_le_pi hx₁ hx₂ hy₁ hy₂)
(le_of_lt ∘ cos_lt_cos_of_nonneg_of_le_pi hx₁ hy₂)
(λ h, h ▸ le_refl _)

lemma sin_lt_sin_of_le_of_le_pi_div_two {x y : ℝ} (hx₁ : -(π / 2) ≤ x)
Expand Down Expand Up @@ -1085,38 +1085,38 @@ neg_pos.1 (tan_neg x ▸ tan_pos_of_pos_of_lt_pi_div_two (by linarith) (by linar
lemma tan_nonpos_of_nonpos_of_neg_pi_div_two_le {x : ℝ} (hx0 : x ≤ 0) (hpx : -(π / 2) ≤ x) : tan x ≤ 0 :=
neg_nonneg.1 (tan_neg x ▸ tan_nonneg_of_nonneg_of_le_pi_div_two (by linarith) (by linarith [pi_pos]))

lemma tan_lt_tan_of_nonneg_of_lt_pi_div_two {x y : ℝ} (hx₁ : 0 ≤ x) (hx₂ : x < π / 2) (hy₁ : 0 y)
(hy₂ : y < π / 2) (hxy : x < y) : tan x < tan y :=
lemma tan_lt_tan_of_nonneg_of_lt_pi_div_two {x y : ℝ} (hx₁ : 0 ≤ x) (hy₂ : y < π / 2) (hxy : x < y) :
tan x < tan y :=
begin
rw [tan_eq_sin_div_cos, tan_eq_sin_div_cos],
exact div_lt_div
(sin_lt_sin_of_le_of_le_pi_div_two (by linarith) (le_of_lt hy₂) hxy)
(cos_le_cos_of_nonneg_of_le_pi hx₁ (by linarith) hy₁ (by linarith) (le_of_lt hxy))
(sin_nonneg_of_nonneg_of_le_pi hy₁ (by linarith))
(cos_le_cos_of_nonneg_of_le_pi hx₁ (by linarith) (le_of_lt hxy))
(sin_nonneg_of_nonneg_of_le_pi (by linarith) (by linarith))
(cos_pos_of_neg_pi_div_two_lt_of_lt_pi_div_two (by linarith) hy₂)
end

lemma tan_lt_tan_of_lt_of_lt_pi_div_two {x y : ℝ} (hx₁ : -(π / 2) < x) (hx₂ : x < π / 2)
(hy₁ : -(π / 2) < y) (hy₂ : y < π / 2) (hxy : x < y) : tan x < tan y :=
lemma tan_lt_tan_of_lt_of_lt_pi_div_two {x y : ℝ} (hx₁ : -(π / 2) < x)
(hy₂ : y < π / 2) (hxy : x < y) : tan x < tan y :=
match le_total x 0, le_total y 0 with
| or.inl hx0, or.inl hy0 := neg_lt_neg_iff.1 $ by rw [← tan_neg, ← tan_neg]; exact
tan_lt_tan_of_nonneg_of_lt_pi_div_two (neg_nonneg.2 hy0) (neg_lt.2 hy₁)
(neg_nonneg.2 hx0) (neg_lt.2 hx₁) (neg_lt_neg hxy)
tan_lt_tan_of_nonneg_of_lt_pi_div_two (neg_nonneg.2 hy0)
(neg_lt.2 hx₁) (neg_lt_neg hxy)
| or.inl hx0, or.inr hy0 := (lt_or_eq_of_le hy0).elim
(λ hy0, calc tan x ≤ 0 : tan_nonpos_of_nonpos_of_neg_pi_div_two_le hx0 (le_of_lt hx₁)
... < tan y : tan_pos_of_pos_of_lt_pi_div_two hy0 hy₂)
(λ hy0, by rw [← hy0, tan_zero]; exact
tan_neg_of_neg_of_pi_div_two_lt (hy0.symm ▸ hxy) hx₁)
| or.inr hx0, or.inl hy0 := by linarith
| or.inr hx0, or.inr hy0 := tan_lt_tan_of_nonneg_of_lt_pi_div_two hx0 hx₂ hy0 hy₂ hxy
| or.inr hx0, or.inr hy0 := tan_lt_tan_of_nonneg_of_lt_pi_div_two hx0 hy₂ hxy
end

lemma tan_inj_of_lt_of_lt_pi_div_two {x y : ℝ} (hx₁ : -(π / 2) < x) (hx₂ : x < π / 2)
(hy₁ : -(π / 2) < y) (hy₂ : y < π / 2) (hxy : tan x = tan y) : x = y :=
match lt_trichotomy x y with
| or.inl h := absurd (tan_lt_tan_of_lt_of_lt_pi_div_two hx₁ hx₂ hy₁ hy₂ h) (by rw hxy; exact lt_irrefl _)
| or.inl h := absurd (tan_lt_tan_of_lt_of_lt_pi_div_two hx₁ hy₂ h) (by rw hxy; exact lt_irrefl _)
| or.inr (or.inl h) := h
| or.inr (or.inr h) := absurd (tan_lt_tan_of_lt_of_lt_pi_div_two hy₁ hy₂ hx₁ hx₂ h) (by rw hxy; exact lt_irrefl _)
| or.inr (or.inr h) := absurd (tan_lt_tan_of_lt_of_lt_pi_div_two hy₁ hx₂ h) (by rw hxy; exact lt_irrefl _)
end

/-- Inverse of the `tan` function, returns values in the range `-π / 2 < arctan x` and `arctan x < π / 2` -/
Expand Down
6 changes: 6 additions & 0 deletions src/meta/rb_map.lean
Expand Up @@ -64,6 +64,12 @@ it may be worth it to reverse the fold.
meta def sdiff {α} (s1 s2 : rb_set α) : rb_set α :=
s2.fold s1 $ λ v s, s.erase v

/--
`insert_list s l` inserts each element of `l` into `s`.
-/
meta def insert_list {key} (s : rb_set key) (l : list key) : rb_set key :=
l.foldl rb_set.insert s

end rb_set

/-! ### Declarations about `rb_map` -/
Expand Down
36 changes: 12 additions & 24 deletions src/tactic/linarith/preprocessing.lean
Expand Up @@ -25,6 +25,7 @@ is the main list, and generally none of these should be skipped unless you know
-/

open native tactic expr

namespace linarith

/-! ### Preprocessing -/
Expand Down Expand Up @@ -73,7 +74,7 @@ infer_type e >>= rearr_comp_aux e

/-- If `e` is of the form `((n : ℕ) : ℤ)`, `is_nat_int_coe e` returns `n : ℕ`. -/
meta def is_nat_int_coe : expr → option expr
| `((↑(%%n : ℕ) : ℤ)) := some n
| `(@coe ℕ ℤ %%_ %%n) := some n
| _ := none

/-- If `e : ℕ`, returns a proof of `0 ≤ (e : ℤ)`. -/
Expand All @@ -89,23 +90,6 @@ meta def get_nat_comps : expr → list expr
| none := []
end

/--
`mk_coe_nat_nonneg_prfs e` returns a list of proofs of the form `0 ≤ ((t : ℕ) : ℤ)`
for each subexpression of `e` of the form `((t : ℕ) : ℤ)`.
-/
meta def mk_coe_nat_nonneg_prfs (e : expr) : tactic (list expr) :=
(get_nat_comps e).mmap mk_coe_nat_nonneg_prf

/--
If `pf` is a proof of a comparison over `ℕ`, `mk_int_pfs_of_nat_pf pf` returns a proof of the
corresponding inequality over `ℤ`, using `tactic.zify_proof`, along with proofs that the cast
naturals are nonnegative.
-/
meta def mk_int_pfs_of_nat_pf (pf : expr) : tactic (list expr) :=
do pf' ← zify_proof [] pf,
(a, b) ← infer_type pf' >>= get_rel_sides,
list.cons pf' <$> ((++) <$> mk_coe_nat_nonneg_prfs a <*> mk_coe_nat_nonneg_prfs b)

/--
If `pf` is a proof of a strict inequality `(a : ℤ) < b`,
`mk_non_strict_int_pf_of_strict_int_pf pf` returns a proof of `a + 1 ≤ b`,
Expand Down Expand Up @@ -176,14 +160,18 @@ end }

/--
If `h` is an equality or inequality between natural numbers,
`nat_to_int h` lifts this inequality to the integers,
adding the facts that the integers involved are nonnegative.
`nat_to_int` lifts this inequality to the integers.
It also adds the facts that the integers involved are nonnegative.
To avoid adding the same nonnegativity facts many times, it is a global preprocessor.
-/
meta def nat_to_int : preprocessor :=
meta def nat_to_int : global_preprocessor :=
{ name := "move nats to ints",
transform := λ h,
do tp ← infer_type h,
guardb (is_nat_prop tp) >> mk_int_pfs_of_nat_pf h <|> return [h] }
transform := λ l,
do l ← l.mmap (λ h, infer_type h >>= guardb ∘ is_nat_prop >> zify_proof [] h <|> return h),
nonnegs ← l.mfoldl (λ (es : expr_set) h, do
(a, b) ← infer_type h >>= get_rel_sides,
return $ (es.insert_list (get_nat_comps a)).insert_list (get_nat_comps b)) mk_rb_set,
(++) l <$> nonnegs.to_list.mmap mk_coe_nat_nonneg_prf }

/-- `strengthen_strict_int h` turns a proof `h` of a strict integer inequality `t1 < t2`
into a proof of `t1 ≤ t2 + 1`. -/
Expand Down
4 changes: 4 additions & 0 deletions test/linarith.lean
Expand Up @@ -343,3 +343,7 @@ by refine_struct { le := leα }; admit

example (a : α) (ha : a < 2) : a ≤ a :=
by linarith

example (p q r s t u v w : ℕ) (h1 : p + u = q + t) (h2 : r + w = s + v) :
p * r + q * s + (t * w + u * v) = p * s + q * r + (t * v + u * w) :=
by nlinarith

0 comments on commit 056a72a

Please sign in to comment.