|
| 1 | +/- |
| 2 | +Copyright (c) 2018 Mario Carneiro. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Mario Carneiro |
| 5 | +
|
| 6 | +Evaluate expressions in the language of commutative monoids and groups. |
| 7 | +-/ |
| 8 | +import algebra.group_power tactic.norm_num |
| 9 | + |
| 10 | +namespace tactic |
| 11 | +namespace abel |
| 12 | + |
| 13 | +meta structure cache := |
| 14 | +(α : expr) |
| 15 | +(univ : level) |
| 16 | +(α0 : expr) |
| 17 | +(is_group : bool) |
| 18 | +(inst : expr) |
| 19 | + |
| 20 | +meta def mk_cache (e : expr) : tactic cache := |
| 21 | +do α ← infer_type e, |
| 22 | + c ← mk_app ``add_comm_monoid [α] >>= mk_instance, |
| 23 | + cg ← try_core (mk_app ``add_comm_group [α] >>= mk_instance), |
| 24 | + u ← mk_meta_univ, |
| 25 | + infer_type α >>= unify (expr.sort (level.succ u)), |
| 26 | + u ← get_univ_assignment u, |
| 27 | + α0 ← expr.of_nat α 0, |
| 28 | + match cg with |
| 29 | + | (some cg) := return ⟨α, u, α0, tt, cg⟩ |
| 30 | + | _ := return ⟨α, u, α0, ff, c⟩ |
| 31 | + end |
| 32 | + |
| 33 | +meta def cache.app (c : cache) (n : name) (inst : expr) : list expr → expr := |
| 34 | +(@expr.const tt n [c.univ] c.α inst).mk_app |
| 35 | + |
| 36 | +meta def cache.mk_app (c : cache) (n inst : name) (l : list expr) : tactic expr := |
| 37 | +do m ← mk_instance ((expr.const inst [c.univ] : expr) c.α), return $ c.app n m l |
| 38 | + |
| 39 | +meta def add_g : name → name |
| 40 | +| (name.mk_string s p) := name.mk_string (s ++ "g") p |
| 41 | +| n := n |
| 42 | + |
| 43 | +meta def cache.iapp (c : cache) (n : name) : list expr → expr := |
| 44 | +c.app (if c.is_group then add_g n else n) c.inst |
| 45 | + |
| 46 | +def term {α} [add_comm_monoid α] (n : ℕ) (x a : α) : α := add_monoid.smul n x + a |
| 47 | +def termg {α} [add_comm_group α] (n : ℤ) (x a : α) : α := gsmul n x + a |
| 48 | + |
| 49 | +meta def cache.mk_term (c : cache) (n x a : expr) : expr := c.iapp ``term [n, x, a] |
| 50 | + |
| 51 | +meta def cache.int_to_expr (c : cache) (n : ℤ) : tactic expr := |
| 52 | +expr.of_int (if c.is_group then `(ℤ) else `(ℕ)) n |
| 53 | + |
| 54 | +meta inductive normal_expr : Type |
| 55 | +| zero (e : expr) : normal_expr |
| 56 | +| nterm (e : expr) (n : expr × ℤ) (x : expr) (a : normal_expr) : normal_expr |
| 57 | + |
| 58 | +meta def normal_expr.e : normal_expr → expr |
| 59 | +| (normal_expr.zero e) := e |
| 60 | +| (normal_expr.nterm e _ _ _) := e |
| 61 | + |
| 62 | +meta instance : has_coe normal_expr expr := ⟨normal_expr.e⟩ |
| 63 | + |
| 64 | +meta def normal_expr.term' (c : cache) (n : expr × ℤ) (x : expr) (a : normal_expr) : normal_expr := |
| 65 | +normal_expr.nterm (c.mk_term n.1 x a) n x a |
| 66 | + |
| 67 | +meta def normal_expr.zero' (c : cache) : normal_expr := normal_expr.zero c.α0 |
| 68 | + |
| 69 | +meta def normal_expr.to_list : normal_expr → list (ℤ × expr) |
| 70 | +| (normal_expr.zero _) := [] |
| 71 | +| (normal_expr.nterm _ (_, n) x a) := (n, x) :: a.to_list |
| 72 | + |
| 73 | +open normal_expr |
| 74 | + |
| 75 | +meta def normal_expr.to_string (e : normal_expr) : string := |
| 76 | +" + ".intercalate $ (to_list e).map $ |
| 77 | +λ ⟨n, e⟩, to_string n ++ " • (" ++ to_string e ++ ")" |
| 78 | + |
| 79 | +meta def normal_expr.pp (e : normal_expr) : tactic format := |
| 80 | +do l ← (to_list e).mmap (λ ⟨n, e⟩, do |
| 81 | + pe ← pp e, return (to_fmt n ++ " • (" ++ pe ++ ")")), |
| 82 | + return $ format.join $ l.intersperse ↑" + " |
| 83 | + |
| 84 | +meta instance : has_to_tactic_format normal_expr := ⟨normal_expr.pp⟩ |
| 85 | + |
| 86 | +meta def normal_expr.refl_conv (e : normal_expr) : tactic (normal_expr × expr) := |
| 87 | +do p ← mk_eq_refl e, return (e, p) |
| 88 | + |
| 89 | +theorem const_add_term {α} [add_comm_monoid α] (k n x a a') (h : k + a = a') : |
| 90 | + k + @term α _ n x a = term n x a' := by simp [h.symm, term] |
| 91 | + |
| 92 | +theorem const_add_termg {α} [add_comm_group α] (k n x a a') (h : k + a = a') : |
| 93 | + k + @termg α _ n x a = termg n x a' := by simp [h.symm, termg] |
| 94 | + |
| 95 | +theorem term_add_const {α} [add_comm_monoid α] (n x a k a') (h : a + k = a') : |
| 96 | + @term α _ n x a + k = term n x a' := by simp [h.symm, term] |
| 97 | + |
| 98 | +theorem term_add_constg {α} [add_comm_group α] (n x a k a') (h : a + k = a') : |
| 99 | + @termg α _ n x a + k = termg n x a' := by simp [h.symm, termg] |
| 100 | + |
| 101 | +theorem term_add_term {α} [add_comm_monoid α] (n₁ x a₁ n₂ a₂ n' a') |
| 102 | + (h₁ : n₁ + n₂ = n') (h₂ : a₁ + a₂ = a') : |
| 103 | + @term α _ n₁ x a₁ + @term α _ n₂ x a₂ = term n' x a' := |
| 104 | +by simp [h₁.symm, h₂.symm, term, add_monoid.add_smul] |
| 105 | + |
| 106 | +theorem term_add_termg {α} [add_comm_group α] (n₁ x a₁ n₂ a₂ n' a') |
| 107 | + (h₁ : n₁ + n₂ = n') (h₂ : a₁ + a₂ = a') : |
| 108 | + @termg α _ n₁ x a₁ + @termg α _ n₂ x a₂ = termg n' x a' := |
| 109 | +by simp [h₁.symm, h₂.symm, termg, add_gsmul] |
| 110 | + |
| 111 | +theorem zero_term {α} [add_comm_monoid α] (x a) : @term α _ 0 x a = a := |
| 112 | +by simp [term] |
| 113 | + |
| 114 | +theorem zero_termg {α} [add_comm_group α] (x a) : @termg α _ 0 x a = a := |
| 115 | +by simp [termg] |
| 116 | + |
| 117 | +meta def eval_add (c : cache) : normal_expr → normal_expr → tactic (normal_expr × expr) |
| 118 | +| (zero _) e₂ := do |
| 119 | + p ← mk_app ``zero_add [e₂], |
| 120 | + return (e₂, p) |
| 121 | +| e₁ (zero _) := do |
| 122 | + p ← mk_app ``add_zero [e₁], |
| 123 | + return (e₁, p) |
| 124 | +| he₁@(nterm e₁ n₁ x₁ a₁) he₂@(nterm e₂ n₂ x₂ a₂) := |
| 125 | + if expr.lex_lt x₁ x₂ then do |
| 126 | + (a', h) ← eval_add a₁ he₂, |
| 127 | + return (term' c n₁ x₁ a', c.iapp ``term_add_const [n₁.1, x₁, a₁, e₂, a', h]) |
| 128 | + else if x₁ ≠ x₂ then do |
| 129 | + (a', h) ← eval_add he₁ a₂, |
| 130 | + return (term' c n₂ x₂ a', c.iapp ``const_add_term [e₁, n₂.1, x₂, a₂, a', h]) |
| 131 | + else do |
| 132 | + (n', h₁) ← mk_app ``has_add.add [n₁.1, n₂.1] >>= norm_num, |
| 133 | + (a', h₂) ← eval_add a₁ a₂, |
| 134 | + let k := n₁.2 + n₂.2, |
| 135 | + let p₁ := c.iapp ``term_add_term [n₁.1, x₁, a₁, n₂.1, a₂, n', a', h₁, h₂], |
| 136 | + if k = 0 then do |
| 137 | + p ← mk_eq_trans p₁ (c.iapp ``zero_term [x₁, a']), |
| 138 | + return (a', p) |
| 139 | + else return (term' c (n', k) x₁ a', p₁) |
| 140 | + |
| 141 | +theorem term_neg {α} [add_comm_group α] (n x a n' a') |
| 142 | + (h₁ : -n = n') (h₂ : -a = a') : |
| 143 | + -@termg α _ n x a = termg n' x a' := |
| 144 | +by simp [h₂.symm, h₁.symm, termg] |
| 145 | + |
| 146 | +meta def eval_neg (c : cache) : normal_expr → tactic (normal_expr × expr) |
| 147 | +| (zero e) := do |
| 148 | + p ← c.mk_app ``neg_zero ``add_group [], |
| 149 | + return (zero' c, p) |
| 150 | +| (nterm e n x a) := do |
| 151 | + (n', h₁) ← mk_app ``has_neg.neg [n.1] >>= norm_num, |
| 152 | + (a', h₂) ← eval_neg a, |
| 153 | + return (term' c (n', -n.2) x a', |
| 154 | + c.app ``term_neg c.inst [n.1, x, a, n', a', h₁, h₂]) |
| 155 | + |
| 156 | +def smul {α} [add_comm_monoid α] (n : ℕ) (x : α) : α := add_monoid.smul n x |
| 157 | +def smulg {α} [add_comm_group α] (n : ℤ) (x : α) : α := gsmul n x |
| 158 | + |
| 159 | +theorem zero_smul {α} [add_comm_monoid α] (c) : smul c (0 : α) = 0 := |
| 160 | +by simp [smul] |
| 161 | + |
| 162 | +theorem zero_smulg {α} [add_comm_group α] (c) : smulg c (0 : α) = 0 := |
| 163 | +by simp [smulg] |
| 164 | + |
| 165 | +theorem term_smul {α} [add_comm_monoid α] (c n x a n' a') |
| 166 | + (h₁ : c * n = n') (h₂ : smul c a = a') : |
| 167 | + smul c (@term α _ n x a) = term n' x a' := |
| 168 | +by simp [h₂.symm, h₁.symm, term, smul, add_monoid.smul_add, add_monoid.mul_smul] |
| 169 | + |
| 170 | +theorem term_smulg {α} [add_comm_group α] (c n x a n' a') |
| 171 | + (h₁ : c * n = n') (h₂ : smulg c a = a') : |
| 172 | + smulg c (@termg α _ n x a) = termg n' x a' := |
| 173 | +by simp [h₂.symm, h₁.symm, termg, smulg, gsmul_add, gsmul_mul] |
| 174 | + |
| 175 | +meta def eval_smul (c : cache) (k : expr × ℤ) : |
| 176 | + normal_expr → tactic (normal_expr × expr) |
| 177 | +| (zero _) := return (zero' c, c.iapp ``zero_smul [k.1]) |
| 178 | +| (nterm e n x a) := do |
| 179 | + (n', h₁) ← mk_app ``has_mul.mul [k.1, n.1] >>= norm_num, |
| 180 | + (a', h₂) ← eval_smul a, |
| 181 | + return (term' c (n', k.2 * n.2) x a', |
| 182 | + c.iapp ``term_smul [k.1, n.1, x, a, n', a', h₁, h₂]) |
| 183 | + |
| 184 | +theorem term_atom {α} [add_comm_monoid α] (x : α) : x = term 1 x 0 := |
| 185 | +by simp [term] |
| 186 | + |
| 187 | +theorem term_atomg {α} [add_comm_group α] (x : α) : x = termg 1 x 0 := |
| 188 | +by simp [termg] |
| 189 | + |
| 190 | +meta def eval_atom (c : cache) (e : expr) : tactic (normal_expr × expr) := |
| 191 | +do n1 ← c.int_to_expr 1, |
| 192 | + return (term' c (n1, 1) e (zero' c), c.iapp ``term_atom [e]) |
| 193 | + |
| 194 | +lemma unfold_sub {α} [add_group α] (a b c : α) |
| 195 | + (h : a + -b = c) : a - b = c := h |
| 196 | + |
| 197 | +theorem unfold_smul {α} [add_comm_monoid α] (n) (x y : α) |
| 198 | + (h : smul n x = y) : add_monoid.smul n x = y := h |
| 199 | + |
| 200 | +theorem unfold_smulg {α} [add_comm_group α] (n : ℕ) (x y : α) |
| 201 | + (h : smulg (int.of_nat n) x = y) : add_monoid.smul n x = y := h |
| 202 | + |
| 203 | +theorem unfold_gsmul {α} [add_comm_group α] (n : ℤ) (x y : α) |
| 204 | + (h : smulg n x = y) : gsmul n x = y := h |
| 205 | + |
| 206 | +lemma subst_into_smul {α} [add_comm_monoid α] |
| 207 | + (l r tl tr t) (prl : l = tl) (prr : r = tr) |
| 208 | + (prt : @smul α _ tl tr = t) : smul l r = t := |
| 209 | +by simp [prl, prr, prt] |
| 210 | + |
| 211 | +lemma subst_into_smulg {α} [add_comm_group α] |
| 212 | + (l r tl tr t) (prl : l = tl) (prr : r = tr) |
| 213 | + (prt : @smulg α _ tl tr = t) : smulg l r = t := |
| 214 | +by simp [prl, prr, prt] |
| 215 | + |
| 216 | +meta def eval (c : cache) : expr → tactic (normal_expr × expr) |
| 217 | +| `(%%e₁ + %%e₂) := do |
| 218 | + (e₁', p₁) ← eval e₁, |
| 219 | + (e₂', p₂) ← eval e₂, |
| 220 | + (e', p') ← eval_add c e₁' e₂', |
| 221 | + p ← c.mk_app ``norm_num.subst_into_sum ``has_add [e₁, e₂, e₁', e₂', e', p₁, p₂, p'], |
| 222 | + return (e', p) |
| 223 | +| `(%%e₁ - %%e₂) := do |
| 224 | + e₂' ← mk_app ``has_neg.neg [e₂], |
| 225 | + e ← mk_app ``has_add.add [e₁, e₂'], |
| 226 | + (e', p) ← eval e, |
| 227 | + p' ← c.mk_app ``unfold_sub ``add_group [e₁, e₂, e', p], |
| 228 | + return (e', p') |
| 229 | +| `(- %%e) := do |
| 230 | + (e₁, p₁) ← eval e, |
| 231 | + (e₂, p₂) ← eval_neg c e₁, |
| 232 | + p ← c.mk_app ``norm_num.subst_into_neg ``has_neg [e, e₁, e₂, p₁, p₂], |
| 233 | + return (e₂, p) |
| 234 | +| `(add_monoid.smul %%e₁ %%e₂) := do |
| 235 | + n ← if c.is_group then mk_app ``int.of_nat [e₁] else return e₁, |
| 236 | + (e', p) ← eval $ c.iapp ``smul [n, e₂], |
| 237 | + return (e', c.iapp ``unfold_smul [e₁, e₂, e', p]) |
| 238 | +| `(gsmul %%e₁ %%e₂) := do |
| 239 | + guardb c.is_group, |
| 240 | + (e', p) ← eval $ c.iapp ``smul [e₁, e₂], |
| 241 | + return (e', c.app ``unfold_gsmul c.inst [e₁, e₂, e', p]) |
| 242 | +| `(smul %%e₁ %%e₂) := do |
| 243 | + guard (¬ c.is_group), |
| 244 | + (e₁', p₁) ← norm_num.derive e₁ <|> refl_conv e₁, n ← e₁'.to_nat, |
| 245 | + (e₂', p₂) ← eval e₂, |
| 246 | + (e', p) ← eval_smul c (e₁', n) e₂', |
| 247 | + return (e', c.iapp ``subst_into_smul [e₁, e₂, e₁', e₂', e', p₁, p₂, p]) |
| 248 | +| `(smulg %%e₁ %%e₂) := do |
| 249 | + guardb c.is_group, |
| 250 | + (e₁', p₁) ← norm_num.derive e₁ <|> refl_conv e₁, n ← e₁'.to_int, |
| 251 | + (e₂', p₂) ← eval e₂, |
| 252 | + (e', p) ← eval_smul c (e₁', n) e₂', |
| 253 | + return (e', c.iapp ``subst_into_smul [e₁, e₂, e₁', e₂', e', p₁, p₂, p]) |
| 254 | +| e := eval_atom c e |
| 255 | + |
| 256 | +meta def eval' (c : cache) (e : expr) : tactic (expr × expr) := |
| 257 | +do (e', p) ← eval c e, return (e', p) |
| 258 | + |
| 259 | +@[derive has_reflect] |
| 260 | +inductive normalize_mode | raw | term |
| 261 | + |
| 262 | +meta def normalize (mode := normalize_mode.term) (e : expr) : tactic (expr × expr) := do |
| 263 | +pow_lemma ← simp_lemmas.mk.add_simp ``pow_one, |
| 264 | +let lemmas := match mode with |
| 265 | +| normalize_mode.term := |
| 266 | + [``term.equations._eqn_1, ``termg.equations._eqn_1, |
| 267 | + ``add_zero, ``add_monoid.one_smul, ``one_gsmul] |
| 268 | +| _ := [] |
| 269 | +end, |
| 270 | +lemmas ← lemmas.mfoldl simp_lemmas.add_simp simp_lemmas.mk, |
| 271 | +(_, e', pr) ← ext_simplify_core () {} |
| 272 | + simp_lemmas.mk (λ _, failed) (λ _ _ _ _ e, do |
| 273 | + c ← mk_cache e, |
| 274 | + (new_e, pr) ← match mode with |
| 275 | + | normalize_mode.raw := eval' c |
| 276 | + | normalize_mode.term := trans_conv (eval' c) (simplify lemmas []) |
| 277 | + end e, |
| 278 | + guard (¬ new_e =ₐ e), |
| 279 | + return ((), new_e, some pr, ff)) |
| 280 | + (λ _ _ _ _ _, failed) `eq e, |
| 281 | +return (e', pr) |
| 282 | + |
| 283 | +end abel |
| 284 | + |
| 285 | +namespace interactive |
| 286 | +open interactive interactive.types lean.parser |
| 287 | +open tactic.abel |
| 288 | + |
| 289 | +local postfix `?`:9001 := optional |
| 290 | + |
| 291 | +/-- Tactic for solving equations in the language of abels. |
| 292 | + This version of `abel` fails if the target is not an equality |
| 293 | + that is provable by the axioms of commutative (semi)abels. -/ |
| 294 | +meta def abel1 : tactic unit := |
| 295 | +do `(%%e₁ = %%e₂) ← target, |
| 296 | + c ← mk_cache e₁, |
| 297 | + (e₁', p₁) ← eval c e₁, |
| 298 | + (e₂', p₂) ← eval c e₂, |
| 299 | + is_def_eq e₁' e₂', |
| 300 | + p ← mk_eq_symm p₂ >>= mk_eq_trans p₁, |
| 301 | + tactic.exact p |
| 302 | + |
| 303 | +meta def abel.mode : lean.parser abel.normalize_mode := |
| 304 | +with_desc "(raw|term)?" $ |
| 305 | +do mode ← ident?, match mode with |
| 306 | +| none := return abel.normalize_mode.term |
| 307 | +| some `term := return abel.normalize_mode.term |
| 308 | +| some `raw := return abel.normalize_mode.raw |
| 309 | +| _ := failed |
| 310 | +end |
| 311 | + |
| 312 | +/-- Tactic for solving equations in the language of |
| 313 | + commutative monoids and groups. |
| 314 | + Attempts to prove the goal outright if there is no `at` |
| 315 | + specifier and the target is an equality, but if this |
| 316 | + fails it falls back to rewriting all monoid expressions |
| 317 | + into a normal form. -/ |
| 318 | +meta def abel (SOP : parse abel.mode) (loc : parse location) : tactic unit := |
| 319 | +match loc with |
| 320 | +| interactive.loc.ns [none] := abel1 |
| 321 | +| _ := failed |
| 322 | +end <|> |
| 323 | +do ns ← loc.get_locals, |
| 324 | + tt ← tactic.replace_at (normalize SOP) ns loc.include_goal |
| 325 | + | fail "abel failed to simplify", |
| 326 | + when loc.include_goal $ try tactic.reflexivity |
| 327 | + |
| 328 | +end interactive |
| 329 | +end tactic |
0 commit comments