Skip to content
This repository was archived by the owner on Feb 8, 2021. It is now read-only.

Commit 1973d23

Browse files
committed
add wait.PollImmediate() and retool wait tests
1 parent e330b11 commit 1973d23

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed

pkg/kubectl/rolling_updater_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ func TestRollingUpdater_cleanupWithClients(t *testing.T) {
817817
t.Errorf("unexpected error: %v", err)
818818
}
819819
if len(fake.Actions()) != len(test.expected) {
820-
t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions, test.expected)
820+
t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions(), test.expected)
821821
}
822822
for j, action := range fake.Actions() {
823823
if e, a := test.expected[j], action.GetVerb(); e != a {

pkg/util/wait/wait.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,25 @@ type ConditionFunc func() (done bool, err error)
4545
// may be missed if the condition takes too long or the time window is too short.
4646
// If you want to Poll something forever, see PollInfinite.
4747
// Poll always waits the interval before the first check of the condition.
48-
// TODO: create a separate PollImmediate function that does not wait.
4948
func Poll(interval, timeout time.Duration, condition ConditionFunc) error {
50-
return WaitFor(poller(interval, timeout), condition)
49+
return pollInternal(poller(interval, timeout), condition)
50+
}
51+
func pollInternal(wait WaitFunc, condition ConditionFunc) error {
52+
return WaitFor(wait, condition)
53+
}
54+
55+
func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error {
56+
return pollImmediateInternal(poller(interval, timeout), condition)
57+
}
58+
func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error {
59+
done, err := condition()
60+
if err != nil {
61+
return err
62+
}
63+
if done {
64+
return nil
65+
}
66+
return pollInternal(wait, condition)
5167
}
5268

5369
// PollInfinite polls forever.
@@ -59,16 +75,16 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error {
5975
// should be executed and is closed when the last test should be invoked.
6076
type WaitFunc func() <-chan struct{}
6177

62-
// WaitFor gets a channel from wait(), and then invokes c once for every value
63-
// placed on the channel and once more when the channel is closed. If c
64-
// returns an error the loop ends and that error is returned, and if c returns
78+
// WaitFor gets a channel from wait(), and then invokes fn once for every value
79+
// placed on the channel and once more when the channel is closed. If fn
80+
// returns an error the loop ends and that error is returned, and if fn returns
6581
// true the loop ends and nil is returned. ErrWaitTimeout will be returned if
66-
// the channel is closed without c ever returning true.
67-
func WaitFor(wait WaitFunc, c ConditionFunc) error {
68-
w := wait()
82+
// the channel is closed without fn ever returning true.
83+
func WaitFor(wait WaitFunc, fn ConditionFunc) error {
84+
c := wait()
6985
for {
70-
_, open := <-w
71-
ok, err := c()
86+
_, open := <-c
87+
ok, err := fn()
7288
if err != nil {
7389
return err
7490
}

pkg/util/wait/wait_test.go

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package wait
1818

1919
import (
2020
"errors"
21+
"sync/atomic"
2122
"testing"
2223
"time"
2324

@@ -45,38 +46,99 @@ DRAIN:
4546
}
4647
}
4748

48-
func fakeTicker(count int) WaitFunc {
49+
func fakeTicker(max int, used *int32) WaitFunc {
4950
return func() <-chan struct{} {
5051
ch := make(chan struct{})
5152
go func() {
52-
for i := 0; i < count; i++ {
53+
for i := 0; i < max; i++ {
5354
ch <- struct{}{}
55+
if used != nil {
56+
atomic.AddInt32(used, 1)
57+
}
5458
}
5559
close(ch)
5660
}()
5761
return ch
5862
}
5963
}
6064

65+
type fakePoller struct {
66+
max int
67+
used int32 // accessed with atomics
68+
}
69+
70+
func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc {
71+
return fakeTicker(fp.max, &fp.used)
72+
}
73+
6174
func TestPoll(t *testing.T) {
6275
invocations := 0
6376
f := ConditionFunc(func() (bool, error) {
6477
invocations++
6578
return true, nil
6679
})
67-
if err := Poll(time.Microsecond, time.Microsecond, f); err != nil {
80+
fp := fakePoller{max: 1}
81+
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil {
82+
t.Fatalf("unexpected error %v", err)
83+
}
84+
if invocations != 1 {
85+
t.Errorf("Expected exactly one invocation, got %d", invocations)
86+
}
87+
used := atomic.LoadInt32(&fp.used)
88+
if used != 1 {
89+
t.Errorf("Expected exactly one tick, got %d", used)
90+
}
91+
92+
expectedError := errors.New("Expected error")
93+
f = ConditionFunc(func() (bool, error) {
94+
return false, expectedError
95+
})
96+
fp = fakePoller{max: 1}
97+
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError {
98+
t.Fatalf("Expected error %v, got none %v", expectedError, err)
99+
}
100+
if invocations != 1 {
101+
t.Errorf("Expected exactly one invocation, got %d", invocations)
102+
}
103+
used = atomic.LoadInt32(&fp.used)
104+
if used != 1 {
105+
t.Errorf("Expected exactly one tick, got %d", used)
106+
}
107+
}
108+
109+
func TestPollImmediate(t *testing.T) {
110+
invocations := 0
111+
f := ConditionFunc(func() (bool, error) {
112+
invocations++
113+
return true, nil
114+
})
115+
fp := fakePoller{max: 0}
116+
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil {
68117
t.Fatalf("unexpected error %v", err)
69118
}
70-
if invocations == 0 {
71-
t.Errorf("Expected at least one invocation, got zero")
119+
if invocations != 1 {
120+
t.Errorf("Expected exactly one invocation, got %d", invocations)
121+
}
122+
used := atomic.LoadInt32(&fp.used)
123+
if used != 0 {
124+
t.Errorf("Expected exactly zero ticks, got %d", used)
72125
}
126+
73127
expectedError := errors.New("Expected error")
74128
f = ConditionFunc(func() (bool, error) {
75129
return false, expectedError
76130
})
77-
if err := Poll(time.Microsecond, time.Microsecond, f); err == nil || err != expectedError {
131+
fp = fakePoller{max: 0}
132+
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError {
78133
t.Fatalf("Expected error %v, got none %v", expectedError, err)
79134
}
135+
if invocations != 1 {
136+
t.Errorf("Expected exactly one invocation, got %d", invocations)
137+
}
138+
used = atomic.LoadInt32(&fp.used)
139+
if used != 0 {
140+
t.Errorf("Expected exactly zero ticks, got %d", used)
141+
}
80142
}
81143

82144
func TestPollForever(t *testing.T) {
@@ -154,7 +216,7 @@ func TestWaitFor(t *testing.T) {
154216
return false, nil
155217
}),
156218
2,
157-
3,
219+
3, // the contract of WaitFor() says the func is called once more at the end of the wait
158220
true,
159221
},
160222
"returns immediately on error": {
@@ -169,7 +231,7 @@ func TestWaitFor(t *testing.T) {
169231
}
170232
for k, c := range testCases {
171233
invocations = 0
172-
ticker := fakeTicker(c.Ticks)
234+
ticker := fakeTicker(c.Ticks, nil)
173235
err := WaitFor(ticker, c.F)
174236
switch {
175237
case c.Err && err == nil:
@@ -180,7 +242,7 @@ func TestWaitFor(t *testing.T) {
180242
continue
181243
}
182244
if invocations != c.Invoked {
183-
t.Errorf("%s: Expected %d invocations, called %d", k, c.Invoked, invocations)
245+
t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations)
184246
}
185247
}
186248
}

0 commit comments

Comments
 (0)