@@ -164,6 +164,9 @@ structure LinarithConfig : Type where
164164 splitHypotheses : Bool := true
165165 /-- Split `≠` in hypotheses, by branching in cases `<` and `>`. -/
166166 splitNe : Bool := false
167+ /-- If true, `linarith?` attempts to greedily remove unused hypotheses from its
168+ suggestion. -/
169+ minimize : Bool := true
167170 /-- Override the list of preprocessors. -/
168171 preprocessors : List GlobalBranchingPreprocessor := defaultPreprocessors
169172 /-- Specify an oracle for identifying candidate contradictions.
@@ -238,66 +241,84 @@ def ExprMultiMap.insert {α : Type} (self : ExprMultiMap α) (k : Expr) (v : α)
238241 return self.push (k, [v])
239242
240243/--
241- `partitionByType l` takes a list `l` of proofs of comparisons. It sorts these proofs by
242- the type of the variables in the comparison, e.g. `(a : ℚ) < 1` and `(b : ℤ) > c` will be separated.
243- Returns a map from a type to a list of comparisons over that type.
244+ `partitionByTypeIdx l` takes a list `l` of pairs `(h, i)` where `h` is a proof of a
245+ comparison and `i` records the original position of `h`. The proofs are grouped by the
246+ type of the variables appearing in the comparison, e.g. `(a : ℚ) < 1` and
247+ `(b : ℤ) > c` will be separated. The resulting map associates each type with the
248+ list of `(h, i)` pairs over that type.
244249-/
245- def partitionByType (l : List Expr) : MetaM (ExprMultiMap Expr) :=
246- l.foldlM (fun m h => do m.insert (← typeOfIneqProof h) h ) #[]
250+ def partitionByTypeIdx (l : List ( Expr × Nat)) : MetaM (ExprMultiMap ( Expr × Nat) ) :=
251+ l.foldlM (fun m ⟨h, i⟩ => do m.insert (← typeOfIneqProof h) (h, i) ) #[]
247252
248253/--
249- Given a list `ls` of lists of proofs of comparisons, `findLinarithContradiction cfg ls` will try to
250- prove `False` by calling `linarith` on each list in succession. It will stop at the first proof of
251- `False`, and fail if no contradiction is found with any list.
254+ Given a list `ls` of pairs `(α, L)` where each `L` is a list of indexed proofs of
255+ comparisons over the type `α`, `findLinarithContradiction cfg g ls` tries each list in
256+ succession, invoking `linarith` until one produces a contradiction. It returns the
257+ resulting proof of `False` together with the indices of the hypotheses that had
258+ nonzero coefficients in the final certificate.
252259-/
253- def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId) (ls : List (Expr × List Expr)) :
254- MetaM Expr :=
260+ def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId)
261+ (ls : List (Expr × List (Expr × Nat))) : MetaM ( Expr × List Nat) :=
255262 try
256263 ls.firstM (fun ⟨α, L⟩ =>
257- withTraceNode `linarith (return m!"{exceptEmoji ·} running on type {α}" ) <|
258- proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g L)
264+ withTraceNode `linarith (return m!"{exceptEmoji ·} running on type {α}" ) do
265+ let (pf, idxs) ←
266+ proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g (L.map Prod.fst)
267+ let idxs := idxs.map fun i => L[i]!.2
268+ return (pf, idxs))
259269 catch e => throwError "linarith failed to find a contradiction\n {g}\n {e.toMessageData}"
260270
261271/--
262- Given a list `hyps` of proofs of comparisons, `runLinarith cfg hyps prefType`
263- preprocesses `hyps` according to the list of preprocessors in `cfg`.
264- This results in a list of branches (typically only one),
265- each of which must succeed in order to close the goal.
266-
267- In each branch, we partition the list of hypotheses by type, and run `linarith` on each class
268- in the partition; one of these must succeed in order for `linarith` to succeed on this branch.
269- If `prefType` is given, it will first use the class of proofs of comparisons over that type.
272+ Given a list `hyps` of proofs of comparisons, `runLinarith cfg prefType g hyps` preprocesses
273+ `hyps` according to the list of preprocessors in `cfg`. This results in a list of branches
274+ (typically only one), each of which must succeed in order to close the goal.
275+
276+ In each branch, the hypotheses are partitioned by type and `linarith` is run on each class in
277+ turn; one of these must succeed in order for `linarith` to succeed on the branch. If `prefType`
278+ is provided, the corresponding class is tried first.
279+
280+ On success, the metavariable `g` is assigned and the function returns the indices of the
281+ original hypotheses that were used with nonzero coefficient in the final proof.
270282-/
271283-- If it succeeds, the passed metavariable should have been assigned.
272284def runLinarith (cfg : LinarithConfig) (prefType : Option Expr) (g : MVarId)
273- (hyps : List Expr) : MetaM Unit := do
274- let singleProcess (g : MVarId) (hyps : List Expr) : MetaM Expr := g.withContext do
275- linarithTraceProofs s! "after preprocessing, linarith has { hyps.length} facts:" hyps
276- let mut hyp_set ← partitionByType hyps
277- trace[linarith] "hypotheses appear in {hyp_set.size} different types"
278- -- If we have a preferred type, strip it from `hyp_set` and prepare a handler with a custom
279- -- trace message
280- let pref : MetaM _ ← do
281- if let some t := prefType then
282- let (i, vs) ← hyp_set.find t
283- hyp_set := hyp_set.eraseIdxIfInBounds i
284- pure <|
285- withTraceNode `linarith (return m!"{exceptEmoji ·} running on preferred type {t}" ) <|
286- proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g vs
287- else
288- pure failure
289- pref <|> findLinarithContradiction cfg g hyp_set.toList
285+ (hyps : List Expr) : MetaM (List Nat) := do
286+ let singleProcess (g : MVarId) (hyps : List (Expr × Nat)) : MetaM (Expr × List Nat) :=
287+ g.withContext do
288+ linarithTraceProofs
289+ s! "after preprocessing, linarith has { hyps.length} facts:" (hyps.map Prod.fst)
290+ let mut hyp_set ← partitionByTypeIdx hyps
291+ trace[linarith] "hypotheses appear in {hyp_set.size} different types"
292+ -- If we have a preferred type, strip it from `hyp_set` and prepare a handler with a custom
293+ -- trace message
294+ let pref : MetaM _ ← do
295+ if let some t := prefType then
296+ let (i, vs) ← hyp_set.find t
297+ hyp_set := hyp_set.eraseIdxIfInBounds i
298+ pure <|
299+ withTraceNode `linarith (return m!"{exceptEmoji ·} running on preferred type {t}" ) do
300+ let (pf, idxs) ←
301+ proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g (vs.map Prod.fst)
302+ let idxs := idxs.map fun j => vs[j]!.2
303+ return (pf, idxs)
304+ else
305+ pure failure
306+ pref <|> findLinarithContradiction cfg g hyp_set.toList
290307 let mut preprocessors := cfg.preprocessors
291308 if cfg.splitNe then
292309 preprocessors := Linarith.removeNe :: preprocessors
293310 if cfg.splitHypotheses then
294311 preprocessors := Linarith.splitConjunctions.globalize.branching :: preprocessors
295312 let branches ← preprocess preprocessors g hyps
313+ let mut used : List Nat := []
296314 for (g, es) in branches do
297- let r ← singleProcess g es
315+ let esIdx := es.zipIdx
316+ let (r, idxs) ← singleProcess g esIdx
298317 g.assign r
318+ used := idxs ++ used
299319 -- Verify that we closed the goal. Failure here should only result from a bad `Preprocessor`.
300320 (Expr.mvar g).ensureHasNoMVars
321+ return used.eraseDups
301322
302323-- /--
303324-- `filterHyps restr_type hyps` takes a list of proofs of comparisons `hyps`, and filters it
@@ -311,24 +332,27 @@ def runLinarith (cfg : LinarithConfig) (prefType : Option Expr) (g : MVarId)
311332-- | none => return false)
312333
313334/--
314- `linarith only_on hyps cfg` tries to close the goal using linear arithmetic. It fails
315- if it does not succeed at doing this.
335+ `linarithUsedHyps only_on hyps cfg g` runs `linarith` with the supplied hypotheses. It
336+ fails if the goal cannot be closed. When successful, it returns the subset of `hyps` that
337+ were actually used (i.e. had a nonzero coefficient) in the final certificate.
316338
317339* `hyps` is a list of proofs of comparisons to include in the search.
318340* If `only_on` is true, the search will be restricted to `hyps`. Otherwise it will use all
319341 comparisons in the local context.
320342* If `cfg.transparency := semireducible`,
321343 it will unfold semireducible definitions when trying to match atomic expressions.
322344 -/
323- partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig := {} )
324- (g : MVarId) : MetaM Unit := g.withContext do
345+ partial def linarithUsedHyps (only_on : Bool) (hyps : List Expr)
346+ (cfg : LinarithConfig := {}) ( g : MVarId) : MetaM (List Expr) := g.withContext do
325347 -- if the target is an equality, we run `linarith` twice, to prove ≤ and ≥.
326348 if (← whnfR (← instantiateMVars (← g.getType))).isEq then
327349 trace[linarith] "target is an equality: splitting"
328350 if let some [g₁, g₂] ← try ? (g.apply (← mkConst' ``eq_of_not_lt_of_not_gt)) then
329- withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≥" ) <| linarith only_on hyps cfg g₁
330- withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≤" ) <| linarith only_on hyps cfg g₂
331- return
351+ let h₁ ← withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≥" ) <|
352+ linarithUsedHyps only_on hyps cfg g₁
353+ let h₂ ← withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≤" ) <|
354+ linarithUsedHyps only_on hyps cfg g₂
355+ return h₁ ++ h₂
332356
333357 /- If we are proving a comparison goal (and not just `False`), we consider the type of the
334358 elements in the comparison to be the "preferred" type. That is, if we find comparison
@@ -347,17 +371,33 @@ partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig :
347371 | (some (t, v), g) => pure (g, some t, some v)
348372
349373 g.withContext do
350- -- set up the list of hypotheses, considering the `only_on` and `restrict_type` options
351- let hyps ← (if only_on then return new_var.toList ++ hyps
352- else return (← getLocalHyps).toList ++ hyps)
374+ -- set up the list of hypotheses, considering the `only_on` and `restrict_type` options
375+ let hyps ←
376+ (if only_on then return new_var.toList ++ hyps
377+ else return (← getLocalHyps).toList ++ hyps)
353378
354379 -- TODO in mathlib3 we could specify a restriction to a single type.
355380 -- I haven't done that here because I don't know how to store a `Type` in `LinarithConfig`.
356381 -- There's only one use of the `restrict_type` configuration option in mathlib3,
357382 -- and it can be avoided just by using `linarith only`.
358383
359384 linarithTraceProofs "linarith is running on the following hypotheses:" hyps
360- runLinarith cfg target_type g hyps
385+ let usedIdxs ← runLinarith cfg target_type g hyps
386+ let used := usedIdxs.filterMap (hyps[·]?)
387+ let used := match new_var with
388+ | some nv => used.filter (fun h => !(h == nv))
389+ | none => used
390+ return used
391+
392+ /--
393+ Run the core `linarith` procedure on the goal `g` using the hypotheses `hyps`.
394+ If `only_on` is true, the search is restricted to `hyps`; otherwise all suitable
395+ local hypotheses are considered. This is the workhorse behind the user-facing
396+ `linarith` tactic.
397+ -/
398+ partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig := {})
399+ (g : MVarId) : MetaM Unit := do
400+ discard <| linarithUsedHyps only_on hyps cfg g
361401
362402end Linarith
363403
@@ -416,6 +456,8 @@ optional arguments:
416456 disequality hypotheses. (`false` by default.)
417457* If `exfalso` is `false`, `linarith` will fail when the goal is neither an inequality nor `False`.
418458 (`true` by default.)
459+ * If `minimize` is `false`, `linarith?` will report all hypotheses appearing in its initial
460+ proof without attempting to drop redundancies. (`true` by default.)
419461* `restrict_type` (not yet implemented in mathlib4)
420462 will only use hypotheses that are inequalities over `tp`. This is useful
421463 if you have e.g. both integer- and rational-valued inequalities in the local context, which can
@@ -428,8 +470,19 @@ routine.
428470-/
429471syntax (name := linarith) "linarith" "!" ? linarithArgsRest : tactic
430472
473+ /--
474+ `linarith?` behaves like `linarith` but, on success, it prints a suggestion of
475+ the form `linarith only [...]` listing a minimized set of hypotheses used in the
476+ final proof. Use `linarith?!` for the higher-reducibility variant and set the
477+ `minimize` flag in the configuration to control whether greedy minimization is
478+ performed.
479+ -/
480+ syntax (name := linarith?) "linarith?" "!" ? linarithArgsRest : tactic
481+
431482@[inherit_doc linarith] macro "linarith!" rest:linarithArgsRest : tactic =>
432483 `(tactic| linarith ! $rest:linarithArgsRest)
484+ @[inherit_doc linarith?] macro "linarith?!" rest:linarithArgsRest : tactic =>
485+ `(tactic| linarith? ! $rest:linarithArgsRest)
433486
434487/--
435488An extension of `linarith` with some preprocessing to allow it to solve some nonlinear arithmetic
@@ -458,6 +511,46 @@ elab_rules : tactic
458511 let cfg := (← elabLinarithConfig cfg).updateReducibility bang.isSome
459512 commitIfNoEx do liftMetaFinishingTactic <| Linarith.linarith o.isSome args.toList cfg
460513
514+ elab_rules : tactic
515+ | `(tactic| linarith?%$tk $[!%$bang]? $cfg:optConfig $[only%$o]? $[[$args,*]]?) =>
516+ withMainContext do
517+ let args ←
518+ ((args.map (TSepArray.getElems)).getD {}).mapM (elabTermWithoutNewMVars `linarith)
519+ let cfg := (← elabLinarithConfig cfg).updateReducibility bang.isSome
520+ let g ← getMainGoal
521+ let st ← saveState
522+ try
523+ let used₀ ← Linarith.linarithUsedHyps o.isSome args.toList cfg g
524+ -- Check that all used hypotheses are fvars (not arbitrary terms)
525+ if used₀.any (fun e => e.fvarId?.isNone) then
526+ throwError "linarith? currently only supports named hypothesis, not terms"
527+ let used ←
528+ if cfg.minimize then
529+ let rec minimize (hs : List Expr) (i : Nat) : TacticM (List Expr) := do
530+ if _h : i < hs.length then
531+ let rest := hs.eraseIdx i
532+ st.restore
533+ try
534+ let _ ← Linarith.linarith true rest cfg g
535+ minimize rest i
536+ catch _ => minimize hs (i+1 )
537+ else
538+ return hs
539+ minimize used₀ 0
540+ else
541+ pure used₀
542+ st.restore
543+ discard <| Linarith.linarith true used cfg g
544+ replaceMainGoal []
545+ -- TODO: we should check for, and deal with, shadowed names here.
546+ let idsList ← used.mapM fun e => do
547+ pure (Lean.mkIdent (← e.fvarId!.getUserName))
548+ let sugg ← `(tactic| linarith only [$(idsList.toArray),*])
549+ Lean.Meta.Tactic.TryThis.addSuggestion tk sugg
550+ catch e =>
551+ discard <| st.restore
552+ throw e
553+
461554-- TODO restore this when `add_tactic_doc` is ported
462555-- add_tactic_doc
463556-- { name := "linarith",
0 commit comments