Skip to content

Commit

Permalink
Document and fix Group. Add a test for it.
Browse files Browse the repository at this point in the history
  • Loading branch information
bobg committed Jan 5, 2022
1 parent a1e217d commit f88c191
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 19 deletions.
35 changes: 35 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func Accum[T any](ctx context.Context, inp *Iter[T], f func(T, T) (T, error)) *I
out[i+1] == f(out[i], inp[i+1])
func Chan[T any](ctx context.Context, inp <-chan T) *Iter[T]
Chan creates an iterator reading from a channel.
func Concat[T any](ctx context.Context, inps ...*Iter[T]) *Iter[T]
Concat[T] takes a sequence of iterators and produces an iterator over all
the elements of the input iterators, in sequence.
Expand Down Expand Up @@ -126,6 +129,38 @@ func Gen[T any](ctx context.Context, f func() (T, bool, error)) *Iter[T]
Gen produces an iterator whose members are generated by successive calls to
a given function.
func Group[T any, K comparable](ctx context.Context, inp *Iter[T], f func(T) (K, error)) *Iter[Pair[K, *Iter[T]]]
Group partitions the elements of an input iterator into separate groups
according to a given partitioning function. The output is an iterator of
(K,iterator) pairs, where each K is a value produced by the partitioning
function, and each iterator contains the elements from the input belonging
to that group.
The logic that consumes the input iterator is single-threaded. This means
that a caller consuming the top-level output of Group (i.e., the iterator of
pairs) should launch goroutines to consume the nested iterator in each pair.
Otherwise a reader waiting for a value on the K1 sub-iterator may deadlock
waiting for someone to read the value that Group is trying to supply to the
K2 sub-iterator.
Illustration:
outer := Group(ctx, inp, partitionFunc)
for {
pair, ok, err := outer.Next()
if err != nil { ... }
if !ok { break }
k, inner := pair.X, pair.Y
go func() {
for {
x, ok, err := inner.Next()
if err != nil { ... }
if !ok { break }
...handle value x in the k group...
}
}()
}
func Ints(ctx context.Context, start, delta int) *Iter[int]
Ints produces an infinite iterator of integers, starting at start and
incrementing by delta.
Expand Down
2 changes: 2 additions & 0 deletions chan.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package chit

import "context"

// Chan creates an iterator reading from a channel.
func Chan[T any](ctx context.Context, inp <-chan T) *Iter[T] {
return New(ctx, func(send func(T) error) error {
Expand Down
16 changes: 10 additions & 6 deletions chit.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ func New[T any](ctx context.Context, writer func(send func(T) error) error) *Ite
}
go func() {
iter.Err = writer(func(x T) error {
select {
case ch <- x:
return nil
case <-ctx.Done():
return ctx.Err()
}
return chsend(ctx, ch, x)
})
close(ch)
}()
Expand All @@ -75,3 +70,12 @@ func (it *Iter[T]) Next() (T, bool, error) {
func (it *Iter[T]) Cancel() {
it.cancel()
}

func chsend[T any](ctx context.Context, ch chan<- T, x T) error {
select {
case ch <- x:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
53 changes: 40 additions & 13 deletions group.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
package chit

import "context"

// Group partitions the elements of an input iterator into separate groups according to a given partitioning function.
// The output is an iterator of (K,iterator) pairs,
// where each K is a value produced by the partitioning function,
// and each iterator contains the elements from the input belonging to that group.
//
// The logic that consumes the input iterator is single-threaded.
// This means that a caller consuming the top-level output of Group
// (i.e., the iterator of pairs)
// should launch goroutines to consume the nested iterator in each pair.
// Otherwise a reader waiting for a value on the K1 sub-iterator
// may deadlock waiting for someone to read the value
// that Group is trying to supply to the K2 sub-iterator.
//
// Illustration:
//
// outer := Group(ctx, inp, partitionFunc)
// for {
// pair, ok, err := outer.Next()
// if err != nil { ... }
// if !ok { break }
// k, inner := pair.X, pair.Y
// go func() {
// for {
// x, ok, err := inner.Next()
// if err != nil { ... }
// if !ok { break }
// ...handle value x in the k group...
// }
// }()
// }
func Group[T any, K comparable](ctx context.Context, inp *Iter[T], f func(T) (K, error)) *Iter[Pair[K, *Iter[T]]] {
// When we discover a new partition (a new K value),
// we create a channel to feed the corresponding Iter[T].
m := make(map[K]chan<- T)

return New(ctx, func(outerSend func(Pair[K, *Iter[T]]) error {
return New(ctx, func(outerSend func(Pair[K, *Iter[T]]) error) error {
defer func() {
for _, ch := range m {
close(ch)
Expand All @@ -29,17 +61,14 @@ func Group[T any, K comparable](ctx context.Context, inp *Iter[T], f func(T) (K,
// This is an existing partition.
// Supply the current value to its iterator.

select {
case ch <- x:
// ok
case <-ctx.Done():
return ctx.Err()
err = chsend(ctx, ch, x)
if err != nil {
return err
}
continue
}

// This is a new partition.
//

ch := make(chan T, 32)
m[k] = ch
Expand All @@ -49,12 +78,10 @@ func Group[T any, K comparable](ctx context.Context, inp *Iter[T], f func(T) (K,
return err
}

select {
case ch <- x:
// ok
case <-ctx.Done():
return ctx.Err()
err = chsend(ctx, ch, x)
if err != nil {
return err
}
}
}))
})
}
49 changes: 49 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package chit

import (
"context"
"reflect"
"sync"
"testing"
)

func TestGroup(t *testing.T) {
ctx := context.Background()
inp := FirstN(ctx, Ints(ctx, 1, 1), 10)
groups := Group(ctx, inp, func(x int) (int, error) { return x % 3, nil })
m := map[int][]int{
0: nil,
1: nil,
2: nil,
}
var wg sync.WaitGroup
for {
pair, ok, err := groups.Next()
if err != nil {
t.Fatal(err)
}
if !ok {
break
}
wg.Add(1)
go func() {
s, err := ToSlice(ctx, pair.Y)
if err != nil {
panic(err)
}
m[pair.X] = s
wg.Done()
}()
}
wg.Wait()

if !reflect.DeepEqual(m[0], []int{3, 6, 9}) {
t.Errorf("got %v, want [3 6 9]", m[0])
}
if !reflect.DeepEqual(m[1], []int{1, 4, 7, 10}) {
t.Errorf("got %v, want [1 4 7 10]", m[1])
}
if !reflect.DeepEqual(m[2], []int{2, 5, 8}) {
t.Errorf("got %v, want [2 5 8]", m[2])
}
}

0 comments on commit f88c191

Please sign in to comment.