diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 8a89f5db..9de654cc 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,6 +1,7 @@ # 开发中 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) - [queue: API 定义](https://github.com/gotomicro/ekit/pull/109) +- [queue: 基于堆和切片的优先级队列](https://github.com/gotomicro/ekit/pull/110) # v0.0.4 - [slice: 重构 index 和 contains 的方法,直接调用对应Func 版本](https://github.com/gotomicro/ekit/pull/87) diff --git a/internal/queue/priority_queue.go b/internal/queue/priority_queue.go new file mode 100644 index 00000000..28a6be72 --- /dev/null +++ b/internal/queue/priority_queue.go @@ -0,0 +1,136 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "errors" + + "github.com/gotomicro/ekit/internal/slice" + + "github.com/gotomicro/ekit" +) + +var ( + ErrOutOfCapacity = errors.New("ekit: 超出最大容量限制") + ErrEmptyQueue = errors.New("ekit: 队列为空") +) + +// PriorityQueue 是一个基于小顶堆的优先队列 +// 当capacity <= 0时,为无界队列,切片容量会动态扩缩容 +// 当capacity > 0 时,为有界队列,初始化后就固定容量,不会扩缩容 +type PriorityQueue[T any] struct { + // 用于比较前一个元素是否小于后一个元素 + compare ekit.Comparator[T] + // 队列容量 + capacity int + // 队列中的元素,为便于计算父子节点的index,0位置留空,根节点从1开始 + data []T +} + +func (p *PriorityQueue[T]) Len() int { + return len(p.data) - 1 +} + +// Cap 无界队列返回0,有界队列返回创建队列时设置的值 +func (p *PriorityQueue[T]) Cap() int { + return p.capacity +} + +func (p *PriorityQueue[T]) IsBoundless() bool { + return p.capacity <= 0 +} + +func (p *PriorityQueue[T]) isFull() bool { + return p.capacity > 0 && len(p.data)-1 == p.capacity +} + +func (p *PriorityQueue[T]) isEmpty() bool { + return len(p.data) < 2 +} + +func (p *PriorityQueue[T]) Peek() (T, error) { + if p.isEmpty() { + var t T + return t, ErrEmptyQueue + } + return p.data[1], nil +} + +func (p *PriorityQueue[T]) Enqueue(t T) error { + if p.isFull() { + return ErrOutOfCapacity + } + + p.data = append(p.data, t) + node, parent := len(p.data)-1, (len(p.data)-1)/2 + for parent > 0 && p.compare(p.data[node], p.data[parent]) < 0 { + p.data[parent], p.data[node] = p.data[node], p.data[parent] + node = parent + parent = parent / 2 + } + + return nil +} + +func (p *PriorityQueue[T]) Dequeue() (T, error) { + if p.isEmpty() { + var t T + return t, ErrEmptyQueue + } + + pop := p.data[1] + p.data[1] = p.data[len(p.data)-1] + p.data = p.data[:len(p.data)-1] + p.shrinkIfNecessary() + p.heapify(p.data, len(p.data)-1, 1) + return pop, nil +} + +func (p *PriorityQueue[T]) shrinkIfNecessary() { + if p.IsBoundless() { + p.data = slice.Shrink[T](p.data) + } +} + +func (p *PriorityQueue[T]) heapify(data []T, n, i int) { + minPos := i + for { + if left := i * 2; left <= n && p.compare(data[left], data[minPos]) < 0 { + minPos = left + } + if right := i*2 + 1; right <= n && p.compare(data[right], data[minPos]) < 0 { + minPos = right + } + if minPos == i { + break + } + data[i], data[minPos] = data[minPos], data[i] + i = minPos + } +} + +// NewPriorityQueue 创建优先队列 capacity <= 0 时,为无界队列,否则有有界队列 +func NewPriorityQueue[T any](capacity int, compare ekit.Comparator[T]) *PriorityQueue[T] { + sliceCap := capacity + 1 + if capacity < 1 { + capacity = 0 + sliceCap = 64 + } + return &PriorityQueue[T]{ + capacity: capacity, + data: make([]T, 1, sliceCap), + compare: compare, + } +} diff --git a/internal/queue/priority_queue_test.go b/internal/queue/priority_queue_test.go new file mode 100644 index 00000000..40fb2240 --- /dev/null +++ b/internal/queue/priority_queue_test.go @@ -0,0 +1,496 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotomicro/ekit" + + "github.com/stretchr/testify/assert" +) + +func TestNewPriorityQueue(t *testing.T) { + data := []int{6, 5, 4, 3, 2, 1} + testCases := []struct { + name string + q *PriorityQueue[int] + capacity int + data []int + expected []int + }{ + { + name: "无边界", + q: NewPriorityQueue(0, compare()), + capacity: 0, + data: data, + expected: []int{1, 2, 3, 4, 5, 6}, + }, + { + name: "有边界 ", + q: NewPriorityQueue(len(data), compare()), + capacity: len(data), + data: data, + expected: []int{1, 2, 3, 4, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, 0, tc.q.Len()) + for _, d := range data { + err := tc.q.Enqueue(d) + assert.NoError(t, err) + if err != nil { + return + } + } + assert.Equal(t, tc.capacity, tc.q.Cap()) + assert.Equal(t, len(data), tc.q.Len()) + res := make([]int, 0, len(data)) + for tc.q.Len() > 0 { + el, err := tc.q.Dequeue() + assert.NoError(t, err) + if err != nil { + return + } + res = append(res, el) + } + assert.Equal(t, tc.expected, res) + }) + + } + +} + +func TestPriorityQueue_Peek(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + wantErr error + }{ + { + name: "有数据", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + wantErr: ErrEmptyQueue, + }, + { + name: "无数据", + capacity: 0, + data: []int{}, + wantErr: ErrEmptyQueue, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.capacity, compare()) + for _, el := range tc.data { + err := q.Enqueue(el) + require.NoError(t, err) + } + for q.Len() > 0 { + peek, err := q.Peek() + assert.NoError(t, err) + el, _ := q.Dequeue() + assert.Equal(t, el, peek) + } + _, err := q.Peek() + assert.Equal(t, tc.wantErr, err) + }) + + } +} + +func TestPriorityQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + element int + wantErr error + }{ + { + name: "有界空队列", + capacity: 10, + data: []int{}, + element: 10, + }, + { + name: "有界满队列", + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + wantErr: ErrOutOfCapacity, + }, + { + name: "有界非空不满队列", + capacity: 12, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + }, + { + name: "无界空队列", + capacity: 0, + data: []int{}, + element: 10, + }, + { + name: "无界非空队列", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + element: 10, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(tc.capacity, tc.data, compare()) + require.NotNil(t, q) + err := q.Enqueue(tc.element) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.capacity, q.Cap()) + }) + + } +} + +func TestPriorityQueue_EnqueueElement(t *testing.T) { + testCases := []struct { + name string + data []int + element int + wantSlice []int + }{ + { + name: "新加入的元素是最大的", + data: []int{10, 8, 7, 6, 2}, + element: 20, + wantSlice: []int{0, 2, 6, 8, 10, 7, 20}, + }, + { + name: "新加入的元素是最小的", + data: []int{10, 8, 7, 6, 2}, + element: 1, + wantSlice: []int{0, 1, 6, 2, 10, 7, 8}, + }, + { + name: "新加入的元素子区间中", + data: []int{10, 8, 7, 6, 2}, + element: 5, + wantSlice: []int{0, 2, 6, 5, 10, 7, 8}, + }, + { + name: "新加入的元素与已有元素相同", + data: []int{10, 8, 7, 6, 2}, + element: 6, + wantSlice: []int{0, 2, 6, 6, 10, 7, 8}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(0, tc.data, compare()) + require.NotNil(t, q) + err := q.Enqueue(tc.element) + require.NoError(t, err) + assert.Equal(t, tc.wantSlice, q.data) + }) + + } +} + +func TestPriorityQueue_EnqueueHeapStruct(t *testing.T) { + data := []int{6, 5, 4, 3, 2, 1} + testCases := []struct { + name string + capacity int + data []int + wantSlice []int + pivot int + pivotData []int + }{ + { + name: "队列满", + capacity: len(data), + data: data, + pivot: 2, + pivotData: []int{0, 4, 6, 5}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + { + name: "队列不满", + capacity: len(data) * 2, + data: data, + pivot: 3, + pivotData: []int{0, 3, 4, 5, 6}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + { + name: "无界队列", + capacity: 0, + data: data, + pivot: 3, + pivotData: []int{0, 3, 4, 5, 6}, + wantSlice: []int{0, 1, 3, 2, 6, 4, 5}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.capacity, compare()) + for i, el := range tc.data { + require.NoError(t, q.Enqueue(el)) + // 检查中途堆结构堆调整,是否符合预期 + if i == tc.pivot { + assert.Equal(t, tc.pivotData, q.data) + } + } + // 检查最终堆结构,是否符合预期 + assert.Equal(t, tc.wantSlice, q.data) + }) + + } +} + +func TestPriorityQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + data []int + wantErr error + wantVal int + wantSlice []int + }{ + { + name: "空队列", + data: []int{}, + wantErr: ErrEmptyQueue, + }, + { + name: "只有一个元素", + data: []int{10}, + wantVal: 10, + wantSlice: []int{0}, + }, + { + name: "many", + data: []int{6, 5, 4, 3, 2, 1}, + wantVal: 1, + wantSlice: []int{0, 2, 3, 5, 6, 4}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(0, tc.data, compare()) + require.NotNil(t, q) + val, err := q.Dequeue() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantSlice, q.data) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestPriorityQueue_DequeueComplexCheck(t *testing.T) { + testCases := []struct { + name string + capacity int + data []int + pivot int + want []int + }{ + { + name: "无边界", + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + pivot: 2, + want: []int{0, 4, 6, 5}, + }, + { + name: "有边界", + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + pivot: 3, + want: []int{0, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := priorityQueueOf(tc.capacity, tc.data, compare()) + require.NotNil(t, q) + i := 0 + prev := -1 + for q.Len() > 0 { + el, err := q.Dequeue() + require.NoError(t, err) + // 检查中途出队后,堆结构堆调整是否符合预期 + if i == tc.pivot { + assert.Equal(t, tc.want, q.data) + } + // 检查出队是否有序 + assert.LessOrEqual(t, prev, el) + prev = el + i++ + } + }) + + } +} + +func TestPriorityQueue_Shrink(t *testing.T) { + var compare ekit.Comparator[int] = func(a, b int) int { + if a < b { + return -1 + } + if a == b { + return 0 + } + return 1 + } + testCases := []struct { + name string + originCap int + enqueueLoop int + dequeueLoop int + expectCap int + sliceCap int + }{ + { + name: "有界,小于64", + originCap: 32, + enqueueLoop: 6, + dequeueLoop: 5, + expectCap: 32, + sliceCap: 33, + }, + { + name: "有界,小于2048, 不足1/4", + originCap: 1000, + enqueueLoop: 20, + dequeueLoop: 5, + expectCap: 1000, + sliceCap: 1001, + }, + { + name: "有界,小于2048, 超过1/4", + originCap: 1000, + enqueueLoop: 400, + dequeueLoop: 5, + expectCap: 1000, + sliceCap: 1001, + }, + { + name: "有界,大于2048,不足一半", + originCap: 3000, + enqueueLoop: 400, + dequeueLoop: 40, + expectCap: 3000, + sliceCap: 3001, + }, + { + name: "有界,大于2048,大于一半", + originCap: 3000, + enqueueLoop: 2000, + dequeueLoop: 5, + expectCap: 3000, + sliceCap: 3001, + }, + { + name: "无界,小于64", + originCap: 0, + enqueueLoop: 30, + dequeueLoop: 5, + expectCap: 0, + sliceCap: 64, + }, + { + name: "无界,小于2048, 不足1/4", + originCap: 0, + enqueueLoop: 2000, + dequeueLoop: 1990, + expectCap: 0, + sliceCap: 50, + }, + { + name: "无界,小于2048, 超过1/4", + originCap: 0, + enqueueLoop: 2000, + dequeueLoop: 600, + expectCap: 0, + sliceCap: 2560, + }, + { + name: "无界,大于2048,不足一半", + originCap: 0, + enqueueLoop: 3000, + dequeueLoop: 2000, + expectCap: 0, + sliceCap: 1331, + }, + { + name: "无界,大于2048,大于一半", + originCap: 0, + enqueueLoop: 3000, + dequeueLoop: 5, + expectCap: 0, + sliceCap: 3408, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewPriorityQueue[int](tc.originCap, compare) + for i := 0; i < tc.enqueueLoop; i++ { + err := q.Enqueue(i) + if err != nil { + return + } + } + for i := 0; i < tc.dequeueLoop; i++ { + _, err := q.Dequeue() + if err != nil { + return + } + } + assert.Equal(t, tc.expectCap, q.Cap()) + assert.Equal(t, tc.sliceCap, cap(q.data)) + }) + } +} + +func priorityQueueOf(capacity int, data []int, compare ekit.Comparator[int]) *PriorityQueue[int] { + q := NewPriorityQueue[int](capacity, compare) + for _, el := range data { + err := q.Enqueue(el) + if err != nil { + return nil + } + } + return q +} + +func compare() ekit.Comparator[int] { + return func(a, b int) int { + if a < b { + return -1 + } + if a == b { + return 0 + } + return 1 + } +} diff --git a/internal/slice/shink_test.go b/internal/slice/shink_test.go new file mode 100644 index 00000000..366fd207 --- /dev/null +++ b/internal/slice/shink_test.go @@ -0,0 +1,72 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShrink(t *testing.T) { + testCases := []struct { + name string + originCap int + enqueueLoop int + expectCap int + }{ + { + name: "小于64", + originCap: 32, + enqueueLoop: 6, + expectCap: 32, + }, + { + name: "小于2048, 不足1/4", + originCap: 1000, + enqueueLoop: 20, + expectCap: 500, + }, + { + name: "小于2048, 超过1/4", + originCap: 1000, + enqueueLoop: 400, + expectCap: 1000, + }, + { + name: "大于2048,不足一半", + originCap: 3000, + enqueueLoop: 60, + expectCap: 1875, + }, + { + name: "大于2048,大于一半", + originCap: 3000, + enqueueLoop: 2000, + expectCap: 3000, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + l := make([]int, 0, tc.originCap) + + for i := 0; i < tc.enqueueLoop; i++ { + l = append(l, i) + } + l = Shrink[int](l) + assert.Equal(t, tc.expectCap, cap(l)) + }) + } +} diff --git a/internal/slice/shrink.go b/internal/slice/shrink.go new file mode 100644 index 00000000..911adf3a --- /dev/null +++ b/internal/slice/shrink.go @@ -0,0 +1,40 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package slice + +func calCapacity(c, l int) (int, bool) { + if c <= 64 { + return c, false + } + if c > 2048 && (c/l >= 2) { + factor := 0.625 + return int(float32(c) * float32(factor)), true + } + if c <= 2048 && (c/l >= 4) { + return c / 2, true + } + return c, false +} + +func Shrink[T any](src []T) []T { + c, l := cap(src), len(src) + n, changed := calCapacity(c, l) + if !changed { + return src + } + s := make([]T, 0, n) + s = append(s, src...) + return s +} diff --git a/list/array_list.go b/list/array_list.go index 90c9d084..bf03ce02 100644 --- a/list/array_list.go +++ b/list/array_list.go @@ -92,22 +92,7 @@ func (a *ArrayList[T]) Delete(index int) (T, error) { // shrink 数组缩容 func (a *ArrayList[T]) shrink() { - var newCap int - c, l := a.Cap(), a.Len() - if c <= 64 { - return - } - if c > 2048 && (c/l >= 2) { - newCap = int(float32(c) * float32(0.625)) - } else if c <= 2048 && (c/l >= 4) { - newCap = c / 2 - } else { - // 不满足缩容 - return - } - newSlice := make([]T, 0, newCap) - newSlice = append(newSlice, a.vals...) - a.vals = newSlice + a.vals = slice.Shrink(a.vals) } func (a *ArrayList[T]) Len() int { diff --git a/queue/concurrent_priority_queue.go b/queue/concurrent_priority_queue.go new file mode 100644 index 00000000..ae722691 --- /dev/null +++ b/queue/concurrent_priority_queue.go @@ -0,0 +1,65 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "sync" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/queue" +) + +type ConcurrentPriorityQueue[T any] struct { + pq queue.PriorityQueue[T] + m sync.RWMutex +} + +func (c *ConcurrentPriorityQueue[T]) Len() int { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Len() +} + +// Cap 无界队列返回0,有界队列返回创建队列时设置的值 +func (c *ConcurrentPriorityQueue[T]) Cap() int { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Cap() +} + +func (c *ConcurrentPriorityQueue[T]) Peek() (T, error) { + c.m.RLock() + defer c.m.RUnlock() + return c.pq.Peek() +} + +func (c *ConcurrentPriorityQueue[T]) Enqueue(t T) error { + c.m.Lock() + defer c.m.Unlock() + return c.pq.Enqueue(t) +} + +func (c *ConcurrentPriorityQueue[T]) Dequeue() (T, error) { + c.m.Lock() + defer c.m.Unlock() + return c.pq.Dequeue() +} + +// NewConcurrentPriorityQueue 创建优先队列 capacity <= 0 时,为无界队列 +func NewConcurrentPriorityQueue[T any](capacity int, compare ekit.Comparator[T]) *ConcurrentPriorityQueue[T] { + return &ConcurrentPriorityQueue[T]{ + pq: *queue.NewPriorityQueue[T](capacity, compare), + } +} diff --git a/queue/concurrent_priority_queue_test.go b/queue/concurrent_priority_queue_test.go new file mode 100644 index 00000000..947c61a2 --- /dev/null +++ b/queue/concurrent_priority_queue_test.go @@ -0,0 +1,305 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "sync" + "testing" + + "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/queue" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + errOutOfCapacity = queue.ErrOutOfCapacity + errEmptyQueue = queue.ErrEmptyQueue +) + +func TestNewConcurrentPriorityQueue(t *testing.T) { + testCases := []struct { + name string + q *ConcurrentPriorityQueue[int] + capacity int + data []int + expect []int + }{ + { + name: "无边界", + q: NewConcurrentPriorityQueue(0, compare()), + capacity: 0, + data: []int{6, 5, 4, 3, 2, 1}, + expect: []int{1, 2, 3, 4, 5, 6}, + }, + { + name: "有边界 ", + q: NewConcurrentPriorityQueue(6, compare()), + capacity: 6, + data: []int{6, 5, 4, 3, 2, 1}, + expect: []int{1, 2, 3, 4, 5, 6}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, 0, tc.q.Len()) + for _, d := range tc.data { + require.NoError(t, tc.q.Enqueue(d)) + } + assert.Equal(t, tc.capacity, tc.q.Cap()) + assert.Equal(t, len(tc.data), tc.q.Len()) + res := make([]int, 0, len(tc.data)) + for tc.q.Len() > 0 { + head, err := tc.q.Peek() + require.NoError(t, err) + el, err := tc.q.Dequeue() + require.NoError(t, err) + assert.Equal(t, head, el) + res = append(res, el) + } + assert.Equal(t, tc.expect, res) + }) + + } + +} + +// 多个go routine 执行入队操作,完成后,主携程把元素逐一出队,只要有序,可以认为并发入队没问题 +func TestConcurrentPriorityQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + capacity int + concurrency int + perRoutine int + wantSlice []int + remain int + wantErr error + errCount int + }{ + { + name: "不超过capacity", + capacity: 1100, + concurrency: 100, + perRoutine: 10, + }, + { + name: "超过capacity", + capacity: 1000, + concurrency: 101, + perRoutine: 10, + wantErr: errOutOfCapacity, + errCount: 10, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](tc.capacity, compare()) + wg := sync.WaitGroup{} + wg.Add(tc.concurrency) + errChan := make(chan error, tc.capacity) + for i := tc.concurrency; i > 0; i-- { + go func(i int) { + start := i * 10 + for j := 0; j < tc.perRoutine; j++ { + err := q.Enqueue(start + j) + if err != nil { + errChan <- err + } + } + wg.Done() + }(i) + } + wg.Wait() + assert.Equal(t, tc.errCount, len(errChan)) + prev := -1 + for q.Len() > 0 { + el, _ := q.Dequeue() + assert.Less(t, prev, el) + + // 入队元素总数小于capacity时,应该所有元素都入队了,出队顺序应该依次加1 + if prev > -1 && len(errChan) == 0 { + assert.Equal(t, prev+1, el) + } + prev = el + } + }) + + } +} + +// 预先入队一组数据,通过测试多个协程并发出队时,每个协程内出队元素有序,间接确认并发安全 +func TestConcurrentPriorityQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + total int + concurrency int + perRoutine int + wantSlice []int + remain int + wantErr error + errCount int + }{ + { + name: "入队大于出队", + total: 910, + concurrency: 100, + perRoutine: 9, + remain: 10, + }, + { + name: "入队小于出队", + total: 900, + concurrency: 101, + perRoutine: 9, + remain: 0, + wantErr: errEmptyQueue, + errCount: 9, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](tc.total, compare()) + for i := tc.total; i > 0; i-- { + require.NoError(t, q.Enqueue(i)) + } + + resultChan := make(chan int, tc.concurrency*tc.perRoutine) + disOrderChan := make(chan bool, tc.concurrency*tc.perRoutine) + errChan := make(chan error, tc.errCount) + wg := sync.WaitGroup{} + wg.Add(tc.concurrency) + + for i := 0; i < tc.concurrency; i++ { + go func() { + prev := -1 + for i := 0; i < tc.perRoutine; i++ { + el, err := q.Dequeue() + if err != nil { + // 如果出队报错,把错误放到error通道,以便后续检查错误的内容和数量是否符合预期 + errChan <- err + } else { + // 如果出队不报错,则检查出队结果是否符合预期 + resultChan <- el + if prev >= el { + disOrderChan <- false + } + prev = el + } + + } + wg.Done() + }() + } + wg.Wait() + close(resultChan) + close(errChan) + close(disOrderChan) + + // 检查并发出队的元素数量,是否符合预期 + assert.Equal(t, tc.remain, q.Len()) + + // 检查所有协程中的执行错误,是否符合预期 + assert.Equal(t, tc.errCount, len(errChan)) + for err := range errChan { + assert.Equal(t, tc.wantErr, err) + } + + // 每个协程内部,出队元素应该有序,检查是否发现无序的情况 + assert.Equal(t, 0, len(disOrderChan)) + + // 每个协程的每次出队操作,出队元素都应该不同,检查是否符合预期 + resultSet := make(map[int]bool) + for el := range resultChan { + _, ok := resultSet[el] + assert.Equal(t, false, ok) + resultSet[el] = true + } + + }) + + } +} + +// 测试同时并发出入队。只要并发安全,并发出入队后的剩余元素数量+报错数量应该符合预期 +// TODO 有待设计更好的并发出入队测试方案 +func TestConcurrentPriorityQueue_EnqueueDequeue(t *testing.T) { + testCases := []struct { + name string + enqueue int + dequeue int + remain int + }{ + { + name: "出队等于入队", + enqueue: 50, + dequeue: 50, + remain: 0, + }, + { + name: "出队小于入队", + enqueue: 50, + dequeue: 40, + remain: 10, + }, + { + name: "出队大于入队", + enqueue: 50, + dequeue: 60, + remain: -10, + }, + } + for _, tt := range testCases { + tc := tt + t.Run(tc.name, func(t *testing.T) { + q := NewConcurrentPriorityQueue[int](0, compare()) + errChan := make(chan error, tc.dequeue) + wg := sync.WaitGroup{} + wg.Add(tc.enqueue + tc.dequeue) + go func() { + for i := 0; i < tc.enqueue; i++ { + go func(i int) { + require.NoError(t, q.Enqueue(i)) + wg.Done() + }(i) + } + }() + go func() { + for i := 0; i < tc.dequeue; i++ { + _, err := q.Dequeue() + if err != nil { + errChan <- err + } + wg.Done() + } + }() + + wg.Wait() + close(errChan) + assert.Equal(t, tc.remain, q.Len()-len(errChan)) + }) + } +} + +func compare() ekit.Comparator[int] { + return func(a, b int) int { + if a < b { + return -1 + } + if a == b { + return 0 + } + return 1 + } +} diff --git a/types.go b/types.go new file mode 100644 index 00000000..70b1fba7 --- /dev/null +++ b/types.go @@ -0,0 +1,18 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ekit + +// Comparator 用于比较两个对象的大小 src < dst, 返回-1,src = dst, 返回0,src > dst, 返回1 +type Comparator[T any] func(src T, dst T) int