Skip to content

Commit 27fa680

Browse files
committed
feat(linarith): add linarith? mode which suggests a linarith only call (#28533)
Adds `linarith?`, which traces the internal to watch which inequalities are being used, and suggests a `linarith only` call via the "try this:" mechanism. Also adds a `+minimize` flag, on by default, which then greedily tries to drop hypotheses from the used set, to see if the problem is still possible with a smaller set. Currently I don't have test case showing this is worthwhile, however. Contributions welcome.
1 parent e931189 commit 27fa680

File tree

3 files changed

+235
-68
lines changed

3 files changed

+235
-68
lines changed

Mathlib/Tactic/Linarith/Frontend.lean

Lines changed: 142 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ structure LinarithConfig : Type where
164164
splitHypotheses : Bool := true
165165
/-- Split `≠` in hypotheses, by branching in cases `<` and `>`. -/
166166
splitNe : Bool := false
167+
/-- If true, `linarith?` attempts to greedily remove unused hypotheses from its
168+
suggestion. -/
169+
minimize : Bool := true
167170
/-- Override the list of preprocessors. -/
168171
preprocessors : List GlobalBranchingPreprocessor := defaultPreprocessors
169172
/-- Specify an oracle for identifying candidate contradictions.
@@ -238,66 +241,84 @@ def ExprMultiMap.insert {α : Type} (self : ExprMultiMap α) (k : Expr) (v : α)
238241
return self.push (k, [v])
239242

240243
/--
241-
`partitionByType l` takes a list `l` of proofs of comparisons. It sorts these proofs by
242-
the type of the variables in the comparison, e.g. `(a : ℚ) < 1` and `(b : ℤ) > c` will be separated.
243-
Returns a map from a type to a list of comparisons over that type.
244+
`partitionByTypeIdx l` takes a list `l` of pairs `(h, i)` where `h` is a proof of a
245+
comparison and `i` records the original position of `h`. The proofs are grouped by the
246+
type of the variables appearing in the comparison, e.g. `(a : ℚ) < 1` and
247+
`(b : ℤ) > c` will be separated. The resulting map associates each type with the
248+
list of `(h, i)` pairs over that type.
244249
-/
245-
def partitionByType (l : List Expr) : MetaM (ExprMultiMap Expr) :=
246-
l.foldlM (fun m h => do m.insert (← typeOfIneqProof h) h) #[]
250+
def partitionByTypeIdx (l : List (Expr × Nat)) : MetaM (ExprMultiMap (Expr × Nat)) :=
251+
l.foldlM (fun m ⟨h, i⟩ => do m.insert (← typeOfIneqProof h) (h, i)) #[]
247252

248253
/--
249-
Given a list `ls` of lists of proofs of comparisons, `findLinarithContradiction cfg ls` will try to
250-
prove `False` by calling `linarith` on each list in succession. It will stop at the first proof of
251-
`False`, and fail if no contradiction is found with any list.
254+
Given a list `ls` of pairs `(α, L)` where each `L` is a list of indexed proofs of
255+
comparisons over the type `α`, `findLinarithContradiction cfg g ls` tries each list in
256+
succession, invoking `linarith` until one produces a contradiction. It returns the
257+
resulting proof of `False` together with the indices of the hypotheses that had
258+
nonzero coefficients in the final certificate.
252259
-/
253-
def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId) (ls : List (Expr × List Expr)) :
254-
MetaM Expr :=
260+
def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId)
261+
(ls : List (Expr × List (Expr × Nat))) : MetaM (Expr × List Nat) :=
255262
try
256263
ls.firstM (fun ⟨α, L⟩ =>
257-
withTraceNode `linarith (return m!"{exceptEmoji ·} running on type {α}") <|
258-
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g L)
264+
withTraceNode `linarith (return m!"{exceptEmoji ·} running on type {α}") do
265+
let (pf, idxs) ←
266+
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g (L.map Prod.fst)
267+
let idxs := idxs.map fun i => L[i]!.2
268+
return (pf, idxs))
259269
catch e => throwError "linarith failed to find a contradiction\n{g}\n{e.toMessageData}"
260270

