Skip to content

Commit

Permalink
core: api changes, allow to survive runtime gosched
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowspore committed Jan 31, 2024
1 parent c2040e9 commit 135e39a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 61 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func processInput(input []int) (processing, output []int) {

wg.Add(len(input))
for _, v := range input {
ticket := route.TakeTicket()
ticket := route.Ticket()
go func(t *flightorder.Ticket, v int) {
defer wg.Done()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
Expand All @@ -52,11 +52,15 @@ func processInput(input []int) (processing, output []int) {
processing = append(processing, v)
mux.Unlock()

_ = route.CompleteTicket(context.TODO(), t)

mux.Lock()
output = append(output, v)
mux.Unlock()
_ = route.CompleteTicket(context.TODO(), flightorder.CompleteTicketParams{
Ticket: t,
Completion: func(ctx context.Context) error {
mux.Lock()
output = append(output, v)
mux.Unlock()
return nil
},
})
}(ticket, v)
}

Expand Down
39 changes: 27 additions & 12 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,50 @@ func main() {
route := flightorder.NewRoute(flightorder.RouteParams{})

// Take some tickets.
t1 := route.TakeTicket()
t2 := route.TakeTicket()
t3 := route.TakeTicket()
t1 := route.Ticket()
t2 := route.Ticket()
t3 := route.Ticket()

// Perform parallel processing.
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
time.Sleep(time.Millisecond * 20)
fmt.Println("Task 1 started")
route.CompleteTicket(context.TODO(), t1)
fmt.Println("Task 1 completed")
wg.Done()
route.CompleteTicket(context.TODO(), flightorder.CompleteTicketParams{
Ticket: t1,
Completion: func(ctx context.Context) error {
fmt.Println("Task 1 completed")
return nil
},
})
}()

go func() {
defer wg.Done()
time.Sleep(time.Millisecond * 30)
fmt.Println("Task 2 started")
route.CompleteTicket(context.TODO(), t2)
fmt.Println("Task 2 completed")
wg.Done()
route.CompleteTicket(context.TODO(), flightorder.CompleteTicketParams{
Ticket: t2,
Completion: func(ctx context.Context) error {
fmt.Println("Task 2 completed")
return nil
},
})
}()

go func() {
defer wg.Done()
time.Sleep(time.Millisecond * 10)
fmt.Println("Task 3 started")
route.CompleteTicket(context.TODO(), t3)
fmt.Println("Task 3 completed")
wg.Done()
route.CompleteTicket(context.TODO(), flightorder.CompleteTicketParams{
Ticket: t3,
Completion: func(ctx context.Context) error {
fmt.Println("Task 3 completed")
return nil
},
})
}()

wg.Wait()
Expand Down
16 changes: 10 additions & 6 deletions examples/numbers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func processInput(input []int) (processing, output []int) {

wg.Add(len(input))
for _, v := range input {
ticket := route.TakeTicket()
ticket := route.Ticket()
go func(t *flightorder.Ticket, v int) {
defer wg.Done()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
Expand All @@ -37,11 +37,15 @@ func processInput(input []int) (processing, output []int) {
processing = append(processing, v)
mux.Unlock()

_ = route.CompleteTicket(context.TODO(), t)

mux.Lock()
output = append(output, v)
mux.Unlock()
_ = route.CompleteTicket(context.TODO(), flightorder.CompleteTicketParams{
Ticket: t,
Completion: func(ctx context.Context) error {
mux.Lock()
output = append(output, v)
mux.Unlock()
return nil
},
})
}(ticket, v)
}

Expand Down
31 changes: 24 additions & 7 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ func NewRoute(params RouteParams) *Route {
}
}

// TakeTicket takes a new ticket.
func (r *Route) TakeTicket() *Ticket {
// Ticket takes a new ticket.
func (r *Route) Ticket() *Ticket {
r.mux.Lock()
defer r.mux.Unlock()

Expand All @@ -48,15 +48,28 @@ func (r *Route) TakeTicket() *Ticket {
return ticket
}

// CompleteTicketParams is a parameters for CompleteTicket method.
type CompleteTicketParams struct {
// Ticket to complete.
Ticket *Ticket
// Completion function, will be called in order tickets are taken. Optional.
Completion func(ctx context.Context) error
}

// CompleteTicket completes a ticket.
// Waits for previous taken tickets to complete first, if any.
func (r *Route) CompleteTicket(ctx context.Context, t *Ticket) error {
func (r *Route) CompleteTicket(ctx context.Context, params CompleteTicketParams) error {
if params.Completion == nil {
params.Completion = func(ctx context.Context) error { return nil }
}

t := params.Ticket
if r.recorder != nil {
r.recorder.completeCall(t)
}

if t.prev == nil {
return r.completeTail(t)
return r.completeTail(ctx, t, params.Completion)
}

if err := r.waitFor(ctx, t.prev); err != nil {
Expand All @@ -66,10 +79,10 @@ func (r *Route) CompleteTicket(ctx context.Context, t *Ticket) error {
r.allocator.ReleaseTicket(t.prev)
t.prev = nil

return r.completeTail(t)
return r.completeTail(ctx, t, params.Completion)
}

func (r *Route) completeTail(t *Ticket) error {
func (r *Route) completeTail(ctx context.Context, t *Ticket, f func(ctx context.Context) error) error {
r.mux.Lock()
defer r.mux.Unlock()

Expand All @@ -81,7 +94,11 @@ func (r *Route) completeTail(t *Ticket) error {
r.recorder.recordCompleted(t)
}

return nil
return f(ctx)
}

if err := f(ctx); err != nil {
return err
}

// There is a ticket ahead.
Expand Down
60 changes: 30 additions & 30 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func TestRoute(t *testing.T) {
TicketAllocator: alloc,
})

t1 := route.TakeTicket()
require.NoError(t, route.CompleteTicket(context.TODO(), t1))
t1 := route.Ticket()
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t1}))
require.Nil(t, route.last)
require.Equal(t, []*Ticket{t1}, alloc.released)
})
Expand All @@ -28,10 +28,10 @@ func TestRoute(t *testing.T) {
TicketAllocator: alloc,
})

t1 := route.TakeTicket()
t2 := route.TakeTicket()
require.NoError(t, route.CompleteTicket(context.TODO(), t1))
require.NoError(t, route.CompleteTicket(context.TODO(), t2))
t1 := route.Ticket()
t2 := route.Ticket()
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t1}))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t2}))
require.Nil(t, route.last)
require.Equal(t, []*Ticket{t1, t2}, alloc.released)
})
Expand All @@ -45,27 +45,27 @@ func TestRoute(t *testing.T) {
rec := newRecorder()
route.recorder = rec

t1 := route.TakeTicket()
t2 := route.TakeTicket()
t3 := route.TakeTicket()
t1 := route.Ticket()
t2 := route.Ticket()
t3 := route.Ticket()

var wg sync.WaitGroup
wg.Add(3)
go func() {
time.Sleep(time.Millisecond * 10)
require.NoError(t, route.CompleteTicket(context.TODO(), t3))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t3}))
wg.Done()
}()

