Skip to content

Commit

Permalink
feat(tactic/linarith): treat expr atoms up to defeq
Browse files Browse the repository at this point in the history
  • Loading branch information
robertylewis committed Apr 19, 2019
1 parent 3c1dce1 commit 38dc0a7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
44 changes: 25 additions & 19 deletions src/tactic/linarith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,26 @@ match c1.keys, c2.keys with
| _, _ := none
end

meta def list.mfind {α} (tac : α → tactic unit) : list α → tactic α
| [] := failed
| (h::t) := tac h >> return h <|> list.mfind t

meta def rb_map.find_defeq {v} (m : expr_map v) (e : expr) : tactic v :=
prod.snd <$> list.mfind (λ p, is_def_eq e p.1) m.to_list

/--
Turns an expression into a map from ℕ to ℤ, for use in a comp object.
The expr_map ℕ argument identifies which expressions have already been assigned numbers.
Returns a new map.
-/
meta def map_of_expr : expr_map ℕ → expr → option (expr_map ℕ × rb_map ℕ ℤ)
meta def map_of_expr : expr_map ℕ → expr → tactic (expr_map ℕ × rb_map ℕ ℤ)
| m e@`(%%e1 * %%e2) :=
(do (m', comp1) ← map_of_expr m e1,
(m', comp2) ← map_of_expr m' e2,
mp ← map_of_expr_mul_aux comp1 comp2,
return (m', mp)) <|>
(match m.find e with
| some k := return (m, mk_rb_map.insert k 1)
| none := let n := m.size + 1 in return (m.insert e n, mk_rb_map.insert n 1)
end)
(do k ← rb_map.find_defeq m e, return (m, mk_rb_map.insert k 1)) <|>
(let n := m.size + 1 in return (m.insert e n, mk_rb_map.insert n 1))
| m `(%%e1 + %%e2) :=
do (m', comp1) ← map_of_expr m e1,
(m', comp2) ← map_of_expr m' e2,
Expand All @@ -307,12 +312,12 @@ meta def map_of_expr : expr_map ℕ → expr → option (expr_map ℕ × rb_map
return (m', comp1.add (comp2.scale (-1)))
| m `(-%%e) := do (m', comp) ← map_of_expr m e, return (m', comp.scale (-1))
| m e :=
match e.to_int, m.find e with
| some 0, _ := return ⟨m, mk_rb_map⟩
| some z, _ := return ⟨m, mk_rb_map.insert 0 z⟩
| none, some k := return (m, mk_rb_map.insert k 1)
| none, none := let n := m.size + 1 in
return (m.insert e n, mk_rb_map.insert n 1)
match e.to_int with
| some 0 := return ⟨m, mk_rb_map⟩
| some z := return ⟨m, mk_rb_map.insert 0 z⟩
| none :=
(do k ← rb_map.find_defeq m e, return (m, mk_rb_map.insert k 1)) <|>
(let n := m.size + 1 in return (m.insert e n, mk_rb_map.insert n 1))
end

meta def parse_into_comp_and_expr : expr → option (ineq × expr)
Expand All @@ -321,26 +326,27 @@ meta def parse_into_comp_and_expr : expr → option (ineq × expr)
| `(%%e = 0) := (ineq.eq, e)
| _ := none

meta def to_comp (e : expr) (m : expr_map ℕ) : option (comp × expr_map ℕ) :=
meta def to_comp (e : expr) (m : expr_map ℕ) : tactic (comp × expr_map ℕ) :=
do (iq, e) ← parse_into_comp_and_expr e,
(m', comp') ← map_of_expr m e,
return ⟨⟨iq, comp'⟩, m'⟩

meta def to_comp_fold : expr_map ℕ → list expr →
(list (option comp) × expr_map ℕ)
| m [] := ([], m)
tactic (list (option comp) × expr_map ℕ)
| m [] := return ([], m)
| m (h::t) :=
match to_comp h m with
| some (c, m') := let (l, mp) := to_comp_fold m' t in (c::l, mp)
| none := let (l, mp) := to_comp_fold m t in (none::l, mp)
end
(do (c, m') ← to_comp h m,
(l, mp) ← to_comp_fold m' t,
return (c::l, mp)) <|>
(do (l, mp) ← to_comp_fold m t,
return (none::l, mp))

/--
Takes a list of proofs of props of the form t {<, ≤, =} 0, and creates a linarith_structure.
-/
meta def mk_linarith_structure (l : list expr) : tactic (linarith_structure × rb_map ℕ (expr × expr)) :=
do pftps ← l.mmap infer_type,
let (l', map) := to_comp_fold mk_rb_map pftps,
(l', map) to_comp_fold mk_rb_map pftps,
let lz := list.enum $ ((l.zip pftps).zip l').filter_map (λ ⟨a, b⟩, prod.mk a <$> b),
let prmap := rb_map.of_list $ lz.map (λ ⟨n, x⟩, (n, x.1)),
let vars : rb_set ℕ := rb_map.set_of_list $ list.range map.size.succ,
Expand Down
3 changes: 3 additions & 0 deletions test/linarith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ by linarith

example (a : ℚ) (ha : 0 ≤ a): 0 * 02 * a :=
by linarith

example (x : ℚ) : id x ≥ x :=
by linarith

0 comments on commit 38dc0a7

Please sign in to comment.