261271
/--
262-
Given a list `hyps` of proofs of comparisons, `runLinarith cfg hyps prefType`
263-
preprocesses `hyps` according to the list of preprocessors in `cfg`.
264-
This results in a list of branches (typically only one),
265-
each of which must succeed in order to close the goal.
266-
267-
In each branch, we partition the list of hypotheses by type, and run `linarith` on each class
268-
in the partition; one of these must succeed in order for `linarith` to succeed on this branch.
269-
If `prefType` is given, it will first use the class of proofs of comparisons over that type.
272+
Given a list `hyps` of proofs of comparisons, `runLinarith cfg prefType g hyps` preprocesses
273+
`hyps` according to the list of preprocessors in `cfg`. This results in a list of branches
274+
(typically only one), each of which must succeed in order to close the goal.
275+
276+
In each branch, the hypotheses are partitioned by type and `linarith` is run on each class in
277+
turn; one of these must succeed in order for `linarith` to succeed on the branch. If `prefType`
278+
is provided, the corresponding class is tried first.
279+
280+
On success, the metavariable `g` is assigned and the function returns the indices of the
281+
original hypotheses that were used with nonzero coefficient in the final proof.
270282
-/
271283
-- If it succeeds, the passed metavariable should have been assigned.
272284
def runLinarith (cfg : LinarithConfig) (prefType : Option Expr) (g : MVarId)
273-
(hyps : List Expr) : MetaM Unit := do
274-
let singleProcess (g : MVarId) (hyps : List Expr) : MetaM Expr := g.withContext do
275-
linarithTraceProofs s!"after preprocessing, linarith has {hyps.length} facts:" hyps
276-
let mut hyp_set ← partitionByType hyps
277-
trace[linarith] "hypotheses appear in {hyp_set.size} different types"
278-
-- If we have a preferred type, strip it from `hyp_set` and prepare a handler with a custom
279-
-- trace message
280-
let pref : MetaM _ ← do
281-
if let some t := prefType then
282-
let (i, vs) ← hyp_set.find t
283-
hyp_set := hyp_set.eraseIdxIfInBounds i
284-
pure <|
285-
withTraceNode `linarith (return m!"{exceptEmoji ·} running on preferred type {t}") <|
286-
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g vs
287-
else
288-
pure failure
289-
pref <|> findLinarithContradiction cfg g hyp_set.toList
285+
(hyps : List Expr) : MetaM (List Nat) := do
286+
let singleProcess (g : MVarId) (hyps : List (Expr × Nat)) : MetaM (Expr × List Nat) :=
287+
g.withContext do
288+
linarithTraceProofs
289+
s!"after preprocessing, linarith has {hyps.length} facts:" (hyps.map Prod.fst)
290+
let mut hyp_set ← partitionByTypeIdx hyps
291+
trace[linarith] "hypotheses appear in {hyp_set.size} different types"
292+
-- If we have a preferred type, strip it from `hyp_set` and prepare a handler with a custom
293+
-- trace message
294+
let pref : MetaM _ ← do
295+
if let some t := prefType then
296+
let (i, vs) ← hyp_set.find t
297+
hyp_set := hyp_set.eraseIdxIfInBounds i
298+
pure <|
299+
withTraceNode `linarith (return m!"{exceptEmoji ·} running on preferred type {t}") do
300+
let (pf, idxs) ←
301+
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g (vs.map Prod.fst)
302+
let idxs := idxs.map fun j => vs[j]!.2
303+
return (pf, idxs)
304+
else
305+
pure failure
306+
pref <|> findLinarithContradiction cfg g hyp_set.toList
290307
let mut preprocessors := cfg.preprocessors
291308
if cfg.splitNe then
292309
preprocessors := Linarith.removeNe :: preprocessors
293310
if cfg.splitHypotheses then
294311
preprocessors := Linarith.splitConjunctions.globalize.branching :: preprocessors
295312
let branches ← preprocess preprocessors g hyps
313+
let mut used : List Nat := []
296314
for (g, es) in branches do
297-
let r ← singleProcess g es
315+
let esIdx := es.zipIdx
316+
let (r, idxs) ← singleProcess g esIdx
298317
g.assign r
318+
used := idxs ++ used
299319
-- Verify that we closed the goal. Failure here should only result from a bad `Preprocessor`.
300320
(Expr.mvar g).ensureHasNoMVars
321+
return used.eraseDups
301322

