diff --git a/utils/priority_mutex.go b/utils/priority_mutex.go new file mode 100644 index 00000000..7802db78 --- /dev/null +++ b/utils/priority_mutex.go @@ -0,0 +1,86 @@ +// Copyright 2020 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" +) + +// PriorityMutex is a special type of mutex +// that allows callers to request priority +// over other callers. This can be useful +// if there is a "hot path" in an application +// that requires lock access. +// +// WARNING: It is possible to cause lock starvation +// if not careful (i.e. only high priority callers +// ever do work). +type PriorityMutex struct { + high []chan struct{} + low []chan struct{} + + m sync.Mutex + l bool +} + +// Lock attempts to acquire either a high or low +// priority mutex. When priority is true, a lock +// will be granted before other low priority callers. +func (m *PriorityMutex) Lock(priority bool) { + m.m.Lock() + + if !m.l { + m.l = true + m.m.Unlock() + return + } + + c := make(chan struct{}) + if priority { + m.high = append(m.high, c) + } else { + m.low = append(m.low, c) + } + + m.m.Unlock() + <-c +} + +// Unlock selects the next highest priority lock +// to grant. If there are no locks to grant, it +// sets the value of m.l to false. +func (m *PriorityMutex) Unlock() { + m.m.Lock() + defer m.m.Unlock() + + if len(m.high) > 0 { + c := m.high[0] + m.high = m.high[1:] + close(c) + return + } + + if len(m.low) > 0 { + c := m.low[0] + m.low = m.low[1:] + close(c) + return + } + + // We only set m.l to false when there are + // no items to unlock because it could cause + // lock contention for the next lock to fetch it. + m.l = false +} diff --git a/utils/priority_mutex_test.go b/utils/priority_mutex_test.go new file mode 100644 index 00000000..2a1c2fca --- /dev/null +++ b/utils/priority_mutex_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 Coinbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" +) + +func TestPriorityMutex(t *testing.T) { + arr := []bool{} + expected := make([]bool, 60) + l := new(PriorityMutex) + g, _ := errgroup.WithContext(context.Background()) + + // Lock while adding all locks + l.Lock(true) + + // Add a bunch of low prio items + for i := 0; i < 50; i++ { + expected[i+10] = false + g.Go(func() error { + l.Lock(false) + arr = append(arr, false) + l.Unlock() + return nil + }) + } + + // Add a few high prio items + for i := 0; i < 10; i++ { + expected[i] = true + g.Go(func() error { + l.Lock(true) + arr = append(arr, true) + l.Unlock() + return nil + }) + } + + // Wait for all goroutines to ask for lock + time.Sleep(1 * time.Second) + + // Ensure number of expected locks is correct + assert.Len(t, l.high, 10) + assert.Len(t, l.low, 50) + + l.Unlock() + assert.NoError(t, g.Wait()) + + // Check results array to ensure all of the high priority items processed first, + // followed by all of the low priority items. + assert.Equal(t, expected, arr) + + // Ensure lock is no longer occupied + assert.False(t, l.l) +}