Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 85b8bdc

Browse files
committed
perf(tactic/linarith): use key/value lists instead of rb_maps to represent linear expressions (#3004)
This has essentially no effect on the test file, but scales much better. See discussion at https://leanprover.zulipchat.com/#narrow/stream/187764-Lean-for.20teaching/topic/Real.20analysis for an example which is in reach with this change. Co-authored-by: Rob Lewis <rob.y.lewis@gmail.com>
1 parent 7f60a62 commit 85b8bdc

File tree

3 files changed

+229
-27
lines changed

3 files changed

+229
-27
lines changed

src/meta/rb_map.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ s.fold (pure s) (λ a m,
3535
meta def union {key} (s t : rb_set key) : rb_set key :=
3636
s.fold t (λ a t, t.insert a)
3737

38+
/--
39+
`of_list_core empty l` turns a list of keys into an `rb_set`.
40+
It takes a user_provided `rb_set` to use for the base case.
41+
This can be used to pre-seed the set with additional elements,
42+
and/or to use a custom comparison operator.
43+
-/
44+
meta def of_list_core {key} (base : native.rb_set key) : list key → native.rb_map key unit
45+
| [] := base
46+
| (x::xs) := native.rb_set.insert (of_list_core xs) x
47+
3848
end rb_set
3949

4050
namespace rb_map

src/tactic/linarith.lean

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,87 @@ by rw [←h3, mul_assoc, mul_div_comm, h2, ←mul_assoc, h1, mul_comm, one_mul]
9898

9999
end lemmas
100100

101+
/--
102+
A linear expression is a list of pairs of variable indices and coefficients.
103+
104+
Some functions on `linexp` assume that `n : ℕ` occurs at most once as the first element of a pair,
105+
and that the list is sorted in decreasing order of the first argument.
106+
This is not enforced by the type but the operations here preserve it.
107+
-/
108+
@[reducible]
109+
def linexp : Type := list (ℕ × ℤ)
110+
end linarith
111+
112+
/--
113+
A map `ℕ → ℤ` is converted to `list (ℕ × ℤ)` in the obvious way.
114+
This list is sorted in decreasing order of the first argument.
115+
-/
116+
meta def native.rb_map.to_linexp (m : rb_map ℕ ℤ) : linarith.linexp :=
117+
m.to_list
118+
119+
namespace linarith
120+
namespace linexp
121+
122+
/--
123+
Add two `linexp`s together componentwise.
124+
Preserves sorting and uniqueness of the first argument.
125+
-/
126+
meta def add : linexp → linexp → linexp
127+
| [] a := a
128+
| a [] := a
129+
| (a@(n1,z1)::t1) (b@(n2,z2)::t2) :=
130+
if n1 < n2 then b::add (a::t1) t2
131+
else if n2 < n1 then a::add t1 (b::t2)
132+
else let sum := z1 + z2 in if sum = 0 then add t1 t2 else (n1, sum)::add t1 t2
133+
134+
/-- `l.scale c` scales the values in `l` by `c` without modifying the order or keys. -/
135+
def scale (c : ℤ) (l : linexp) : linexp :=
136+
if c = 0 then []
137+
else if c = 1 then l
138+
else l.map $ λ ⟨n, z⟩, (n, z*c)
139+
140+
/--
141+
`l.get n` returns the value in `l` associated with key `n`, if it exists, and `none` otherwise.
142+
This function assumes that `l` is sorted in decreasing order of the first argument,
143+
that is, it will return `none` as soon as it finds a key smaller than `n`.
144+
-/
145+
def get (n : ℕ) : linexp → option ℤ
146+
| [] := none
147+
| ((a, b)::t) :=
148+
if a < n then none
149+
else if a = n then some b
150+
else get t
151+
152+
/--
153+
`l.contains n` is true iff `n` is the first element of a pair in `l`.
154+
-/
155+
def contains (n : ℕ) : linexp → bool := option.is_some ∘ get n
156+
157+
/--
158+
`l.zfind n` returns the value associated with key `n` if there is one, and 0 otherwise.
159+
-/
160+
def zfind (n : ℕ) (l : linexp) : ℤ :=
161+
match l.get n with
162+
| none := 0
163+
| some v := v
164+
end
165+
166+
/--
167+
Defines a lex ordering on `linexp`. This function is performance critical.
168+
-/
169+
def cmp : linexp → linexp → ordering
170+
| [] [] := ordering.eq
171+
| [] _ := ordering.lt
172+
| _ [] := ordering.gt
173+
| ((n1,z1)::t1) ((n2,z2)::t2) :=
174+
if n1 < n2 then ordering.lt
175+
else if n2 < n1 then ordering.gt
176+
else if z1 < z2 then ordering.lt
177+
else if z2 < z1 then ordering.gt
178+
else cmp t1 t2
179+
180+
end linexp
181+
101182
section datatypes
102183

103184
@[derive decidable_eq, derive inhabited]
@@ -111,11 +192,14 @@ def ineq.max : ineq → ineq → ineq
111192
| le a := a
112193
| lt a := lt
113194

114-
def ineq.is_lt : ineq → ineq → bool
115-
| eq le := tt
116-
| eq lt := tt
117-
| le lt := tt
118-
| _ _ := ff
195+
/-- `ineq` is ordered `eq < le < lt`. -/
196+
def ineq.cmp : ineq → ineq → ordering
197+
| eq eq := ordering.eq
198+
| eq _ := ordering.lt
199+
| le le := ordering.eq
200+
| le lt := ordering.lt
201+
| lt lt := ordering.eq
202+
| _ _ := ordering.gt
119203

120204
def ineq.to_string : ineq → string
121205
| eq := "="
@@ -128,15 +212,15 @@ instance : has_to_string ineq := ⟨ineq.to_string⟩
128212
The main datatype for FM elimination.
129213
Variables are represented by natural numbers, each of which has an integer coefficient.
130214
Index 0 is reserved for constants, i.e. `coeffs.find 0` is the coefficient of 1.
131-
The represented term is `coeffs.keys.sum (λ i, coeffs.find i * Var[i])`.
215+
The represented term is `coeffs.sum (λ ⟨k, v⟩, v * Var[k])`.
132216
str determines the direction of the comparison -- is it < 0, ≤ 0, or = 0?
133217
-/
134218
@[derive inhabited]
135-
meta structure comp :=
219+
meta structure comp : Type :=
136220
(str : ineq)
137-
(coeffs : rb_map ℕ int)
221+
(coeffs : linexp)
138222

139-
meta inductive comp_source
223+
meta inductive comp_source : Type
140224
| assump : ℕ → comp_source
141225
| add : comp_source → comp_source → comp_source
142226
| scale : ℕ → comp_source → comp_source
@@ -158,23 +242,27 @@ meta structure pcomp :=
158242
(c : comp)
159243
(src : comp_source)
160244

161-
meta def map_lt (m1 m2 : rb_map ℕ int) : bool :=
162-
list.lex (prod.lex (<) (<)) m1.to_list m2.to_list
163-
164-
-- make more efficient
165-
meta def comp.lt (c1 c2 : comp) : bool :=
166-
(c1.str.is_lt c2.str) || (c1.str = c2.str) && map_lt c1.coeffs c2.coeffs
245+
/-- `comp` has a lex order. First the `ineq`s are compared, then the `coeff`s. -/
246+
meta def comp.cmp : comp → comp → ordering
247+
| ⟨str1, coeffs1⟩ ⟨str2, coeffs2⟩ :=
248+
match str1.cmp str2 with
249+
| ordering.lt := ordering.lt
250+
| ordering.gt := ordering.gt
251+
| ordering.eq := coeffs1.cmp coeffs2
252+
end
167253

168-
meta instance comp.has_lt : has_lt comp := ⟨λ a b, comp.lt a b⟩
169-
meta instance pcomp.has_lt : has_lt pcomp := ⟨λ p1 p2, p1.c < p2.c⟩
170-
-- short-circuit type class inference
171-
meta instance pcomp.has_lt_dec : decidable_rel ((<) : pcomp → pcomp → Prop) := by apply_instance
254+
/--
255+
The `comp_source` field is ignored when comparing `pcomp`s. Two `pcomp`s proving the same
256+
comparison, with different sources, are considered equivalent.
257+
-/
258+
meta def pcomp.cmp (p1 p2 : pcomp) : ordering :=
259+
p1.c.cmp p2.c
172260

173261
meta def comp.coeff_of (c : comp) (a : ℕ) : ℤ :=
174262
c.coeffs.zfind a
175263

176264
meta def comp.scale (c : comp) (n : ℕ) : comp :=
177-
{ c with coeffs := c.coeffs.map ((*) (n : ℤ)) }
265+
{ c with coeffs := c.coeffs.scale n }
178266

179267
meta def comp.add (c1 c2 : comp) : comp :=
180268
⟨c1.str.max c2.str, c1.coeffs.add c2.coeffs⟩
@@ -191,6 +279,10 @@ meta instance pcomp.to_format : has_to_format pcomp :=
191279
meta instance comp.to_format : has_to_format comp :=
192280
⟨λ p, to_fmt p.coeffs ++ to_string p.str ++ "0"
193281

282+
/-- Creates an empty set of `pcomp`s, sorted using `pcomp.cmp`. -/
283+
meta def mk_pcomp_set : rb_set pcomp :=
284+
rb_map.mk_core unit pcomp.cmp
285+
194286
end datatypes
195287

196288
section fm_elim
@@ -216,8 +308,8 @@ meta def comp.is_contr (c : comp) : bool := c.coeffs.empty ∧ c.str = ineq.lt
216308
meta def pcomp.is_contr (p : pcomp) : bool := p.c.is_contr
217309

218310
meta def elim_with_set (a : ℕ) (p : pcomp) (comps : rb_set pcomp) : rb_set pcomp :=
219-
if ¬ p.c.coeffs.contains a then mk_rb_set.insert p else
220-
comps.fold mk_rb_set $ λ pc s,
311+
if ¬ p.c.coeffs.contains a then mk_pcomp_set.insert p else
312+
comps.fold mk_pcomp_set $ λ pc s,
221313
match pelim_var p pc a with
222314
| some pc := s.insert pc
223315
| none := s
@@ -263,7 +355,7 @@ meta def monad.elim_var (a : ℕ) : linarith_monad unit :=
263355
do vs ← get_vars,
264356
when (vs.contains a) $
265357
do comps ← get_comps,
266-
let cs' := comps.fold mk_rb_set (λ p s, s.union (elim_with_set a p comps)),
358+
let cs' := comps.fold mk_pcomp_set (λ p s, s.union (elim_with_set a p comps)),
267359
update (vs.erase a) cs'
268360

269361
meta def elim_all_vars : linarith_monad unit :=
@@ -400,7 +492,7 @@ meta def to_comp (red : transparency) (e : expr) (m : expr_map ℕ) (mm : rb_map
400492
do (iq, e) ← parse_into_comp_and_expr e,
401493
(m', comp') ← map_of_expr red m e,
402494
let ⟨nm, mm'⟩ := sum_to_lf comp' mm,
403-
return ⟨⟨iq, mm'⟩,m',nm⟩
495+
return ⟨⟨iq, mm'.to_linexp⟩,m',nm⟩
404496

405497
meta def to_comp_fold (red : transparency) : expr_map ℕ → list expr → rb_map monom ℕ →
406498
tactic (list (option comp) × expr_map ℕ × rb_map monom ℕ )
@@ -422,9 +514,9 @@ do pftps ← l.mmap infer_type,
422514
(l', _, map) ← to_comp_fold red mk_rb_map pftps mk_rb_map,
423515
let lz := list.enum $ ((l.zip pftps).zip l').filter_map (λ ⟨a, b⟩, prod.mk a <$> b),
424516
let prmap := rb_map.of_list $ lz.map (λ ⟨n, x⟩, (n, x.1)),
425-
let vars : rb_set ℕ := rb_map.set_of_list $ list.range map.size.succ,
426-
let pc : rb_set pcomp := rb_map.set_of_list $
427-
lz.map (λ ⟨n, x⟩, ⟨x.2, comp_source.assump n⟩),
517+
let vars : rb_set ℕ := rb_map.set_of_list $ list.range map.size,
518+
let pc : rb_set pcomp :=
519+
rb_set.of_list_core mk_pcomp_set $ lz.map (λ ⟨n, x⟩, ⟨x.2, comp_source.assump n⟩),
428520
return (⟨vars, pc⟩, prmap)
429521

430522
meta def linarith_monad.run (red : transparency) {α} (tac : linarith_monad α) (l : list expr) :

test/linarith.lean

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tactic.linarith
22

3+
34
example (e b c a v0 v1 : ℚ) (h1 : v0 = 5*a) (h2 : v1 = 3*b) (h3 : v0 + v1 + c = 10) :
45
v0 + 5 + (v1 - 3) + (c - 2) = 10 :=
56
by linarith
@@ -168,3 +169,102 @@ by linarith only [hx]
168169
example (x y : ℚ) (h : x < y) : x ≠ y := by linarith
169170

170171
example (x y : ℚ) (h : x < y) : ¬ x = y := by linarith
172+
173+
lemma test6 (u v x y A B : ℚ)
174+
175+
(a : 0 < A)
176+
(a_1 : 0 <= 1 - A)
177+
(a_2 : 0 <= B - 1)
178+
(a_3 : 0 <= B - x)
179+
(a_4 : 0 <= B - y)
180+
(a_5 : 0 <= u)
181+
(a_6 : 0 <= v)
182+
(a_7 : 0 < A - u)
183+
(a_8 : 0 < A - v) :
184+
(0 < A * A)
185+
-> (0 <= A * (1 - A))
186+
-> (0 <= A * (B - 1))
187+
-> (0 <= A * (B - x))
188+
-> (0 <= A * (B - y))
189+
-> (0 <= A * u)
190+
-> (0 <= A * v)
191+
-> (0 < A * (A - u))
192+
-> (0 < A * (A - v))
193+
-> (0 <= (1 - A) * A)
194+
-> (0 <= (1 - A) * (1 - A))
195+
-> (0 <= (1 - A) * (B - 1))
196+
-> (0 <= (1 - A) * (B - x))
197+
-> (0 <= (1 - A) * (B - y))
198+
-> (0 <= (1 - A) * u)
199+
-> (0 <= (1 - A) * v)
200+
-> (0 <= (1 - A) * (A - u))
201+
-> (0 <= (1 - A) * (A - v))
202+
-> (0 <= (B - 1) * A)
203+
-> (0 <= (B - 1) * (1 - A))
204+
-> (0 <= (B - 1) * (B - 1))
205+
-> (0 <= (B - 1) * (B - x))
206+
-> (0 <= (B - 1) * (B - y))
207+
-> (0 <= (B - 1) * u)
208+
-> (0 <= (B - 1) * v)
209+
-> (0 <= (B - 1) * (A - u))
210+
-> (0 <= (B - 1) * (A - v))
211+
-> (0 <= (B - x) * A)
212+
-> (0 <= (B - x) * (1 - A))
213+
-> (0 <= (B - x) * (B - 1))
214+
-> (0 <= (B - x) * (B - x))
215+
-> (0 <= (B - x) * (B - y))
216+
-> (0 <= (B - x) * u)
217+
-> (0 <= (B - x) * v)
218+
-> (0 <= (B - x) * (A - u))
219+
-> (0 <= (B - x) * (A - v))
220+
-> (0 <= (B - y) * A)
221+
-> (0 <= (B - y) * (1 - A))
222+
-> (0 <= (B - y) * (B - 1))
223+
-> (0 <= (B - y) * (B - x))
224+
-> (0 <= (B - y) * (B - y))
225+
-> (0 <= (B - y) * u)
226+
-> (0 <= (B - y) * v)
227+
-> (0 <= (B - y) * (A - u))
228+
-> (0 <= (B - y) * (A - v))
229+
-> (0 <= u * A)
230+
-> (0 <= u * (1 - A))
231+
-> (0 <= u * (B - 1))
232+
-> (0 <= u * (B - x))
233+
-> (0 <= u * (B - y))
234+
-> (0 <= u * u)
235+
-> (0 <= u * v)
236+
-> (0 <= u * (A - u))
237+
-> (0 <= u * (A - v))
238+
-> (0 <= v * A)
239+
-> (0 <= v * (1 - A))
240+
-> (0 <= v * (B - 1))
241+
-> (0 <= v * (B - x))
242+
-> (0 <= v * (B - y))
243+
-> (0 <= v * u)
244+
-> (0 <= v * v)
245+
-> (0 <= v * (A - u))
246+
-> (0 <= v * (A - v))
247+
-> (0 < (A - u) * A)
248+
-> (0 <= (A - u) * (1 - A))
249+
-> (0 <= (A - u) * (B - 1))
250+
-> (0 <= (A - u) * (B - x))
251+
-> (0 <= (A - u) * (B - y))
252+
-> (0 <= (A - u) * u)
253+
-> (0 <= (A - u) * v)
254+
-> (0 < (A - u) * (A - u))
255+
-> (0 < (A - u) * (A - v))
256+
-> (0 < (A - v) * A)
257+
-> (0 <= (A - v) * (1 - A))
258+
-> (0 <= (A - v) * (B - 1))
259+
-> (0 <= (A - v) * (B - x))
260+
-> (0 <= (A - v) * (B - y))
261+
-> (0 <= (A - v) * u)
262+
-> (0 <= (A - v) * v)
263+
-> (0 < (A - v) * (A - u))
264+
-> (0 < (A - v) * (A - v))
265+
->
266+
u * y + v * x + u * v < 3 * A * B :=
267+
begin
268+
intros,
269+
linarith
270+
end

0 commit comments

Comments
 (0)