302323
-- /--
303324
-- `filterHyps restr_type hyps` takes a list of proofs of comparisons `hyps`, and filters it
@@ -311,24 +332,27 @@ def runLinarith (cfg : LinarithConfig) (prefType : Option Expr) (g : MVarId)
311332
-- | none => return false)
312333

313334
/--
314-
`linarith only_on hyps cfg` tries to close the goal using linear arithmetic. It fails
315-
if it does not succeed at doing this.
335+
`linarithUsedHyps only_on hyps cfg g` runs `linarith` with the supplied hypotheses. It
336+
fails if the goal cannot be closed. When successful, it returns the subset of `hyps` that
337+
were actually used (i.e. had a nonzero coefficient) in the final certificate.
316338
317339
* `hyps` is a list of proofs of comparisons to include in the search.
318340
* If `only_on` is true, the search will be restricted to `hyps`. Otherwise it will use all
319341
comparisons in the local context.
320342
* If `cfg.transparency := semireducible`,
321343
it will unfold semireducible definitions when trying to match atomic expressions.
322344
-/
323-
partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig := {})
324-
(g : MVarId) : MetaM Unit := g.withContext do
345+
partial def linarithUsedHyps (only_on : Bool) (hyps : List Expr)
346+
(cfg : LinarithConfig := {}) (g : MVarId) : MetaM (List Expr) := g.withContext do
325347
-- if the target is an equality, we run `linarith` twice, to prove ≤ and ≥.
326348
if (← whnfR (← instantiateMVars (← g.getType))).isEq then
327349
trace[linarith] "target is an equality: splitting"
328350
if let some [g₁, g₂] ← try? (g.apply (← mkConst' ``eq_of_not_lt_of_not_gt)) then
329-
withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≥") <| linarith only_on hyps cfg g₁
330-
withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≤") <| linarith only_on hyps cfg g₂
331-
return
351+
let h₁ ← withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≥") <|
352+
linarithUsedHyps only_on hyps cfg g₁
353+
let h₂ ← withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≤") <|
354+
linarithUsedHyps only_on hyps cfg g₂
355+
return h₁ ++ h₂
332356

333357
/- If we are proving a comparison goal (and not just `False`), we consider the type of the
334358
elements in the comparison to be the "preferred" type. That is, if we find comparison
@@ -347,17 +371,33 @@ partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig :
347371
| (some (t, v), g) => pure (g, some t, some v)
348372

349373
g.withContext do
350-
-- set up the list of hypotheses, considering the `only_on` and `restrict_type` options
351-
let hyps ← (if only_on then return new_var.toList ++ hyps
352-
else return (← getLocalHyps).toList ++ hyps)
374+
-- set up the list of hypotheses, considering the `only_on` and `restrict_type` options
375+
let hyps ←
376+
(if only_on then return new_var.toList ++ hyps
377+
else return (← getLocalHyps).toList ++ hyps)
353378

354379
-- TODO in mathlib3 we could specify a restriction to a single type.
355380
-- I haven't done that here because I don't know how to store a `Type` in `LinarithConfig`.
356381
-- There's only one use of the `restrict_type` configuration option in mathlib3,
357382
-- and it can be avoided just by using `linarith only`.
358383

359384
linarithTraceProofs "linarith is running on the following hypotheses:" hyps
360-
runLinarith cfg target_type g hyps
385+
let usedIdxs ← runLinarith cfg target_type g hyps
386+
let used := usedIdxs.filterMap (hyps[·]?)
387+
let used := match new_var with
388+
| some nv => used.filter (fun h => !(h == nv))
389+
| none => used
390+
return used
391+
392+
/--
393+
Run the core `linarith` procedure on the goal `g` using the hypotheses `hyps`.
394+
If `only_on` is true, the search is restricted to `hyps`; otherwise all suitable
395+
local hypotheses are considered. This is the workhorse behind the user-facing
396+
`linarith` tactic.
397+
-/
398+
partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig := {})
399+
(g : MVarId) : MetaM Unit := do
400+
discard <| linarithUsedHyps only_on hyps cfg g
361401

362402
end Linarith
363403

@@ -416,6 +456,8 @@ optional arguments:
416456
disequality hypotheses. (`false` by default.)
417457
* If `exfalso` is `false`, `linarith` will fail when the goal is neither an inequality nor `False`.
418458
(`true` by default.)
459+
* If `minimize` is `false`, `linarith?` will report all hypotheses appearing in its initial
460+
proof without attempting to drop redundancies. (`true` by default.)
419461
* `restrict_type` (not yet implemented in mathlib4)
420462
will only use hypotheses that are inequalities over `tp`. This is useful
421463
if you have e.g. both integer- and rational-valued inequalities in the local context, which can
@@ -428,8 +470,19 @@ routine.
428470
-/
429471
syntax (name := linarith) "linarith" "!"? linarithArgsRest : tactic
430472

473+
/--
474+
`linarith?` behaves like `linarith` but, on success, it prints a suggestion of
475+
the form `linarith only [...]` listing a minimized set of hypotheses used in the
476+
final proof. Use `linarith?!` for the higher-reducibility variant and set the
477+
`minimize` flag in the configuration to control whether greedy minimization is
478+
performed.
479+
-/
480+
syntax (name := linarith?) "linarith?" "!"? linarithArgsRest : tactic
481+
431482
@[inherit_doc linarith] macro "linarith!" rest:linarithArgsRest : tactic =>
432483
`(tactic| linarith ! $rest:linarithArgsRest)
484+
@[inherit_doc linarith?] macro "linarith?!" rest:linarithArgsRest : tactic =>
485+
`(tactic| linarith? ! $rest:linarithArgsRest)
433486

434487
/--
435488
An extension of `linarith` with some preprocessing to allow it to solve some nonlinear arithmetic
@@ -458,6 +511,46 @@ elab_rules : tactic
458511
let cfg := (← elabLinarithConfig cfg).updateReducibility bang.isSome
459512
commitIfNoEx do liftMetaFinishingTactic <| Linarith.linarith o.isSome args.toList cfg
460513

514+
elab_rules : tactic
515+
| `(tactic| linarith?%$tk $[!%$bang]? $cfg:optConfig $[only%$o]? $[[$args,*]]?) =>
516+
withMainContext do
517+
let args ←
518+
((args.map (TSepArray.getElems)).getD {}).mapM (elabTermWithoutNewMVars `linarith)
519+
let cfg := (← elabLinarithConfig cfg).updateReducibility bang.isSome
520+
let g ← getMainGoal
521+
let st ← saveState
522+
try
523+
let used₀ ← Linarith.linarithUsedHyps o.isSome args.toList cfg g
524+
-- Check that all used hypotheses are fvars (not arbitrary terms)
525+
if used₀.any (fun e => e.fvarId?.isNone) then
526+
throwError "linarith? currently only supports named hypothesis, not terms"
527+
let used ←
528+
if cfg.minimize then
529+
let rec minimize (hs : List Expr) (i : Nat) : TacticM (List Expr) := do
530+
if _h : i < hs.length then
531+
let rest := hs.eraseIdx i
532+
st.restore
533+
try
534+
let _ ← Linarith.linarith true rest cfg g
535+
minimize rest i
536+
catch _ => minimize hs (i+1)
537+
else
538+
return hs
539+
minimize used₀ 0
540+
else
541+
pure used₀
542+
st.restore
543+
discard <| Linarith.linarith true used cfg g
544+
replaceMainGoal []
545+
-- TODO: we should check for, and deal with, shadowed names here.
546+
let idsList ← used.mapM fun e => do
547+
pure (Lean.mkIdent (← e.fvarId!.getUserName))
548+
let sugg ← `(tactic| linarith only [$(idsList.toArray),*])
549+
Lean.Meta.Tactic.TryThis.addSuggestion tk sugg
550+
catch e =>
551+
discard <| st.restore
552+
throw e
553+
461554
-- TODO restore this when `add_tactic_doc` is ported
462555
-- add_tactic_doc
463556
-- { name := "linarith",

Mathlib/Tactic/Linarith/Verification.lean

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,22 @@ def mkNegOneLtZeroProof (tp : Expr) : MetaM Expr := do
140140
mkAppM `neg_neg_of_pos #[zero_lt_one]
141141

