Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(Data/Multiset/Powerset): redefine powersetAux #7388

Closed
wants to merge 4 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions Mathlib/Data/Multiset/Powerset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@ variable {α : Type*}

/-! ### powerset -/

--Porting note: TODO: Write a more efficient version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what sense is the new version more efficient? In terms of algorithmic complexity they're the same, right? Is the problem the intermediate memory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nothing else, it fixes timeouts and seems to be a more faithful port of the Lean 3 version...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the cause of the slowdown was the use of Array in List.sublists. An Array in the Kernel is just a List, but compiled to something more efficient. List.sublists uses Array.push, which is slow in the kernel becuase it's l ++ [a] on the underlying list, so it's linear time instead of constant time. So this new implementation is a better complexity in the kernel.

Copy link
Member

@eric-wieser eric-wieser Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean List.sublists also is inefficient in the kernel? Can we change it back to be array-free (using the definition in mathlib3port for example), and then this definition will get the same improvement?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're worried about runtime performance, then we can add a csimp lemma in Std that converts the kernel-friendly version into the execution-friendly version. The advantage of doing that in Std to List.sublists is that it should cause everything downstream to be optimal automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, my check for whether alternate definitions are "good enough" is the following example, which essentially comes from the Zulip thread:

import Mathlib.Order.Partition.Finpartition

open Finset

instance Finpartition.fintype_finset
  {α : Type _} [DecidableEq α] (a : Finset α) : Fintype (Finpartition a) where
  elems := a.powerset.powerset.image
    (λ p => if h : p.SupIndep id ∧ p.sup id = a ∧ ⊥ ∉ p then ⟨p, h.1, h.2.1, h.2.2else ⊥)
  complete := by
    rintro p
    rw [mem_image]
    refine' ⟨p.parts, _, _⟩
    · simp only [mem_powerset]
      intros i hi
      rw [mem_powerset]
      exact p.le hi
    · rw [dif_pos]
      simp only [p.supIndep, p.supParts, p.not_bot_mem, eq_self_iff_true, not_false_iff, and_self]

example : @Fintype.card (Finpartition (range 3)) (Finpartition.fintype_finset _) = 5 := by rfl

Copy link
Collaborator Author

@collares collares Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-wieser I checked by reverting my changes and replacing the Std definition of sublists by

  l.foldr (fun a acc => join (acc.map fun x => [x, a :: x])) [[]]

(and sorrying out the affected theorems) and the example in the previous comment didn't time out. Therefore, I fully agree that we should fix it in Std. Unfortunately I don't have the time to prepare and shepherd a Std PR right now, but I'd be happy if you or someone else did it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to adopt this in #7746

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do recent elaborator changes have any effect on this problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance difference is still noticable; see the example in the std issue.

/-- A helper function for the powerset of a multiset. Given a list `l`, returns a list
of sublists of `l` as multisets. -/
def powersetAux (l : List α) : List (Multiset α) :=
(sublists l).map (↑)
l.foldr (fun a acc => acc >>= (fun x => [x, a ::ₘ x])) [0]
#align multiset.powerset_aux Multiset.powersetAux

theorem powersetAux_eq_map_coe {l : List α} : powersetAux l = (sublists l).map (↑) :=
rfl
theorem powersetAux_eq_map_coe {l : List α} : powersetAux l = (sublists l).map (↑) := by
collares marked this conversation as resolved.
Show resolved Hide resolved
induction l
· case nil => rfl
· case cons hd tl ih =>
collares marked this conversation as resolved.
Show resolved Hide resolved
have : List.map ofList (List.bind (sublists tl) fun x => [x, hd :: x]) =
List.bind (List.map ofList (sublists tl)) fun x => [x, hd ::ₘ x] := by
simp [List.bind_map, List.map_bind]
simp only [powersetAux, List.foldr, bind_eq_bind, sublists_cons, this]
congr
#align multiset.powerset_aux_eq_map_coe Multiset.powersetAux_eq_map_coe

@[simp]
Expand Down Expand Up @@ -163,7 +169,7 @@ theorem revzip_powersetAux_perm_aux' {l : List α} :
theorem revzip_powersetAux_perm {l₁ l₂ : List α} (p : l₁ ~ l₂) :
revzip (powersetAux l₁) ~ revzip (powersetAux l₂) := by
haveI := Classical.decEq α
simp [fun l : List α => revzip_powersetAux_lemma l revzip_powersetAux, coe_eq_coe.2 p]
simp only [fun l : List α => revzip_powersetAux_lemma l revzip_powersetAux, coe_eq_coe.2 p]
exact (powersetAux_perm p).map _
#align multiset.revzip_powerset_aux_perm Multiset.revzip_powersetAux_perm

Expand Down