Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - feat(tactic/ring): recursive ring_nf #14429

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,7 @@ begin
simp only [stream_nth_fr_ne_zero, conts_eq.symm, pred_conts_eq.symm] at tmp,
rw tmp,
simp only [denom'],
ring_nf,
ac_refl },
ring_nf },
rwa this },
-- derive some tedious inequalities that we need to rewrite our goal
have nextConts_b_ineq : (fib (n + 2) : K) ≤ (pred_conts.b + gp.b * conts.b), by
Expand Down
2 changes: 1 addition & 1 deletion src/ring_theory/trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ begin
trace_form_apply, algebra.smul_mul_assoc],
rw [mul_comm (b x), ← smul_def],
ring_nf,
simp,
simp [mul_comm],
end

lemma trace_matrix_of_matrix_mul_vec [fintype κ] (b : κ → B) (P : matrix κ κ A) :
Expand Down
76 changes: 54 additions & 22 deletions src/tactic/ring.lean
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,17 @@ do c ← get_cache,
return (xadd' c (const α1 1) (e, i) (`(1), 1) (const α0 0),
c.cs_app ``horner_atom [e])

/-- Evaluate `a` where `a` is an atom. -/
meta def eval_norm_atom (norm_atom : expr → tactic (expr × expr))
(e : expr) : ring_m (horner_expr × expr) :=
do o ← lift $ try_core (guard (e.get_app_args.length > 0) >> norm_atom e),
match o with
| none := eval_atom e
| some (e', p) := do
(e₂, p₂) ← eval_atom e',
prod.mk e₂ <$> lift (mk_eq_trans p p₂)
end

lemma subst_into_pow {α} [monoid α] (l r tl tr t)
(prl : (l : α) = tl) (prr : (r : ℕ) = tr) (prt : tl ^ tr = t) : l ^ r = t :=
by rw [prl, prr, prt]
Expand All @@ -445,7 +456,7 @@ by rw [div_eq_mul_inv, h]

/-- Evaluate a ring expression `e` recursively to normal form, together with a proof of
equality. -/
meta def eval : expr → ring_m (horner_expr × expr)
meta def eval (norm_atom : expr → tactic (expr × expr)) : expr → ring_m (horner_expr × expr)
| `(%%e₁ + %%e₂) := do
(e₁', p₁) ← eval e₁,
(e₂', p₂) ← eval e₂,
Expand All @@ -460,7 +471,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p) ← eval e,
p' ← ic_lift $ λ ic, ic.mk_app ``unfold_sub [e₁, e₂, e', p],
return (e', p'))
(eval_atom e)
(eval_norm_atom norm_atom e)
| `(- %%e) := do
(e₁, p₁) ← eval e,
(e₂, p₂) ← eval_neg e₁,
Expand All @@ -475,7 +486,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
| e@`(has_inv.inv %%_) := (do
(e', p) ← lift $ norm_num.derive e <|> refl_conv e,
n ← lift $ e'.to_rat,
return (const e' n, p)) <|> eval_atom e
return (const e' n, p)) <|> eval_norm_atom norm_atom e
| e@`(@has_div.div _ %%inst %%e₁ %%e₂) := mcond
(succeeds (do
inst' ← ic_lift $ λ ic, ic.mk_app ``div_inv_monoid.to_has_div [],
Expand All @@ -486,7 +497,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p) ← eval e,
p' ← ic_lift $ λ ic, ic.mk_app ``unfold_div [e₁, e₂, e', p],
return (e', p'))
(eval_atom e)
(eval_norm_atom norm_atom e)
| e@`(@has_pow.pow _ _ %%P %%e₁ %%e₂) := do
(e₂', p₂) ← lift $ norm_num.derive e₂ <|> refl_conv e₂,
match e₂'.to_nat, P with
Expand All @@ -495,18 +506,18 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p') ← eval_pow e₁' (e₂, k),
p ← ic_lift $ λ ic, ic.mk_app ``subst_into_pow [e₁, e₂, e₁', e₂', e', p₁, p₂, p'],
return (e', p)
| _, _ := eval_atom e
| _, _ := eval_norm_atom norm_atom e
end
| e := match e.to_nat with
| some n := (const e (rat.of_int n)).refl_conv
| none := eval_atom e
| none := eval_norm_atom norm_atom e
end

/-- Evaluate a ring expression `e` recursively to normal form, together with a proof of
equality. -/
meta def eval' (red : transparency) (atoms : ref (buffer expr))
(e : expr) : tactic (expr × expr) :=
ring_m.run' red atoms e $ do (e', p) ← eval e, return (e', p)
(norm_atom : expr → tactic (expr × expr)) (e : expr) : tactic (expr × expr) :=
ring_m.run' red atoms e $ do (e', p) ← eval norm_atom e, return (e', p)

theorem horner_def' {α} [comm_semiring α] (a x n b) : @horner α _ a x n b = x ^ n * a + b :=
by simp [horner, mul_comm]
Expand Down Expand Up @@ -540,10 +551,22 @@ inductive normalize_mode | raw | SOP | horner
instance : inhabited normalize_mode := ⟨normalize_mode.horner⟩

/-- A `ring`-based normalization simplifier that rewrites ring expressions into the specified mode.
See `normalize`. This version takes a list of atoms to persist across multiple calls. -/
See `normalize`. This version takes a list of atoms to persist across multiple calls.

* `atoms`: a mutable reference containing the atom set from the previous call
* `red`: the reducibility setting to use when comparing atoms for defeq
* `mode`: the normalization style (see `normalize_mode`)
* `recursive`: if true, atoms will be reduced recursively using `normalize'`
* `e`: the expression to normalize
* `inner`: This should be set to `ff`. It is used internally to disable normalization
at the top level when called from `eval` in order to prevent an infinite loop
`eval' -> eval_atom -> normalize' -> eval'` when called on something that can't
be simplified like `x`.
-/
meta def normalize' (atoms : ref (buffer expr))
(red : transparency) (mode := normalize_mode.horner) (e : expr) : tactic (expr × expr) :=
do
(red : transparency) (mode := normalize_mode.horner) (recursive := tt) :
expr → opt_param _ ff → tactic (expr × expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add to the doc what the new opt_param does?

| e inner := do
pow_lemma ← simp_lemmas.mk.add_simp ``pow_one,
let lemmas := match mode with
| normalize_mode.SOP :=
Expand All @@ -563,10 +586,12 @@ do
pure (e', pr))
(λ e, do
a ← read_ref atoms,
let norm_rec := if recursive then λ e, normalize' e tt else λ _, failed,
(a, e', pr) ← ext_simplify_core a {}
simp_lemmas.mk (λ _, failed) (λ a _ _ _ e, do
simp_lemmas.mk (λ _, failed) (λ a _ _ p e, do
guard (inner → p.is_some),
write_ref atoms a,
(new_e, pr) ← eval' red atoms e,
(new_e, pr) ← eval' red atoms norm_rec e,
(new_e, pr) ← match mode with
| normalize_mode.raw := λ _, pure (new_e, pr)
| normalize_mode.horner := trans_conv (λ _, pure (new_e, pr))
Expand Down Expand Up @@ -594,9 +619,15 @@ do
This results in terms like `(3 * x ^ 2 * y + 1) * x + y`.
* `SOP` means sum of products form, expanding everything to monomials.
This results in terms like `3 * x ^ 3 * y + x + y`. -/
meta def normalize (red : transparency) (mode := normalize_mode.horner) (e : expr) :
tactic (expr × expr) :=
using_new_ref mk_buffer $ λ atoms, normalize' atoms red mode e
meta def normalize (red : transparency) (mode := normalize_mode.horner)
(recursive := tt) (e : expr) : tactic (expr × expr) :=
using_new_ref mk_buffer $ λ atoms, normalize' atoms red mode recursive e

/-- Configuration for `ring_nf`.

* `recursive`: if true, atoms inside ring expressions will be reduced recursively
-/
@[derive inhabited] structure ring_nf_cfg := (recursive := tt)

end ring

Expand All @@ -613,7 +644,7 @@ meta def ring1 (red : parse (tk "!")?) : tactic unit :=
let transp := if red.is_some then semireducible else reducible in
do `(%%e₁ = %%e₂) ← target >>= instantiate_mvars,
((e₁', p₁), (e₂', p₂)) ← ring_m.run transp e₁ $
prod.mk <$> eval e₁ <*> eval e₂,
prod.mk <$> eval (λ _, failed) e₁ <*> eval (λ _, failed) e₂,
is_def_eq e₁' e₂',
p ← mk_eq_symm p₂ >>= mk_eq_trans p₁,
tactic.exact p
Expand All @@ -636,12 +667,12 @@ which rewrites all ring expressions into a normal form. When writing a normal fo
`ring_nf SOP` will use sum-of-products form instead of horner form.
`ring_nf!` will use a more aggressive reducibility setting to identify atoms.
-/
meta def ring_nf (red : parse (tk "!")?) (SOP : parse ring.mode) (loc : parse location) :
tactic unit :=
meta def ring_nf (red : parse (tk "!")?) (SOP : parse ring.mode) (loc : parse location)
(cfg : ring_nf_cfg := {}) : tactic unit :=
do ns ← loc.get_locals,
let transp := if red.is_some then semireducible else reducible,
tt ← using_new_ref mk_buffer $ λ atoms,
tactic.replace_at (normalize' atoms transp SOP) ns loc.include_goal
tactic.replace_at (normalize' atoms transp SOP cfg.recursive) ns loc.include_goal
| fail "ring_nf failed to simplify",
when loc.include_goal $ try tactic.reflexivity

Expand Down Expand Up @@ -682,9 +713,10 @@ local postfix `?`:9001 := optional
/--
Normalises expressions in commutative (semi-)rings inside of a `conv` block using the tactic `ring`.
-/
meta def ring_nf (red : parse (lean.parser.tk "!")?) (SOP : parse ring.mode) : conv unit :=
meta def ring_nf (red : parse (lean.parser.tk "!")?) (SOP : parse ring.mode)
(cfg : ring.ring_nf_cfg := {}) : conv unit :=
let transp := if red.is_some then semireducible else reducible in
replace_lhs (normalize transp SOP)
replace_lhs (normalize transp SOP cfg.recursive)
<|> fail "ring_nf failed to simplify"

/--
Expand Down
19 changes: 19 additions & 0 deletions test/ring.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ begin
ring
end

example {A : ℤ} (f : ℤ → ℤ) : f 0 = f (A - A) := by ring_nf
example {A : ℤ} (f : ℤ → ℤ) : f 0 = f (A + -A) := by ring_nf

example {a b c : ℝ} (h : 0 < a ^ 4 + b ^ 4 + c ^ 4) :
a ^ 4 / (a ^ 4 + b ^ 4 + c ^ 4) +
b ^ 4 / (b ^ 4 + c ^ 4 + a ^ 4) +
c ^ 4 / (c ^ 4 + a ^ 4 + b ^ 4)
= 1 :=
begin
ring_nf at ⊢ h,
field_simp [h.ne'],
end

example (a b c d x y : ℚ) (hx : x ≠ 0) (hy : y ≠ 0) :
a + b / x - c / x^2 + d / x^3 = a + x⁻¹ * (y * b / y + (d / x - c) / x) :=
begin
Expand Down Expand Up @@ -76,3 +89,9 @@ by transitivity; [exact h, ring]

-- `ring_nf` should descend into the subexpressions `x * -a` and `-a * x`:
example {a x : ℚ} : x * -a = - a * x := by ring_nf

example (f : ℤ → ℤ) (a b : ℤ) : f (2 * a + b) + b = b + f (b + a + a) :=
begin
success_if_fail {{ ring_nf {recursive := ff} }},
ring_nf
end