142142
/--
143-
`addNegEqProofs l` inspects the list of proofs `l` for proofs of the form `t = 0`. For each such
144-
proof, it adds a proof of `-t = 0` to the list.
143+
`addNegEqProofsIdx l` inspects a list `l` of pairs `(h, i)` where `h` proves
144+
`tᵢ Rᵢ 0` and `i` records the original index of the hypothesis. For each
145+
equality proof `t = 0` in the list, it appends a proof of `-t = 0` with the
146+
same index `i`. All other entries are preserved.
145147
-/
146-
def addNegEqProofs : List Expr → MetaM (List Expr)
148+
def addNegEqProofsIdx : List (Expr × Nat) → MetaM (List (Expr × Nat))
147149
| [] => return []
148-
| (h::tl) => do
150+
| (⟨h, i⟩::tl) => do
149151
let (iq, t) ← parseCompAndExpr (← inferType h)
150152
match iq with
151153
| Ineq.eq => do
152-
let nep := mkAppN (← mkAppM `Iff.mpr #[← mkAppOptM ``neg_eq_zero #[none, none, t]]) #[h]
153-
let tl ← addNegEqProofs tl
154-
return h::nep::tl
155-
| _ => return h :: (← addNegEqProofs tl)
154+
let nep :=
155+
mkAppN (← mkAppM `Iff.mpr #[← mkAppOptM ``neg_eq_zero #[none, none, t]]) #[h]
156+
let tl ← addNegEqProofsIdx tl
157+
return (h, i)::(nep, i)::tl
158+
| _ => return (h, i) :: (← addNegEqProofsIdx tl)
156159

157160
/--
158161
`proveEqZeroUsing tac e` tries to use `tac` to construct a proof of `e = 0`.
@@ -188,15 +191,19 @@ tactic, which is typically `ring`. We prove (2) by folding over the set of hypot
188191
`transparency : TransparencyMode` controls the transparency level with which atoms are identified.
189192
-/
190193
def proveFalseByLinarith (transparency : TransparencyMode) (oracle : CertificateOracle)
191-
(discharger : TacticM Unit) : MVarId → List Expr → MetaM Expr
194+
(discharger : TacticM Unit) : MVarId → List Expr → MetaM (Expr × List Nat)
192195
| _, [] => throwError "no args to linarith"
193196
| g, l@(h::_) => do
194197
Lean.Core.checkSystem decl_name%.toString
195198
-- for the elimination to work properly, we must add a proof of `-1 < 0` to the list,
196199
-- along with negated equality proofs.
197-
let l' ← detailTrace "addNegEqProofs" <| addNegEqProofs l
198-
let inputs ← detailTrace "mkNegOneLtZeroProof" <|
199-
return (← mkNegOneLtZeroProof (← typeOfIneqProof h))::l'.reverse
200+
let lidx := l.zipIdx
201+
let l' ← detailTrace "addNegEqProofs" <| addNegEqProofsIdx lidx
202+
let inputsTagged : List (Expr × Option Nat) ←
203+
detailTrace "mkNegOneLtZeroProof" <|
204+
return ((← mkNegOneLtZeroProof (← typeOfIneqProof h)), none) ::
205+
(l'.reverse.map fun ⟨e, i⟩ => (e, some i))
206+
let inputs := inputsTagged.map Prod.fst
200207
trace[linarith.detail] "inputs:{indentD <| toMessageData (← inputs.mapM inferType)}"
201208
let (comps, max_var) ← detailTrace "linearFormsAndMaxVar" <|
202209
linearFormsAndMaxVar transparency inputs
@@ -212,28 +219,36 @@ def proveFalseByLinarith (transparency : TransparencyMode) (oracle : Certificate
212219
throwError "linarith failed to find a contradiction"
213220
trace[linarith] "found a contradiction: {certificate.toList}"
214221
return certificate
215-
let (sm, zip) ←
222+
let (sm, zip, idxs) ←
216223
withTraceNode `linarith (return m!"{exceptEmoji ·} Building final expression") do
217-
let enum_inputs := inputs.zipIdx
224+
let enum_inputs := inputsTagged.zipIdx
218225
-- construct a list pairing nonzero coeffs with the proof of their corresponding
219-
-- comparison
220-
let zip := enum_inputs.filterMap fun ⟨e, n⟩ => (certificate[n]?).map (e, ·)
221-
let mls ← zip.mapM fun ⟨e, n⟩ => do mulExpr n (← leftOfIneqProof e)
226+
-- comparison and track the original index
227+
let used := enum_inputs.filterMap fun ⟨⟨e, orig?⟩, n⟩ =>
228+
(certificate[n]?).map fun c => (e, c, orig?)
229+
let zip := used.map fun ⟨e, c, _⟩ => (e, c)
230+
let mls ← used.mapM fun ⟨e, c, _⟩ => do mulExpr c (← leftOfIneqProof e)
222231
-- `sm` is the sum of input terms, scaled to cancel out all variables.
223232
let sm ← addExprs mls
224233
-- let sm ← instantiateMVars sm
225234
trace[linarith] "{indentD sm}\nshould be both 0 and negative"
226-
return (sm, zip)
235+
let idxs :=
236+
(used.foldl (fun acc (_, _, orig?) =>
237+
match orig? with
238+
| some i => i :: acc
239+
| none => acc) []).eraseDups
240+
return (sm, zip, idxs)
227241
-- we prove that `sm = 0`, typically with `ring`.
228242
let sm_eq_zero ← detailTrace "proveEqZeroUsing" <| proveEqZeroUsing discharger sm
229243
-- we also prove that `sm < 0`
230244
let sm_lt_zero ← detailTrace "mkLTZeroProof" <| mkLTZeroProof zip
231-
detailTrace "Linarith.lt_irrefl" do
245+
let pf ← detailTrace "Linarith.lt_irrefl" do
232246
-- this is a contradiction.
233247
let pftp ← inferType sm_lt_zero
234248
let ⟨_, nep, _⟩ ← g.rewrite pftp sm_eq_zero
235249
let pf' ← mkAppM ``Eq.mp #[nep, sm_lt_zero]
236250
mkAppM ``Linarith.lt_irrefl #[pf']
251+
return (pf, idxs)
237252
where
238253
/-- Log `f` under `linarith.detail`, with exception emojis and the provided name. -/
239254
detailTrace {α} (s : String) (f : MetaM α) : MetaM α :=

0 commit comments

Comments
 (0)