@@ -98,6 +98,87 @@ by rw [←h3, mul_assoc, mul_div_comm, h2, ←mul_assoc, h1, mul_comm, one_mul]
98
98
99
99
end lemmas
100
100
101
+ /--
102
+ A linear expression is a list of pairs of variable indices and coefficients.
103
+
104
+ Some functions on `linexp` assume that `n : ℕ` occurs at most once as the first element of a pair,
105
+ and that the list is sorted in decreasing order of the first argument.
106
+ This is not enforced by the type but the operations here preserve it.
107
+ -/
108
+ @[reducible]
109
+ def linexp : Type := list (ℕ × ℤ)
110
+ end linarith
111
+
112
+ /--
113
+ A map `ℕ → ℤ` is converted to `list (ℕ × ℤ)` in the obvious way.
114
+ This list is sorted in decreasing order of the first argument.
115
+ -/
116
+ meta def native.rb_map.to_linexp (m : rb_map ℕ ℤ) : linarith.linexp :=
117
+ m.to_list
118
+
119
+ namespace linarith
120
+ namespace linexp
121
+
122
+ /--
123
+ Add two `linexp`s together componentwise.
124
+ Preserves sorting and uniqueness of the first argument.
125
+ -/
126
+ meta def add : linexp → linexp → linexp
127
+ | [] a := a
128
+ | a [] := a
129
+ | (a@(n1,z1)::t1) (b@(n2,z2)::t2) :=
130
+ if n1 < n2 then b::add (a::t1) t2
131
+ else if n2 < n1 then a::add t1 (b::t2)
132
+ else let sum := z1 + z2 in if sum = 0 then add t1 t2 else (n1, sum)::add t1 t2
133
+
134
+ /-- `l.scale c` scales the values in `l` by `c` without modifying the order or keys. -/
135
+ def scale (c : ℤ) (l : linexp) : linexp :=
136
+ if c = 0 then []
137
+ else if c = 1 then l
138
+ else l.map $ λ ⟨n, z⟩, (n, z*c)
139
+
140
+ /--
141
+ `l.get n` returns the value in `l` associated with key `n`, if it exists, and `none` otherwise.
142
+ This function assumes that `l` is sorted in decreasing order of the first argument,
143
+ that is, it will return `none` as soon as it finds a key smaller than `n`.
144
+ -/
145
+ def get (n : ℕ) : linexp → option ℤ
146
+ | [] := none
147
+ | ((a, b)::t) :=
148
+ if a < n then none
149
+ else if a = n then some b
150
+ else get t
151
+
152
+ /--
153
+ `l.contains n` is true iff `n` is the first element of a pair in `l`.
154
+ -/
155
+ def contains (n : ℕ) : linexp → bool := option.is_some ∘ get n
156
+
157
+ /--
158
+ `l.zfind n` returns the value associated with key `n` if there is one, and 0 otherwise.
159
+ -/
160
+ def zfind (n : ℕ) (l : linexp) : ℤ :=
161
+ match l.get n with
162
+ | none := 0
163
+ | some v := v
164
+ end
165
+
166
+ /--
167
+ Defines a lex ordering on `linexp`. This function is performance critical.
168
+ -/
169
+ def cmp : linexp → linexp → ordering
170
+ | [] [] := ordering.eq
171
+ | [] _ := ordering.lt
172
+ | _ [] := ordering.gt
173
+ | ((n1,z1)::t1) ((n2,z2)::t2) :=
174
+ if n1 < n2 then ordering.lt
175
+ else if n2 < n1 then ordering.gt
176
+ else if z1 < z2 then ordering.lt
177
+ else if z2 < z1 then ordering.gt
178
+ else cmp t1 t2
179
+
180
+ end linexp
181
+
101
182
section datatypes
102
183
103
184
@[derive decidable_eq, derive inhabited]
@@ -111,11 +192,14 @@ def ineq.max : ineq → ineq → ineq
111
192
| le a := a
112
193
| lt a := lt
113
194
114
- def ineq.is_lt : ineq → ineq → bool
115
- | eq le := tt
116
- | eq lt := tt
117
- | le lt := tt
118
- | _ _ := ff
195
+ /-- `ineq` is ordered `eq < le < lt`. -/
196
+ def ineq.cmp : ineq → ineq → ordering
197
+ | eq eq := ordering.eq
198
+ | eq _ := ordering.lt
199
+ | le le := ordering.eq
200
+ | le lt := ordering.lt
201
+ | lt lt := ordering.eq
202
+ | _ _ := ordering.gt
119
203
120
204
def ineq.to_string : ineq → string
121
205
| eq := " ="
@@ -128,15 +212,15 @@ instance : has_to_string ineq := ⟨ineq.to_string⟩
128
212
The main datatype for FM elimination.
129
213
Variables are represented by natural numbers, each of which has an integer coefficient.
130
214
Index 0 is reserved for constants, i.e. `coeffs.find 0` is the coefficient of 1.
131
- The represented term is `coeffs.keys. sum (λ i, coeffs.find i * Var[i ])`.
215
+ The represented term is `coeffs.sum (λ ⟨k, v⟩, v * Var[k ])`.
132
216
str determines the direction of the comparison -- is it < 0, ≤ 0, or = 0?
133
217
-/
134
218
@[derive inhabited]
135
- meta structure comp :=
219
+ meta structure comp : Type : =
136
220
(str : ineq)
137
- (coeffs : rb_map ℕ int )
221
+ (coeffs : linexp )
138
222
139
- meta inductive comp_source
223
+ meta inductive comp_source : Type
140
224
| assump : ℕ → comp_source
141
225
| add : comp_source → comp_source → comp_source
142
226
| scale : ℕ → comp_source → comp_source
@@ -158,23 +242,27 @@ meta structure pcomp :=
158
242
(c : comp)
159
243
(src : comp_source)
160
244
161
- meta def map_lt (m1 m2 : rb_map ℕ int) : bool :=
162
- list.lex (prod.lex (<) (<)) m1.to_list m2.to_list
163
-
164
- -- make more efficient
165
- meta def comp.lt (c1 c2 : comp) : bool :=
166
- (c1.str.is_lt c2.str) || (c1.str = c2.str) && map_lt c1.coeffs c2.coeffs
245
+ /-- `comp` has a lex order. First the `ineq`s are compared, then the `coeff`s. -/
246
+ meta def comp.cmp : comp → comp → ordering
247
+ | ⟨str1, coeffs1⟩ ⟨str2, coeffs2⟩ :=
248
+ match str1.cmp str2 with
249
+ | ordering.lt := ordering.lt
250
+ | ordering.gt := ordering.gt
251
+ | ordering.eq := coeffs1.cmp coeffs2
252
+ end
167
253
168
- meta instance comp.has_lt : has_lt comp := ⟨λ a b, comp.lt a b⟩
169
- meta instance pcomp.has_lt : has_lt pcomp := ⟨λ p1 p2, p1.c < p2.c⟩
170
- -- short-circuit type class inference
171
- meta instance pcomp.has_lt_dec : decidable_rel ((<) : pcomp → pcomp → Prop ) := by apply_instance
254
+ /--
255
+ The `comp_source` field is ignored when comparing `pcomp`s. Two `pcomp`s proving the same
256
+ comparison, with different sources, are considered equivalent.
257
+ -/
258
+ meta def pcomp.cmp (p1 p2 : pcomp) : ordering :=
259
+ p1.c.cmp p2.c
172
260
173
261
meta def comp.coeff_of (c : comp) (a : ℕ) : ℤ :=
174
262
c.coeffs.zfind a
175
263
176
264
meta def comp.scale (c : comp) (n : ℕ) : comp :=
177
- { c with coeffs := c.coeffs.map ((*) (n : ℤ)) }
265
+ { c with coeffs := c.coeffs.scale n }
178
266
179
267
meta def comp.add (c1 c2 : comp) : comp :=
180
268
⟨c1.str.max c2.str, c1.coeffs.add c2.coeffs⟩
@@ -191,6 +279,10 @@ meta instance pcomp.to_format : has_to_format pcomp :=
191
279
meta instance comp.to_format : has_to_format comp :=
192
280
⟨λ p, to_fmt p.coeffs ++ to_string p.str ++ " 0" ⟩
193
281
282
+ /-- Creates an empty set of `pcomp`s, sorted using `pcomp.cmp`. -/
283
+ meta def mk_pcomp_set : rb_set pcomp :=
284
+ rb_map.mk_core unit pcomp.cmp
285
+
194
286
end datatypes
195
287
196
288
section fm_elim
@@ -216,8 +308,8 @@ meta def comp.is_contr (c : comp) : bool := c.coeffs.empty ∧ c.str = ineq.lt
216
308
meta def pcomp.is_contr (p : pcomp) : bool := p.c.is_contr
217
309
218
310
meta def elim_with_set (a : ℕ) (p : pcomp) (comps : rb_set pcomp) : rb_set pcomp :=
219
- if ¬ p.c.coeffs.contains a then mk_rb_set .insert p else
220
- comps.fold mk_rb_set $ λ pc s,
311
+ if ¬ p.c.coeffs.contains a then mk_pcomp_set .insert p else
312
+ comps.fold mk_pcomp_set $ λ pc s,
221
313
match pelim_var p pc a with
222
314
| some pc := s.insert pc
223
315
| none := s
@@ -263,7 +355,7 @@ meta def monad.elim_var (a : ℕ) : linarith_monad unit :=
263
355
do vs ← get_vars,
264
356
when (vs.contains a) $
265
357
do comps ← get_comps,
266
- let cs' := comps.fold mk_rb_set (λ p s, s.union (elim_with_set a p comps)),
358
+ let cs' := comps.fold mk_pcomp_set (λ p s, s.union (elim_with_set a p comps)),
267
359
update (vs.erase a) cs'
268
360
269
361
meta def elim_all_vars : linarith_monad unit :=
@@ -400,7 +492,7 @@ meta def to_comp (red : transparency) (e : expr) (m : expr_map ℕ) (mm : rb_map
400
492
do (iq, e) ← parse_into_comp_and_expr e,
401
493
(m', comp') ← map_of_expr red m e,
402
494
let ⟨nm, mm'⟩ := sum_to_lf comp' mm,
403
- return ⟨⟨iq, mm'⟩,m',nm⟩
495
+ return ⟨⟨iq, mm'.to_linexp ⟩,m',nm⟩
404
496
405
497
meta def to_comp_fold (red : transparency) : expr_map ℕ → list expr → rb_map monom ℕ →
406
498
tactic (list (option comp) × expr_map ℕ × rb_map monom ℕ )
@@ -422,9 +514,9 @@ do pftps ← l.mmap infer_type,
422
514
(l', _, map) ← to_comp_fold red mk_rb_map pftps mk_rb_map,
423
515
let lz := list.enum $ ((l.zip pftps).zip l').filter_map (λ ⟨a, b⟩, prod.mk a <$> b),
424
516
let prmap := rb_map.of_list $ lz.map (λ ⟨n, x⟩, (n, x.1 )),
425
- let vars : rb_set ℕ := rb_map.set_of_list $ list.range map.size.succ ,
426
- let pc : rb_set pcomp := rb_map.set_of_list $
427
- lz.map (λ ⟨n, x⟩, ⟨x.2 , comp_source.assump n⟩),
517
+ let vars : rb_set ℕ := rb_map.set_of_list $ list.range map.size,
518
+ let pc : rb_set pcomp :=
519
+ rb_set.of_list_core mk_pcomp_set $ lz.map (λ ⟨n, x⟩, ⟨x.2 , comp_source.assump n⟩),
428
520
return (⟨vars, pc⟩, prmap)
429
521
430
522
meta def linarith_monad.run (red : transparency) {α} (tac : linarith_monad α) (l : list expr) :
0 commit comments