@@ -18,6 +18,7 @@ package wait
1818
1919import (
2020 "errors"
21+ "sync/atomic"
2122 "testing"
2223 "time"
2324
@@ -45,38 +46,99 @@ DRAIN:
4546 }
4647}
4748
48- func fakeTicker (count int ) WaitFunc {
49+ func fakeTicker (max int , used * int32 ) WaitFunc {
4950 return func () <- chan struct {} {
5051 ch := make (chan struct {})
5152 go func () {
52- for i := 0 ; i < count ; i ++ {
53+ for i := 0 ; i < max ; i ++ {
5354 ch <- struct {}{}
55+ if used != nil {
56+ atomic .AddInt32 (used , 1 )
57+ }
5458 }
5559 close (ch )
5660 }()
5761 return ch
5862 }
5963}
6064
65+ type fakePoller struct {
66+ max int
67+ used int32 // accessed with atomics
68+ }
69+
70+ func (fp * fakePoller ) GetWaitFunc (interval , timeout time.Duration ) WaitFunc {
71+ return fakeTicker (fp .max , & fp .used )
72+ }
73+
6174func TestPoll (t * testing.T ) {
6275 invocations := 0
6376 f := ConditionFunc (func () (bool , error ) {
6477 invocations ++
6578 return true , nil
6679 })
67- if err := Poll (time .Microsecond , time .Microsecond , f ); err != nil {
80+ fp := fakePoller {max : 1 }
81+ if err := pollInternal (fp .GetWaitFunc (time .Microsecond , time .Microsecond ), f ); err != nil {
82+ t .Fatalf ("unexpected error %v" , err )
83+ }
84+ if invocations != 1 {
85+ t .Errorf ("Expected exactly one invocation, got %d" , invocations )
86+ }
87+ used := atomic .LoadInt32 (& fp .used )
88+ if used != 1 {
89+ t .Errorf ("Expected exactly one tick, got %d" , used )
90+ }
91+
92+ expectedError := errors .New ("Expected error" )
93+ f = ConditionFunc (func () (bool , error ) {
94+ return false , expectedError
95+ })
96+ fp = fakePoller {max : 1 }
97+ if err := pollInternal (fp .GetWaitFunc (time .Microsecond , time .Microsecond ), f ); err == nil || err != expectedError {
98+ t .Fatalf ("Expected error %v, got none %v" , expectedError , err )
99+ }
100+ if invocations != 1 {
101+ t .Errorf ("Expected exactly one invocation, got %d" , invocations )
102+ }
103+ used = atomic .LoadInt32 (& fp .used )
104+ if used != 1 {
105+ t .Errorf ("Expected exactly one tick, got %d" , used )
106+ }
107+ }
108+
109+ func TestPollImmediate (t * testing.T ) {
110+ invocations := 0
111+ f := ConditionFunc (func () (bool , error ) {
112+ invocations ++
113+ return true , nil
114+ })
115+ fp := fakePoller {max : 0 }
116+ if err := pollImmediateInternal (fp .GetWaitFunc (time .Microsecond , time .Microsecond ), f ); err != nil {
68117 t .Fatalf ("unexpected error %v" , err )
69118 }
70- if invocations == 0 {
71- t .Errorf ("Expected at least one invocation, got zero" )
119+ if invocations != 1 {
120+ t .Errorf ("Expected exactly one invocation, got %d" , invocations )
121+ }
122+ used := atomic .LoadInt32 (& fp .used )
123+ if used != 0 {
124+ t .Errorf ("Expected exactly zero ticks, got %d" , used )
72125 }
126+
73127 expectedError := errors .New ("Expected error" )
74128 f = ConditionFunc (func () (bool , error ) {
75129 return false , expectedError
76130 })
77- if err := Poll (time .Microsecond , time .Microsecond , f ); err == nil || err != expectedError {
131+ fp = fakePoller {max : 0 }
132+ if err := pollImmediateInternal (fp .GetWaitFunc (time .Microsecond , time .Microsecond ), f ); err == nil || err != expectedError {
78133 t .Fatalf ("Expected error %v, got none %v" , expectedError , err )
79134 }
135+ if invocations != 1 {
136+ t .Errorf ("Expected exactly one invocation, got %d" , invocations )
137+ }
138+ used = atomic .LoadInt32 (& fp .used )
139+ if used != 0 {
140+ t .Errorf ("Expected exactly zero ticks, got %d" , used )
141+ }
80142}
81143
82144func TestPollForever (t * testing.T ) {
@@ -154,7 +216,7 @@ func TestWaitFor(t *testing.T) {
154216 return false , nil
155217 }),
156218 2 ,
157- 3 ,
219+ 3 , // the contract of WaitFor() says the func is called once more at the end of the wait
158220 true ,
159221 },
160222 "returns immediately on error" : {
@@ -169,7 +231,7 @@ func TestWaitFor(t *testing.T) {
169231 }
170232 for k , c := range testCases {
171233 invocations = 0
172- ticker := fakeTicker (c .Ticks )
234+ ticker := fakeTicker (c .Ticks , nil )
173235 err := WaitFor (ticker , c .F )
174236 switch {
175237 case c .Err && err == nil :
@@ -180,7 +242,7 @@ func TestWaitFor(t *testing.T) {
180242 continue
181243 }
182244 if invocations != c .Invoked {
183- t .Errorf ("%s: Expected %d invocations, called %d" , k , c .Invoked , invocations )
245+ t .Errorf ("%s: Expected %d invocations, got %d" , k , c .Invoked , invocations )
184246 }
185247 }
186248}
0 commit comments