Skip to content

Commit

Permalink
combin: Add helpers for dealing with permutations (#1076)
Browse files Browse the repository at this point in the history
* combin: Add helpers for dealing with permutations
  • Loading branch information
btracey committed Sep 11, 2019
1 parent 1d8f8b2 commit 40d3308
Show file tree
Hide file tree
Showing 3 changed files with 475 additions and 42 deletions.
287 changes: 275 additions & 12 deletions stat/combin/combin.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func LogGeneralizedBinomial(n, k float64) float64 {
return a - b - c
}

// CombinationGenerator generates combinations iteratively. Combinations may be
// called to generate all combinations collectively.
// CombinationGenerator generates combinations iteratively. The Combinations
// function may be called to generate all combinations collectively.
type CombinationGenerator struct {
n int
k int
Expand Down Expand Up @@ -118,29 +118,28 @@ func (c *CombinationGenerator) Next() bool {
return true
}

// Combination generates the current combination. If next is non-nil, it must have
// length k and the result will be stored in-place into combination. If combination
// Combination returns the current combination. If dst is non-nil, it must have
// length k and the result will be stored in-place into dst. If dst
// is nil a new slice will be allocated and returned. If all of the combinations
// have already been constructed (Next() returns false), Combination will panic.
//
// Next must be called to initialize the first value before calling Combination
// or Combination will panic. The value returned by Combination is only changed
// during calls to Next.
func (c *CombinationGenerator) Combination(combination []int) []int {
func (c *CombinationGenerator) Combination(dst []int) []int {
if c.remaining == -1 {
panic("combin: all combinations have been generated")
}
if c.previous == nil {
panic("combin: Combination called before Next")
}
if combination == nil {
combination = make([]int, c.k)
}
if len(combination) != c.k {
if dst == nil {
dst = make([]int, c.k)
} else if len(dst) != c.k {
panic(badInput)
}
copy(combination, c.previous)
return combination
copy(dst, c.previous)
return dst
}

// Combinations generates all of the combinations of k elements from a
Expand All @@ -150,7 +149,7 @@ func (c *CombinationGenerator) Combination(combination []int) []int {
// n and k must be non-negative with n >= k, otherwise Combinations will panic.
//
// CombinationGenerator may alternatively be used to generate the combinations
// iteratively instead of collectively.
// iteratively instead of collectively, or IndexToCombination for random access.
func Combinations(n, k int) [][]int {
combins := Binomial(n, k)
data := make([][]int, combins)
Expand Down Expand Up @@ -368,3 +367,267 @@ func SubFor(sub []int, idx int, dims []int) []int {
sub[len(sub)-1] = idx
return sub
}

// NumPermutations returns the number of permutations when selecting k
// objects from a set of n objects when the selection order matters.
// No check is made for overflow.
//
// NumPermutations panics if either n or k is negative, or if k is
// greater than n.
func NumPermutations(n, k int) int {
if n < 0 {
panic("combin: n is negative")
}
if k < 0 {
panic("combin: k is negative")
}
if k > n {
panic("combin: k is greater than n")
}
p := 1
for i := n - k + 1; i <= n; i++ {
p *= i
}
return p
}

// Permutations generates all of the permutations of k elements from a
// set of size n. The returned slice has length NumPermutations(n, k)
// and each inner slice has length k.
//
// n and k must be non-negative with n >= k, otherwise Permutations will panic.
//
// PermutationGenerator may alternatively be used to generate the permutations
// iteratively instead of collectively, or IndexToPermutation for random access.
func Permutations(n, k int) [][]int {
nPerms := NumPermutations(n, k)
data := make([][]int, nPerms)
if len(data) == 0 {
return data
}
for i := 0; i < nPerms; i++ {
data[i] = IndexToPermutation(nil, i, n, k)
}
return data
}

// PermutationGenerator generates permutations iteratively. The Permutations
// function may be called to generate all permutations collectively.
type PermutationGenerator struct {
n int
k int
nPerm int
idx int
permutation []int
}

// NewPermutationGenerator returns a PermutationGenerator for generating the
// permutations of k elements from a set of size n.
//
// n and k must be non-negative with n >= k, otherwise NewPermutationGenerator
// will panic.
func NewPermutationGenerator(n, k int) *PermutationGenerator {
return &PermutationGenerator{
n: n,
k: k,
nPerm: NumPermutations(n, k),
idx: -1,
permutation: make([]int, k),
}
}

// Next advances the iterator if there are permutations remaining to be generated,
// and returns false if all permutations have been generated. Next must be called
// to initialize the first value before calling Permutation or Permutation will
// panic. The value returned by Permutation is only changed during calls to Next.
func (p *PermutationGenerator) Next() bool {
if p.idx >= p.nPerm-1 {
p.idx = p.nPerm // so Permutation can panic.
return false
}
p.idx++
IndexToPermutation(p.permutation, p.idx, p.n, p.k)
return true
}

// Permutation returns the current permutation. If dst is non-nil, it must have
// length k and the result will be stored in-place into dst. If dst
// is nil a new slice will be allocated and returned. If all of the permutations
// have already been constructed (Next() returns false), Permutation will panic.
//
// Next must be called to initialize the first value before calling Permutation
// or Permutation will panic. The value returned by Permutation is only changed
// during calls to Next.
func (p *PermutationGenerator) Permutation(dst []int) []int {
if p.idx == p.nPerm {
panic("combin: all permutations have been generated")
}
if p.idx == -1 {
panic("combin: Permutation called before Next")
}
if dst == nil {
dst = make([]int, p.k)
} else if len(dst) != p.k {
panic(badInput)
}
copy(dst, p.permutation)
return dst
}

// PermutationIndex returns the index of the given permutation.
//
// The functions PermutationIndex and IndexToPermutation define a bijection
// between the integers and the NumPermutations(n, k) number of possible permutations.
// PermutationIndex returns the inverse of IndexToPermutation.
//
// PermutationIndex panics if perm is not a permutation of k of the first
// [0,n) integers, if n or k are non-negative, or if k is greater than n.
func PermutationIndex(perm []int, n, k int) int {
if n < 0 || k < 0 {
panic(badNegInput)
}
if n < k {
panic(badSetSize)
}
if len(perm) != k {
panic("combin: bad length permutation")
}
contains := make(map[int]struct{}, k)
for _, v := range perm {
if v < 0 || v >= n {
panic("combin: bad element")
}
contains[v] = struct{}{}
}
if len(contains) != k {
panic("combin: perm contains non-unique elements")
}
if n == k {
// The permutation is the ordering of the elements.
return equalPermutationIndex(perm)
}

// The permutation index is found by finding the combination index and the
// equalPermutation index. The combination index is found by just sorting
// the elements, and the the permutation index is the ordering of the size
// of the elements.
tmp := make([]int, len(perm))
copy(tmp, perm)
idx := make([]int, len(perm))
for i := range idx {
idx[i] = i
}
s := sortInts{tmp, idx}
sort.Sort(s)
order := make([]int, len(perm))
for i, v := range idx {
order[v] = i
}
combIdx := CombinationIndex(tmp, n, k)
permIdx := equalPermutationIndex(order)
return combIdx*NumPermutations(k, k) + permIdx
}

type sortInts struct {
data []int
idx []int
}

func (s sortInts) Len() int {
return len(s.data)
}

func (s sortInts) Less(i, j int) bool {
return s.data[i] < s.data[j]
}

func (s sortInts) Swap(i, j int) {
s.data[i], s.data[j] = s.data[j], s.data[i]
s.idx[i], s.idx[j] = s.idx[j], s.idx[i]
}

// IndexToPermutation returns the permutation corresponding to the given index.
//
// The functions PermutationIndex and IndexToPermutation define a bijection
// between the integers and the NumPermutations(n, k) number of possible permutations.
// IndexToPermutation returns the inverse of PermutationIndex.
//
// The permutation is stored in-place into dst if dst is non-nil, otherwise
// a new slice is allocated and returned.
//
// IndexToPermutation panics if n or k are non-negative, if k is greater than n,
// or if idx is not in [0, NumPermutations(n,k)-1]. IndexToPermutation will also panic
// if dst is non-nil and len(dst) is not k.
func IndexToPermutation(dst []int, idx, n, k int) []int {
nPerm := NumPermutations(n, k)
if idx < 0 || idx >= nPerm {
panic("combin: invalid index")
}
if dst == nil {
dst = make([]int, k)
} else if len(dst) != k {
panic(badInput)
}
if n == k {
indexToEqualPermutation(dst, idx)
return dst
}

// First, we index into the combination (which of the k items to choose)
// and then we index into the n == k permutation of those k items. The
// indexing acts like a matrix with nComb rows and factorial(k) columns.
kPerm := NumPermutations(k, k)
combIdx := idx / kPerm
permIdx := idx % kPerm
comb := IndexToCombination(nil, combIdx, n, k) // Gives us the set of integers.
perm := make([]int, len(dst))
indexToEqualPermutation(perm, permIdx) // Gives their order.
for i, v := range perm {
dst[i] = comb[v]
}
return dst
}

// equalPermutationIndex returns the index of the given permutation of the
// first k integers.
func equalPermutationIndex(perm []int) int {
// Note(btracey): This is an n^2 algorithm, but factorial increases
// very quickly (25! overflows int64) so this is not a problem in
// practice.
idx := 0
for i, u := range perm {
less := 0
for _, v := range perm[i:] {
if v < u {
less++
}
}
idx += less * factorial(len(perm)-i-1)
}
return idx
}

// indexToEqualPermutation returns the permutation for the first len(dst)
// integers for the given index.
func indexToEqualPermutation(dst []int, idx int) {
for i := range dst {
dst[i] = i
}
for i := range dst {
f := factorial(len(dst) - i - 1)
r := idx / f
v := dst[i+r]
copy(dst[i+1:i+r+1], dst[i:i+r])
dst[i] = v
idx %= f
}
}

// factorial returns a!.
func factorial(a int) int {
f := 1
for i := 2; i <= a; i++ {
f *= i
}
return f
}
42 changes: 42 additions & 0 deletions stat/combin/combin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,45 @@ func TestCartesian(t *testing.T) {
t.Errorf("cartesian data mismatch.\nwant:\n%v\ngot:\n%v", want, got)
}
}

func TestPermutationIndex(t *testing.T) {
for cas, s := range []struct {
n, k int
}{
{6, 3},
{4, 4},
{10, 1},
{8, 2},
} {
n := s.n
k := s.k
perms := make(map[string]struct{})
for i := 0; i < NumPermutations(n, k); i++ {
perm := IndexToPermutation(nil, i, n, k)
idx := PermutationIndex(perm, n, k)
if idx != i {
t.Errorf("Cas %d: permutation mismatch. Want %d, got %d", cas, i, idx)
}
perms[intSliceToKey(perm)] = struct{}{}
}
if len(perms) != NumPermutations(n, k) {
t.Errorf("Case %d: not all generated combinations were unique", cas)
}
}
}

func TestPermutationGenerator(t *testing.T) {
for n := 0; n <= 7; n++ {
for k := 1; k <= n; k++ {
permutations := Permutations(n, k)
pg := NewPermutationGenerator(n, k)
genPerms := make([][]int, 0, len(permutations))
for pg.Next() {
genPerms = append(genPerms, pg.Permutation(nil))
}
if !intSosMatch(permutations, genPerms) {
t.Errorf("Permutations and generated permutations do not match. n = %v, k = %v", n, k)
}
}
}
}

0 comments on commit 40d3308

Please sign in to comment.