-
Notifications
You must be signed in to change notification settings - Fork 258
/
Preprocessing.lean
427 lines (375 loc) · 16.7 KB
/
Preprocessing.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
/-
Copyright (c) 2020 Robert Y. Lewis. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Robert Y. Lewis
-/
import Mathlib.Tactic.Linarith.Datatypes
import Mathlib.Tactic.Zify
import Mathlib.Tactic.CancelDenoms.Core
import Batteries.Data.RBMap.Basic
import Mathlib.Data.HashMap
import Mathlib.Control.Basic
/-!
# Linarith preprocessing
This file contains methods used to preprocess inputs to `linarith`.
In particular, `linarith` works over comparisons of the form `t R 0`, where `R ∈ {<,≤,=}`.
It assumes that expressions in `t` have integer coefficients and that the type of `t` has
well-behaved subtraction.
## Implementation details
A `GlobalPreprocessor` is a function `List Expr → TacticM (List Expr)`. Users can add custom
preprocessing steps by adding them to the `LinarithConfig` object. `Linarith.defaultPreprocessors`
is the main list, and generally none of these should be skipped unless you know what you're doing.
-/
set_option autoImplicit true
namespace Linarith
/-! ### Preprocessing -/
open Lean hiding Rat
open Elab Tactic Meta
open Qq
/-- Processor that recursively replaces `P ∧ Q` hypotheses with the pair `P` and `Q`. -/
partial def splitConjunctions : Preprocessor where
name := "split conjunctions"
transform := aux
where
/-- Implementation of the `splitConjunctions` preprocessor. -/
aux (proof : Expr) : MetaM (List Expr) := do
match (← instantiateMVars (← inferType proof)).getAppFnArgs with
| (``And, #[_, _]) =>
pure ((← aux (← mkAppM ``And.left #[proof])) ++
(← aux (← mkAppM ``And.right #[proof])))
| _ => pure [proof]
/--
Removes any expressions that are not proofs of inequalities, equalities, or negations thereof.
-/
partial def filterComparisons : Preprocessor where
name := "filter terms that are not proofs of comparisons"
transform h := do
let tp ← whnfR (← instantiateMVars (← inferType h))
if (← isProp tp) && (← aux tp) then pure [h]
else pure []
where
/-- Implementation of the `filterComparisons` preprocessor. -/
aux (e : Expr) : MetaM Bool := do
match e.getAppFnArgs with
| (``Eq, _) | (``LE.le, _) | (``LT.lt, _) => pure true
| (``Not, #[e]) => match (← whnfR e).getAppFnArgs with
| (``LE.le, _) | (``LT.lt, _) => pure true
| _ => pure false
| _ => pure false
section removeNegations
/--
If `prf` is a proof of `¬ e`, where `e` is a comparison,
`flipNegatedComparison prf e` flips the comparison in `e` and returns a proof.
For example, if `prf : ¬ a < b`, ``flipNegatedComparison prf q(a < b)`` returns a proof of `a ≥ b`.
-/
def flipNegatedComparison (prf : Expr) (e : Expr) : MetaM (Option Expr) :=
match e.getAppFnArgs with
| (``LE.le, #[_, _, _, _]) => try? <| mkAppM ``lt_of_not_ge #[prf]
| (``LT.lt, #[_, _, _, _]) => try? <| mkAppM ``le_of_not_gt #[prf]
| _ => throwError "Not a comparison (flipNegatedComparison): {e}"
/--
Replaces proofs of negations of comparisons with proofs of the reversed comparisons.
For example, a proof of `¬ a < b` will become a proof of `a ≥ b`.
-/
def removeNegations : Preprocessor where
name := "replace negations of comparisons"
transform h := do
let t : Q(Prop) ← whnfR (← inferType h)
match t with
| ~q(¬ $p) =>
match ← flipNegatedComparison h (← whnfR p) with
| some h' =>
trace[linarith] "removing negation in {h}"
return [h']
| _ => return [h]
| _ => return [h]
end removeNegations
section natToInt
open Mathlib.Tactic.Zify
/--
`isNatProp tp` is true iff `tp` is an inequality or equality between natural numbers
or the negation thereof.
-/
partial def isNatProp (e : Expr) : Bool :=
match e.getAppFnArgs with
| (``Eq, #[.const ``Nat [], _, _]) => true
| (``LE.le, #[.const ``Nat [], _, _, _]) => true
| (``LT.lt, #[.const ``Nat [], _, _, _]) => true
| (``GE.ge, #[.const ``Nat [], _, _, _]) => true
| (``GT.gt, #[.const ``Nat [], _, _, _]) => true
| (``Not, #[e]) => isNatProp e
| _ => false
/-- If `e` is of the form `((n : ℕ) : C)`, `isNatCoe e` returns `⟨n, C⟩`. -/
def isNatCoe (e : Expr) : Option (Expr × Expr) :=
match e.getAppFnArgs with
| (``Nat.cast, #[target, _, n]) => some ⟨n, target⟩
| _ => none
/--
`getNatComparisons e` returns a list of all subexpressions of `e` of the form `((t : ℕ) : C)`.
-/
partial def getNatComparisons (e : Expr) : List (Expr × Expr) :=
match isNatCoe e with
| some x => [x]
| none => match e.getAppFnArgs with
| (``HAdd.hAdd, #[_, _, _, _, a, b]) => getNatComparisons a ++ getNatComparisons b
| (``HMul.hMul, #[_, _, _, _, a, b]) => getNatComparisons a ++ getNatComparisons b
| (``HSub.hSub, #[_, _, _, _, a, b]) => getNatComparisons a ++ getNatComparisons b
| (``Neg.neg, #[_, _, a]) => getNatComparisons a
| _ => []
/-- If `e : ℕ`, returns a proof of `0 ≤ (e : C)`. -/
def mk_natCast_nonneg_prf (p : Expr × Expr) : MetaM (Option Expr) :=
match p with
| ⟨e, target⟩ => try commitIfNoEx (mkAppM ``natCast_nonneg #[target, e])
catch e => do
trace[linarith] "Got exception when using cast {e.toMessageData}"
return none
open Batteries
/-- Ordering on `Expr`. -/
def Expr.Ord : Ord Expr :=
⟨fun a b => if Expr.lt a b then .lt else if a.equal b then .eq else .gt⟩
attribute [local instance] Expr.Ord
/--
If `h` is an equality or inequality between natural numbers,
`natToInt` lifts this inequality to the integers.
It also adds the facts that the integers involved are nonnegative.
To avoid adding the same nonnegativity facts many times, it is a global preprocessor.
-/
def natToInt : GlobalBranchingPreprocessor where
name := "move nats to ints"
transform g l := do
let l ← l.mapM fun h => do
let t ← whnfR (← instantiateMVars (← inferType h))
if isNatProp t then
let (some (h', t'), _) ← Term.TermElabM.run' (run_for g (zifyProof none h t))
| throwError "zifyProof failed on {h}"
if ← filterComparisons.aux t' then
pure h'
else
-- `zifyProof` turned our comparison into something that wasn't a comparison
-- probably replacing `n = n` with `True`, because of
-- https://github.com/leanprover-community/mathlib4/issues/741
-- so we just keep the original hypothesis.
pure h
else
pure h
let nonnegs ← l.foldlM (init := ∅) fun (es : RBSet (Expr × Expr) lexOrd.compare) h => do
try
let (a, b) ← getRelSides (← inferType h)
pure <| (es.insertList (getNatComparisons a)).insertList (getNatComparisons b)
catch _ => pure es
pure [(g, ((← nonnegs.toList.filterMapM mk_natCast_nonneg_prf) ++ l : List Expr))]
end natToInt
section strengthenStrictInt
/--
If `pf` is a proof of a strict inequality `(a : ℤ) < b`,
`mkNonstrictIntProof pf` returns a proof of `a + 1 ≤ b`,
and similarly if `pf` proves a negated weak inequality.
-/
def mkNonstrictIntProof (pf : Expr) : MetaM (Option Expr) := do
match (← instantiateMVars (← inferType pf)).getAppFnArgs with
| (``LT.lt, #[.const ``Int [], _, a, b]) =>
return mkApp (← mkAppM ``Iff.mpr #[← mkAppOptM ``Int.add_one_le_iff #[a, b]]) pf
| (``GT.gt, #[.const ``Int [], _, a, b]) =>
return mkApp (← mkAppM ``Iff.mpr #[← mkAppOptM ``Int.add_one_le_iff #[b, a]]) pf
| (``Not, #[P]) => match P.getAppFnArgs with
| (``LE.le, #[.const ``Int [], _, a, b]) =>
return mkApp (← mkAppM ``Iff.mpr #[← mkAppOptM ``Int.add_one_le_iff #[b, a]])
(← mkAppM ``lt_of_not_ge #[pf])
| (``GE.ge, #[.const ``Int [], _, a, b]) =>
return mkApp (← mkAppM ``Iff.mpr #[← mkAppOptM ``Int.add_one_le_iff #[a, b]])
(← mkAppM ``lt_of_not_ge #[pf])
| _ => return none
| _ => return none
/-- `strengthenStrictInt h` turns a proof `h` of a strict integer inequality `t1 < t2`
into a proof of `t1 ≤ t2 + 1`. -/
def strengthenStrictInt : Preprocessor where
name := "strengthen strict inequalities over int"
transform h := return [(← mkNonstrictIntProof h).getD h]
end strengthenStrictInt
section compWithZero
/--
`rearrangeComparison e` takes a proof `e` of an equality, inequality, or negation thereof,
and turns it into a proof of a comparison `_ R 0`, where `R ∈ {=, ≤, <}`.
-/
partial def rearrangeComparison (e : Expr) : MetaM (Option Expr) := do
aux e (← instantiateMVars (← inferType e))
where
/-- Implementation of `rearrangeComparison`, after type inference. -/
aux (proof e : Expr) : MetaM (Option Expr) :=
let isZero (e : Expr) := e.getAppFnArgs matches (``OfNat.ofNat, #[_, .lit (.natVal 0), _])
match e.getAppFnArgs with
| (``LE.le, #[_, _, a, b]) => match isZero a, isZero b with
| _, true => return proof
| true, _ => try? <| mkAppM ``neg_nonpos_of_nonneg #[proof]
| _, _ => try? <| mkAppM ``sub_nonpos_of_le #[proof]
| (``LT.lt, #[_, _, a, b]) => match isZero a, isZero b with
| _, true => return proof
| true, _ => try? <| mkAppM ``neg_neg_of_pos #[proof]
| _, _ => try? <| mkAppM ``sub_neg_of_lt #[proof]
| (``Eq, #[_, a, b]) => match isZero a, isZero b with
| _, true => return proof
| true, _ => try? <| mkAppM ``Eq.symm #[proof]
| _, _ => try? <| mkAppM ``sub_eq_zero_of_eq #[proof]
| (``GT.gt, #[_, _, a, b]) => match isZero a, isZero b with
| _, true => try? <| mkAppM ``neg_neg_of_pos #[proof]
| true, _ => try? <| mkAppM ``lt_zero_of_zero_gt #[proof]
| _, _ => try? <| mkAppM ``sub_neg_of_lt #[proof]
| (``GE.ge, #[_, _, a, b]) => match isZero a, isZero b with
| _, true => try? <| mkAppM ``neg_nonpos_of_nonneg #[proof]
| true, _ => try? <| mkAppM ``le_zero_of_zero_ge #[proof]
| _, _ => try? <| mkAppM ``sub_nonpos_of_le #[proof]
| (``Not, #[a]) => do
let some nproof ← flipNegatedComparison proof a | return none
aux nproof (← inferType nproof)
| a => throwError "couldn't rearrange comparison {a}"
/--
`compWithZero h` takes a proof `h` of an equality, inequality, or negation thereof,
and turns it into a proof of a comparison `_ R 0`, where `R ∈ {=, ≤, <}`.
-/
def compWithZero : Preprocessor where
name := "make comparisons with zero"
transform e := return (← rearrangeComparison e).toList
end compWithZero
section cancelDenoms
theorem without_one_mul [MulOneClass M] {a b : M} (h : 1 * a = b) : a = b := by rwa [one_mul] at h
/--
`normalizeDenominatorsLHS h lhs` assumes that `h` is a proof of `lhs R 0`.
It creates a proof of `lhs' R 0`, where all numeric division in `lhs` has been cancelled.
-/
def normalizeDenominatorsLHS (h lhs : Expr) : MetaM Expr := do
let mut (v, lhs') ← CancelDenoms.derive lhs
if v = 1 then
-- `lhs'` has a `1 *` out front, but `mkSingleCompZeroOf` has a special case
-- where it does not produce `1 *`. We strip it off here:
lhs' ← mkAppM ``without_one_mul #[lhs']
let (_, h'') ← mkSingleCompZeroOf v h
try
h''.rewriteType lhs'
catch e =>
dbg_trace
s!"Error in Linarith.normalizeDenominatorsLHS: {← e.toMessageData.toString}"
throw e
/--
`cancelDenoms pf` assumes `pf` is a proof of `t R 0`. If `t` contains the division symbol `/`,
it tries to scale `t` to cancel out division by numerals.
-/
def cancelDenoms : Preprocessor where
name := "cancel denominators"
transform := fun pf => (do
let (_, lhs) ← parseCompAndExpr (← inferType pf)
guard <| lhs.containsConst (fun n => n = ``HDiv.hDiv || n = ``Div.div)
pure [← normalizeDenominatorsLHS pf lhs])
<|> return [pf]
end cancelDenoms
section nlinarith
/--
`findSquares s e` collects all terms of the form `a ^ 2` and `a * a` that appear in `e`
and adds them to the set `s`.
A pair `(a, true)` is added to `s` when `a^2` appears in `e`,
and `(a, false)` is added to `s` when `a*a` appears in `e`. -/
partial def findSquares (s : HashSet (Expr × Bool)) (e : Expr) : MetaM (HashSet (Expr × Bool)) :=
-- Completely traversing the expression is non-ideal,
-- as we can descend into expressions that could not possibly be seen by `linarith`.
-- As a result we visit expressions with bvars, which then cause panics.
-- Ideally this preprocessor would be reimplemented so it only visits things that could be atoms.
-- In the meantime we just bail out if we ever encounter loose bvars.
if e.hasLooseBVars then return s else
match e.getAppFnArgs with
| (``HPow.hPow, #[_, _, _, _, a, b]) => match b.numeral? with
| some 2 => do
let s ← findSquares s a
return (s.insert (a, true))
| _ => e.foldlM findSquares s
| (``HMul.hMul, #[_, _, _, _, a, b]) => if a.equal b then do
let s ← findSquares s a
return (s.insert (a, false))
else
e.foldlM findSquares s
| _ => e.foldlM findSquares s
/--
`nlinarithExtras` is the preprocessor corresponding to the `nlinarith` tactic.
* For every term `t` such that `t^2` or `t*t` appears in the input, adds a proof of `t^2 ≥ 0`
or `t*t ≥ 0`.
* For every pair of comparisons `t1 R1 0` and `t2 R2 0`, adds a proof of `t1*t2 R 0`.
This preprocessor is typically run last, after all inputs have been canonized.
-/
def nlinarithExtras : GlobalPreprocessor where
name := "nonlinear arithmetic extras"
transform ls := do
let s ← ls.foldrM (fun h s' => do findSquares s' (← instantiateMVars (← inferType h)))
HashSet.empty
let new_es ← s.foldM (fun new_es (⟨e, is_sq⟩ : Expr × Bool) =>
((do
let p ← mkAppM (if is_sq then ``sq_nonneg else ``mul_self_nonneg) #[e]
pure <| p::new_es) <|> pure new_es)) ([] : List Expr)
let new_es ← compWithZero.globalize.transform new_es
trace[linarith] "nlinarith preprocessing found squares"
trace[linarith] "{s.toList}"
linarithTraceProofs "so we added proofs" new_es
let with_comps ← (new_es ++ ls).mapM (fun e => do
let tp ← inferType e
try
let ⟨ine, _⟩ ← parseCompAndExpr tp
pure (ine, e)
catch _ => pure (Ineq.lt, e))
let products ← with_comps.mapDiagM fun (⟨posa, a⟩ : Ineq × Expr) ⟨posb, b⟩ =>
try
(some <$> match posa, posb with
| Ineq.eq, _ => mkAppM ``zero_mul_eq #[a, b]
| _, Ineq.eq => mkAppM ``mul_zero_eq #[a, b]
| Ineq.lt, Ineq.lt => mkAppM ``mul_pos_of_neg_of_neg #[a, b]
| Ineq.lt, Ineq.le => do
let a ← mkAppM ``le_of_lt #[a]
mkAppM ``mul_nonneg_of_nonpos_of_nonpos #[a, b]
| Ineq.le, Ineq.lt => do
let b ← mkAppM ``le_of_lt #[b]
mkAppM ``mul_nonneg_of_nonpos_of_nonpos #[a, b]
| Ineq.le, Ineq.le => mkAppM ``mul_nonneg_of_nonpos_of_nonpos #[a, b])
catch _ => pure none
let products ← compWithZero.globalize.transform products.reduceOption
return (new_es ++ ls ++ products)
end nlinarith
section removeNe
/--
`removeNe_aux` case splits on any proof `h : a ≠ b` in the input,
turning it into `a < b ∨ a > b`.
This produces `2^n` branches when there are `n` such hypotheses in the input.
-/
partial def removeNe_aux : MVarId → List Expr → MetaM (List Branch) := fun g hs => do
let some (e, α, a, b) ← hs.findSomeM? (fun e : Expr => do
let some (α, a, b) := (← instantiateMVars (← inferType e)).ne?' | return none
return some (e, α, a, b)) | return [(g, hs)]
let [ng1, ng2] ← g.apply (← mkAppOptM ``Or.elim #[none, none, ← g.getType,
← mkAppOptM ``lt_or_gt_of_ne #[α, none, a, b, e]]) | failure
let do_goal : MVarId → MetaM (List Branch) := fun g => do
let (f, h) ← g.intro1
h.withContext do
let ls ← removeNe_aux h <| hs.removeAll [e]
return ls.map (fun b : Branch => (b.1, (.fvar f)::b.2))
return ((← do_goal ng1) ++ (← do_goal ng2))
/--
`removeNe` case splits on any proof `h : a ≠ b` in the input, turning it into `a < b ∨ a > b`,
by calling `linarith.removeNe_aux`.
This produces `2^n` branches when there are `n` such hypotheses in the input.
-/
def removeNe : GlobalBranchingPreprocessor where
name := "removeNe"
transform := removeNe_aux
end removeNe
/--
The default list of preprocessors, in the order they should typically run.
-/
def defaultPreprocessors : List GlobalBranchingPreprocessor :=
[filterComparisons, removeNegations, natToInt, strengthenStrictInt,
compWithZero, cancelDenoms]
/--
`preprocess pps l` takes a list `l` of proofs of propositions.
It maps each preprocessor `pp ∈ pps` over this list.
The preprocessors are run sequentially: each receives the output of the previous one.
Note that a preprocessor may produce multiple or no expressions from each input expression,
so the size of the list may change.
-/
def preprocess (pps : List GlobalBranchingPreprocessor) (g : MVarId) (l : List Expr) :
MetaM (List Branch) := g.withContext <|
pps.foldlM (fun ls pp => return (← ls.mapM fun (g, l) => do pp.process g l).join) [(g, l)]
end Linarith