Skip to content

Commit

Permalink
feat(tactic/ring): recursive ring_nf (#14429)
Browse files Browse the repository at this point in the history
As [reported on Zulip](https://leanprover.zulipchat.com/#narrow/stream/113488-general/topic/.60ring_nf.60.20not.20consistently.20normalizing.3F). This allows `ring_nf` to rewrite inside the atoms of a ring expression, meaning that things like `f (a + b) + c` can simplify in both `+` expressions.
  • Loading branch information
digama0 committed Jun 15, 2022
1 parent 6e0e270 commit ea97606
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 25 deletions.
Expand Up @@ -515,8 +515,7 @@ begin
simp only [stream_nth_fr_ne_zero, conts_eq.symm, pred_conts_eq.symm] at tmp,
rw tmp,
simp only [denom'],
ring_nf,
ac_refl },
ring_nf },
rwa this },
-- derive some tedious inequalities that we need to rewrite our goal
have nextConts_b_ineq : (fib (n + 2) : K) ≤ (pred_conts.b + gp.b * conts.b), by
Expand Down
2 changes: 1 addition & 1 deletion src/ring_theory/trace.lean
Expand Up @@ -396,7 +396,7 @@ begin
trace_form_apply, algebra.smul_mul_assoc],
rw [mul_comm (b x), ← smul_def],
ring_nf,
simp,
simp [mul_comm],
end

lemma trace_matrix_of_matrix_mul_vec [fintype κ] (b : κ → B) (P : matrix κ κ A) :
Expand Down
76 changes: 54 additions & 22 deletions src/tactic/ring.lean
Expand Up @@ -431,6 +431,17 @@ do c ← get_cache,
return (xadd' c (const α1 1) (e, i) (`(1), 1) (const α0 0),
c.cs_app ``horner_atom [e])

