Skip to content

Commit

Permalink
fix requests tracker concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
atiratree committed May 13, 2024
1 parent 2956e29 commit 1c8b0f6
Showing 1 changed file with 55 additions and 56 deletions.
111 changes: 55 additions & 56 deletions pkg/controller/statefulset/stateful_set_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2423,51 +2423,53 @@ type requestTracker struct {
err error
after int

parallelLock sync.Mutex
parallel int
maxParallel int

delay time.Duration
// this block should be updated consistently
parallelLock sync.Mutex
shouldTrackParallelRequests bool
parallelRequests int
maxParallelRequests int
parallelRequestDelay time.Duration
}

func (rt *requestTracker) errorReady() bool {
rt.Lock()
defer rt.Unlock()
return rt.err != nil && rt.requests >= rt.after
}

func (rt *requestTracker) inc() {
rt.parallelLock.Lock()
rt.parallel++
if rt.maxParallel < rt.parallel {
rt.maxParallel = rt.parallel
func (rt *requestTracker) trackParallelRequests() {
if !rt.shouldTrackParallelRequests {
// do not track parallel requests unless specifically enabled
return
}
rt.parallelLock.Unlock()

rt.Lock()
defer rt.Unlock()
rt.requests++
if rt.delay != 0 {
time.Sleep(rt.delay)
if rt.parallelLock.TryLock() {
// lock acquired: we are the only or the first concurrent request
// initialize the next set of parallel requests
rt.parallelRequests = 1
} else {
// lock is held by other requests
// now wait for the lock to increase the parallelRequests
rt.parallelLock.Lock()
rt.parallelRequests++
}
defer rt.parallelLock.Unlock()
// update the local maximum of parallel collisions
if rt.maxParallelRequests < rt.parallelRequests {
rt.maxParallelRequests = rt.parallelRequests
}
// increase the chance of collisions
if rt.parallelRequestDelay > 0 {
time.Sleep(rt.parallelRequestDelay)
}
}

func (rt *requestTracker) reset() {
rt.parallelLock.Lock()
rt.parallel = 0
rt.parallelLock.Unlock()

rt.Lock()
defer rt.Unlock()
rt.err = nil
rt.after = 0
rt.delay = 0
}

func (rt *requestTracker) getErr() error {
func (rt *requestTracker) incWithOptionalError() error {
rt.Lock()
defer rt.Unlock()
return rt.err
rt.requests++
if rt.err != nil && rt.requests >= rt.after {
// reset and pass the error
defer func() {
rt.err = nil
rt.after = 0
}()
return rt.err
}
return nil
}

func newRequestTracker(requests int, err error, after int) requestTracker {
Expand Down Expand Up @@ -2512,10 +2514,9 @@ func newFakeObjectManager(informerFactory informers.SharedInformerFactory) *fake
}

func (om *fakeObjectManager) CreatePod(ctx context.Context, pod *v1.Pod) error {
defer om.createPodTracker.inc()
if om.createPodTracker.errorReady() {
defer om.createPodTracker.reset()
return om.createPodTracker.getErr()
defer om.createPodTracker.trackParallelRequests()
if err := om.createPodTracker.incWithOptionalError(); err != nil {
return err
}
pod.SetUID(types.UID(pod.Name + "-uid"))
return om.podsIndexer.Update(pod)
Expand All @@ -2526,19 +2527,17 @@ func (om *fakeObjectManager) GetPod(namespace, podName string) (*v1.Pod, error)
}

func (om *fakeObjectManager) UpdatePod(pod *v1.Pod) error {
defer om.updatePodTracker.inc()
if om.updatePodTracker.errorReady() {
defer om.updatePodTracker.reset()
return om.updatePodTracker.getErr()
defer om.updatePodTracker.trackParallelRequests()
if err := om.updatePodTracker.incWithOptionalError(); err != nil {
return err
}
return om.podsIndexer.Update(pod)
}

func (om *fakeObjectManager) DeletePod(pod *v1.Pod) error {
defer om.deletePodTracker.inc()
if om.deletePodTracker.errorReady() {
defer om.deletePodTracker.reset()
return om.deletePodTracker.getErr()
defer om.deletePodTracker.trackParallelRequests()
if err := om.deletePodTracker.incWithOptionalError(); err != nil {
return err
}
if key, err := controller.KeyFunc(pod); err != nil {
return err
Expand Down Expand Up @@ -2733,10 +2732,9 @@ func newFakeStatefulSetStatusUpdater(setInformer appsinformers.StatefulSetInform
}

func (ssu *fakeStatefulSetStatusUpdater) UpdateStatefulSetStatus(ctx context.Context, set *apps.StatefulSet, status *apps.StatefulSetStatus) error {
defer ssu.updateStatusTracker.inc()
if ssu.updateStatusTracker.errorReady() {
defer ssu.updateStatusTracker.reset()
return ssu.updateStatusTracker.err
defer ssu.updateStatusTracker.trackParallelRequests()
if err := ssu.updateStatusTracker.incWithOptionalError(); err != nil {
return err
}
set.Status = *status
ssu.setsIndexer.Update(set)
Expand Down Expand Up @@ -2985,7 +2983,8 @@ func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplica
diff := desiredReplicas - replicas
client := fake.NewSimpleClientset(set)
om, _, ssc := setupController(client)
om.createPodTracker.delay = time.Millisecond
om.createPodTracker.shouldTrackParallelRequests = true
om.createPodTracker.parallelRequestDelay = time.Millisecond

*set.Spec.Replicas = replicas
if err := parallelScaleUpStatefulSetControl(set, ssc, om, invariants); err != nil {
Expand Down Expand Up @@ -3017,8 +3016,8 @@ func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplica
t.Errorf("Failed to scale statefulset to %v replicas, got %v replicas", desiredReplicas, set.Status.Replicas)
}

if (diff < -1 || diff > 1) && om.createPodTracker.maxParallel <= 1 {
t.Errorf("want max parallel requests > 1, got %v", om.createPodTracker.maxParallel)
if (diff < -1 || diff > 1) && om.createPodTracker.maxParallelRequests <= 1 {
t.Errorf("want max parallel requests > 1, got %v", om.createPodTracker.maxParallelRequests)
}
}

Expand Down

0 comments on commit 1c8b0f6

Please sign in to comment.