go func() {
time.Sleep(time.Millisecond * 20)
require.NoError(t, route.CompleteTicket(context.TODO(), t2))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t2}))
wg.Done()
}()

go func() {
time.Sleep(time.Millisecond * 30)
require.NoError(t, route.CompleteTicket(context.TODO(), t1))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t1}))
wg.Done()
}()

Expand All @@ -85,27 +85,27 @@ func TestRoute(t *testing.T) {
rec := newRecorder()
route.recorder = rec

t1 := route.TakeTicket()
t2 := route.TakeTicket()
t3 := route.TakeTicket()
t1 := route.Ticket()
t2 := route.Ticket()
t3 := route.Ticket()

var wg sync.WaitGroup
wg.Add(3)
go func() {
time.Sleep(time.Millisecond * 10)
require.NoError(t, route.CompleteTicket(context.TODO(), t2))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t2}))
wg.Done()
}()

go func() {
time.Sleep(time.Millisecond * 20)
require.NoError(t, route.CompleteTicket(context.TODO(), t3))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t3}))
wg.Done()
}()

go func() {
time.Sleep(time.Millisecond * 30)
require.NoError(t, route.CompleteTicket(context.TODO(), t1))
require.NoError(t, route.CompleteTicket(context.TODO(), CompleteTicketParams{Ticket: t1}))
wg.Done()
}()

Expand All @@ -123,14 +123,14 @@ func BenchmarkRoute(b *testing.B) {
route := NewRoute(RouteParams{})
for i := 0; i < b.N; i++ {
var (
t1 = route.TakeTicket()
t2 = route.TakeTicket()
t3 = route.TakeTicket()
t1 = route.Ticket()
t2 = route.Ticket()
t3 = route.Ticket()
)

require.NoError(b, route.CompleteTicket(ctx, t1))
require.NoError(b, route.CompleteTicket(ctx, t2))
require.NoError(b, route.CompleteTicket(ctx, t3))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t1}))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t2}))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t3}))
}
})

Expand All @@ -140,14 +140,14 @@ func BenchmarkRoute(b *testing.B) {
})
for i := 0; i < b.N; i++ {
var (
t1 = route.TakeTicket()
t2 = route.TakeTicket()
t3 = route.TakeTicket()
t1 = route.Ticket()
t2 = route.Ticket()
t3 = route.Ticket()
)

require.NoError(b, route.CompleteTicket(ctx, t1))
require.NoError(b, route.CompleteTicket(ctx, t2))
require.NoError(b, route.CompleteTicket(ctx, t3))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t1}))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t2}))
require.NoError(b, route.CompleteTicket(ctx, CompleteTicketParams{Ticket: t3}))
}
})
}

0 comments on commit 135e39a

Please sign in to comment.