/-- Evaluate `a` where `a` is an atom. -/
meta def eval_norm_atom (norm_atom : expr → tactic (expr × expr))
(e : expr) : ring_m (horner_expr × expr) :=
do o ← lift $ try_core (guard (e.get_app_args.length > 0) >> norm_atom e),
match o with
| none := eval_atom e
| some (e', p) := do
(e₂, p₂) ← eval_atom e',
prod.mk e₂ <$> lift (mk_eq_trans p p₂)
end

lemma subst_into_pow {α} [monoid α] (l r tl tr t)
(prl : (l : α) = tl) (prr : (r : ℕ) = tr) (prt : tl ^ tr = t) : l ^ r = t :=
by rw [prl, prr, prt]
Expand All @@ -445,7 +456,7 @@ by rw [div_eq_mul_inv, h]

/-- Evaluate a ring expression `e` recursively to normal form, together with a proof of
equality. -/
meta def eval : expr → ring_m (horner_expr × expr)
meta def eval (norm_atom : expr → tactic (expr × expr)) : expr → ring_m (horner_expr × expr)
| `(%%e₁ + %%e₂) := do
(e₁', p₁) ← eval e₁,
(e₂', p₂) ← eval e₂,
Expand All @@ -460,7 +471,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p) ← eval e,
p' ← ic_lift $ λ ic, ic.mk_app ``unfold_sub [e₁, e₂, e', p],
return (e', p'))
(eval_atom e)
(eval_norm_atom norm_atom e)
| `(- %%e) := do
(e₁, p₁) ← eval e,
(e₂, p₂) ← eval_neg e₁,
Expand All @@ -475,7 +486,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
| e@`(has_inv.inv %%_) := (do
(e', p) ← lift $ norm_num.derive e <|> refl_conv e,
n ← lift $ e'.to_rat,
return (const e' n, p)) <|> eval_atom e
return (const e' n, p)) <|> eval_norm_atom norm_atom e
| e@`(@has_div.div _ %%inst %%e₁ %%e₂) := mcond
(succeeds (do
inst' ← ic_lift $ λ ic, ic.mk_app ``div_inv_monoid.to_has_div [],
Expand All @@ -486,7 +497,7 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p) ← eval e,
p' ← ic_lift $ λ ic, ic.mk_app ``unfold_div [e₁, e₂, e', p],
return (e', p'))
(eval_atom e)
(eval_norm_atom norm_atom e)
| e@`(@has_pow.pow _ _ %%P %%e₁ %%e₂) := do
(e₂', p₂) ← lift $ norm_num.derive e₂ <|> refl_conv e₂,
match e₂'.to_nat, P with
Expand All @@ -495,18 +506,18 @@ meta def eval : expr → ring_m (horner_expr × expr)
(e', p') ← eval_pow e₁' (e₂, k),
p ← ic_lift $ λ ic, ic.mk_app ``subst_into_pow [e₁, e₂, e₁', e₂', e', p₁, p₂, p'],
return (e', p)
| _, _ := eval_atom e
| _, _ := eval_norm_atom norm_atom e
end
| e := match e.to_nat with
| some n := (const e (rat.of_int n)).refl_conv
| none := eval_atom e
| none := eval_norm_atom norm_atom e
end

/-- Evaluate a ring expression `e` recursively to normal form, together with a proof of
equality. -/
meta def eval' (red : transparency) (atoms : ref (buffer expr))
(e : expr) : tactic (expr × expr) :=
ring_m.run' red atoms e $ do (e', p) ← eval e, return (e', p)
(norm_atom : expr → tactic (expr × expr)) (e : expr) : tactic (expr × expr) :=
ring_m.run' red atoms e $ do (e', p) ← eval norm_atom e, return (e', p)

theorem horner_def' {α} [comm_semiring α] (a x n b) : @horner α _ a x n b = x ^ n * a + b :=
by simp [horner, mul_comm]
Expand Down Expand Up @@ -540,10 +551,22 @@ inductive normalize_mode | raw | SOP | horner
instance : inhabited normalize_mode := ⟨normalize_mode.horner⟩

/-- A `ring`-based normalization simplifier that rewrites ring expressions into the specified mode.
See `normalize`. This version takes a list of atoms to persist across multiple calls. -/
See `normalize`. This version takes a list of atoms to persist across multiple calls.
* `atoms`: a mutable reference containing the atom set from the previous call
* `red`: the reducibility setting to use when comparing atoms for defeq
* `mode`: the normalization style (see `normalize_mode`)
* `recursive`: if true, atoms will be reduced recursively using `normalize'`
* `e`: the expression to normalize
* `inner`: This should be set to `ff`. It is used internally to disable normalization
at the top level when called from `eval` in order to prevent an infinite loop
`eval' -> eval_atom -> normalize' -> eval'` when called on something that can't
be simplified like `x`.
-/
meta def normalize' (atoms : ref (buffer expr))
(red : transparency) (mode := normalize_mode.horner) (e : expr) : tactic (expr × expr) :=
do
(red : transparency) (mode := normalize_mode.horner) (recursive := tt) :
expr → opt_param _ ff → tactic (expr × expr)
| e inner := do
pow_lemma ← simp_lemmas.mk.add_simp ``pow_one,
let lemmas := match mode with
| normalize_mode.SOP :=
Expand All @@ -563,10 +586,12 @@ do
pure (e', pr))
(λ e, do
a ← read_ref atoms,
let norm_rec := if recursive then λ e, normalize' e tt else λ _, failed,
(a, e', pr) ← ext_simplify_core a {}
simp_lemmas.mk (λ _, failed) (λ a _ _ _ e, do
simp_lemmas.mk (λ _, failed) (λ a _ _ p e, do
guard (inner → p.is_some),
write_ref atoms a,
(new_e, pr) ← eval' red atoms e,
(new_e, pr) ← eval' red atoms norm_rec e,
(new_e, pr) ← match mode with
| normalize_mode.raw := λ _, pure (new_e, pr)
| normalize_mode.horner := trans_conv (λ _, pure (new_e, pr))
Expand Down Expand Up @@ -594,9 +619,15 @@ do
This results in terms like `(3 * x ^ 2 * y + 1) * x + y`.
* `SOP` means sum of products form, expanding everything to monomials.
This results in terms like `3 * x ^ 3 * y + x + y`. -/
meta def normalize (red : transparency) (mode := normalize_mode.horner) (e : expr) :
tactic (expr × expr) :=
using_new_ref mk_buffer $ λ atoms, normalize' atoms red mode e
meta def normalize (red : transparency) (mode := normalize_mode.horner)
(recursive := tt) (e : expr) : tactic (expr × expr) :=
using_new_ref mk_buffer $ λ atoms, normalize' atoms red mode recursive e

/-- Configuration for `ring_nf`.
* `recursive`: if true, atoms inside ring expressions will be reduced recursively
-/
@[derive inhabited] structure ring_nf_cfg := (recursive := tt)

end ring

Expand All @@ -613,7 +644,7 @@ meta def ring1 (red : parse (tk "!")?) : tactic unit :=
let transp := if red.is_some then semireducible else reducible in
do `(%%e₁ = %%e₂) ← target >>= instantiate_mvars,
((e₁', p₁), (e₂', p₂)) ← ring_m.run transp e₁ $
prod.mk <$> eval e₁ <*> eval e₂,
prod.mk <$> eval (λ _, failed) e₁ <*> eval (λ _, failed) e₂,
is_def_eq e₁' e₂',
p ← mk_eq_symm p₂ >>= mk_eq_trans p₁,
tactic.exact p
Expand All @@ -636,12 +667,12 @@ which rewrites all ring expressions into a normal form. When writing a normal fo
`ring_nf SOP` will use sum-of-products form instead of horner form.
`ring_nf!` will use a more aggressive reducibility setting to identify atoms.
-/
meta def ring_nf (red : parse (tk "!")?) (SOP : parse ring.mode) (loc : parse location) :
tactic unit :=
meta def ring_nf (red : parse (tk "!")?) (SOP : parse ring.mode) (loc : parse location)
(cfg : ring_nf_cfg := {}) : tactic unit :=
do ns ← loc.get_locals,
let transp := if red.is_some then semireducible else reducible,
tt ← using_new_ref mk_buffer $ λ atoms,
tactic.replace_at (normalize' atoms transp SOP) ns loc.include_goal
tactic.replace_at (normalize' atoms transp SOP cfg.recursive) ns loc.include_goal
| fail "ring_nf failed to simplify",
when loc.include_goal $ try tactic.reflexivity

Expand Down Expand Up @@ -682,9 +713,10 @@ local postfix `?`:9001 := optional
/--
Normalises expressions in commutative (semi-)rings inside of a `conv` block using the tactic `ring`.
-/
meta def ring_nf (red : parse (lean.parser.tk "!")?) (SOP : parse ring.mode) : conv unit :=
meta def ring_nf (red : parse (lean.parser.tk "!")?) (SOP : parse ring.mode)
(cfg : ring.ring_nf_cfg := {}) : conv unit :=
let transp := if red.is_some then semireducible else reducible in
replace_lhs (normalize transp SOP)
replace_lhs (normalize transp SOP cfg.recursive)
<|> fail "ring_nf failed to simplify"

/--
Expand Down
19 changes: 19 additions & 0 deletions test/ring.lean
Expand Up @@ -36,6 +36,19 @@ begin
ring
end

example {A : ℤ} (f : ℤ → ℤ) : f 0 = f (A - A) := by ring_nf
example {A : ℤ} (f : ℤ → ℤ) : f 0 = f (A + -A) := by ring_nf

example {a b c : ℝ} (h : 0 < a ^ 4 + b ^ 4 + c ^ 4) :
a ^ 4 / (a ^ 4 + b ^ 4 + c ^ 4) +
b ^ 4 / (b ^ 4 + c ^ 4 + a ^ 4) +
c ^ 4 / (c ^ 4 + a ^ 4 + b ^ 4)
= 1 :=
begin
ring_nf at ⊢ h,
field_simp [h.ne'],
end

example (a b c d x y : ℚ) (hx : x ≠ 0) (hy : y ≠ 0) :
a + b / x - c / x^2 + d / x^3 = a + x⁻¹ * (y * b / y + (d / x - c) / x) :=
begin
Expand Down Expand Up @@ -76,3 +89,9 @@ by transitivity; [exact h, ring]

-- `ring_nf` should descend into the subexpressions `x * -a` and `-a * x`:
example {a x : ℚ} : x * -a = - a * x := by ring_nf

example (f : ℤ → ℤ) (a b : ℤ) : f (2 * a + b) + b = b + f (b + a + a) :=
begin
success_if_fail {{ ring_nf {recursive := ff} }},
ring_nf
end

0 comments on commit ea97606

Please sign in to comment.