Skip to content

Commit 52ae936

Browse files
committed
WIP: graph_test.go compiles!
1 parent 4875617 commit 52ae936

File tree

2 files changed

+84
-72
lines changed

2 files changed

+84
-72
lines changed

analyzer.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package hnsw
22

3+
import "cmp"
4+
35
// Analyzer is a struct that holds a graph and provides
46
// methods for analyzing it. It offers no compatibility guarantee
57
// as the methods of measuring the graph's health with change
68
// with the implementation.
7-
type Analyzer[T Embeddable] struct {
8-
Graph *Graph[T]
9+
type Analyzer[K cmp.Ordered] struct {
10+
Graph *Graph[K]
911
}
1012

1113
func (a *Analyzer[T]) Height() int {
@@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int {
1719
func (a *Analyzer[T]) Connectivity() []float64 {
1820
var layerConnectivity []float64
1921
for _, layer := range a.Graph.layers {
20-
if len(layer.Nodes) == 0 {
22+
if len(layer.nodes) == 0 {
2123
continue
2224
}
2325

2426
var sum float64
25-
for _, node := range layer.Nodes {
27+
for _, node := range layer.nodes {
2628
sum += float64(len(node.neighbors))
2729
}
2830

29-
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))
31+
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes)))
3032
}
3133

3234
return layerConnectivity
@@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 {
3638
func (a *Analyzer[T]) Topography() []int {
3739
var topography []int
3840
for _, layer := range a.Graph.layers {
39-
topography = append(topography, len(layer.Nodes))
41+
topography = append(topography, len(layer.nodes))
4042
}
4143
return topography
4244
}

graph_test.go

Lines changed: 76 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package hnsw
22

33
import (
4+
"cmp"
45
"math/rand"
56
"strconv"
67
"testing"
@@ -18,34 +19,42 @@ func Test_maxLevel(t *testing.T) {
1819
require.Equal(t, 11, m)
1920
}
2021

21-
type basicPoint float32
22-
23-
func (n basicPoint) ID() string {
24-
return strconv.FormatFloat(float64(n), 'f', -1, 32)
25-
}
26-
27-
func (n basicPoint) Embedding() []float32 {
28-
return []float32{float32(n)}
29-
}
30-
3122
func Test_layerNode_search(t *testing.T) {
32-
entry := &layerNode[basicPoint]{
33-
vec: basicPoint(0),
34-
neighbors: map[string]*layerNode[basicPoint]{
35-
"1": {
36-
vec: basicPoint(1),
23+
entry := &layerNode[int]{
24+
Node: Node[int]{
25+
Vec: Vector{0},
26+
ID: 0,
27+
},
28+
neighbors: map[int]*layerNode[int]{
29+
1: {
30+
Node: Node[int]{
31+
Vec: Vector{1},
32+
ID: 1,
33+
},
3734
},
38-
"2": {
39-
vec: basicPoint(2),
35+
2: {
36+
Node: Node[int]{
37+
Vec: Vector{2},
38+
ID: 2,
39+
},
4040
},
41-
"3": {
42-
vec: basicPoint(3),
43-
neighbors: map[string]*layerNode[basicPoint]{
44-
"3.8": {
45-
vec: basicPoint(3.8),
41+
3: {
42+
Node: Node[int]{
43+
Vec: Vector{3},
44+
ID: 3,
45+
},
46+
neighbors: map[int]*layerNode[int]{
47+
4: {
48+
Node: Node[int]{
49+
Vec: Vector{4},
50+
ID: 5,
51+
},
4652
},
47-
"4.3": {
48-
vec: basicPoint(4.3),
53+
5: {
54+
Node: Node[int]{
55+
Vec: Vector{5},
56+
ID: 5,
57+
},
4958
},
5059
},
5160
},
@@ -54,13 +63,13 @@ func Test_layerNode_search(t *testing.T) {
5463

5564
best := entry.search(2, 4, []float32{4}, EuclideanDistance)
5665

57-
require.Equal(t, "3.8", best[0].node.Point.ID())
58-
require.Equal(t, "4.3", best[1].node.Point.ID())
66+
require.Equal(t, 4, best[0].node.ID)
67+
require.Equal(t, 3, best[1].node.ID)
5968
require.Len(t, best, 2)
6069
}
6170

62-
func newTestGraph[T Embeddable]() *Graph[T] {
63-
return &Graph[T]{
71+
func newTestGraph[K cmp.Ordered]() *Graph[K] {
72+
return &Graph[K]{
6473
M: 6,
6574
Distance: EuclideanDistance,
6675
Ml: 0.5,
@@ -72,13 +81,18 @@ func newTestGraph[T Embeddable]() *Graph[T] {
7281
func TestGraph_AddSearch(t *testing.T) {
7382
t.Parallel()
7483

75-
g := newTestGraph[basicPoint]()
84+
g := newTestGraph[int]()
7685

7786
for i := 0; i < 128; i++ {
78-
g.Add(basicPoint(float32(i)))
87+
g.Add(
88+
Node[int]{
89+
ID: i,
90+
Vec: Vector{float32(i)},
91+
},
92+
)
7993
}
8094

81-
al := Analyzer[basicPoint]{Graph: g}
95+
al := Analyzer[int]{Graph: g}
8296

8397
// Layers should be approximately log2(128) = 7
8498
// Look for an approximate doubling of the number of nodes in each layer.
@@ -101,11 +115,11 @@ func TestGraph_AddSearch(t *testing.T) {
101115
require.Len(t, nearest, 4)
102116
require.EqualValues(
103117
t,
104-
[]basicPoint{
105-
(64),
106-
(65),
107-
(62),
108-
(63),
118+
[]Node[int]{
119+
{64, Vector{64}},
120+
{65, Vector{65}},
121+
{62, Vector{62}},
122+
{63, Vector{63}},
109123
},
110124
nearest,
111125
)
@@ -114,19 +128,22 @@ func TestGraph_AddSearch(t *testing.T) {
114128
func TestGraph_AddDelete(t *testing.T) {
115129
t.Parallel()
116130

117-
g := newTestGraph[basicPoint]()
131+
g := newTestGraph[int]()
118132
for i := 0; i < 128; i++ {
119-
g.Add(basicPoint(i))
133+
g.Add(Node[int]{
134+
ID: i,
135+
Vec: Vector{float32(i)},
136+
})
120137
}
121138

122139
require.Equal(t, 128, g.Len())
123-
an := Analyzer[basicPoint]{Graph: g}
140+
an := Analyzer[int]{Graph: g}
124141

125142
preDeleteConnectivity := an.Connectivity()
126143

127144
// Delete every even node.
128145
for i := 0; i < 128; i += 2 {
129-
ok := g.Delete(basicPoint(i).ID())
146+
ok := g.Delete(i)
130147
require.True(t, ok)
131148
}
132149

@@ -141,7 +158,7 @@ func TestGraph_AddDelete(t *testing.T) {
141158
)
142159

143160
t.Run("DeleteNotFound", func(t *testing.T) {
144-
ok := g.Delete("not found")
161+
ok := g.Delete(-1)
145162
require.False(t, ok)
146163
})
147164
}
@@ -154,11 +171,14 @@ func Benchmark_HSNW(b *testing.B) {
154171
// Use this to ensure that complexity is O(log n) where n = h.Len().
155172
for _, size := range sizes {
156173
b.Run(strconv.Itoa(size), func(b *testing.B) {
157-
g := Graph[basicPoint]{}
174+
g := Graph[int]{}
158175
g.Ml = 0.5
159176
g.Distance = EuclideanDistance
160177
for i := 0; i < size; i++ {
161-
g.Add(basicPoint(i))
178+
g.Add(Node[int]{
179+
ID: i,
180+
Vec: Vector{float32(i)},
181+
})
162182
}
163183
b.ResetTimer()
164184

@@ -174,19 +194,6 @@ func Benchmark_HSNW(b *testing.B) {
174194
}
175195
}
176196

177-
type genericPoint struct {
178-
id string
179-
x []float32
180-
}
181-
182-
func (n genericPoint) ID() string {
183-
return n.id
184-
}
185-
186-
func (n genericPoint) Embedding() []float32 {
187-
return n.x
188-
}
189-
190197
func randFloats(n int) []float32 {
191198
x := make([]float32, n)
192199
for i := range x {
@@ -198,31 +205,34 @@ func randFloats(n int) []float32 {
198205
func Benchmark_HNSW_1536(b *testing.B) {
199206
b.ReportAllocs()
200207

201-
g := newTestGraph[genericPoint]()
208+
g := newTestGraph[int]()
202209
const size = 1000
203-
points := make([]genericPoint, size)
210+
points := make([]Node[int], size)
204211
for i := 0; i < size; i++ {
205-
points[i] = genericPoint{x: randFloats(1536), id: strconv.Itoa(i)}
212+
points[i] = Node[int]{
213+
ID: i,
214+
Vec: Vector(randFloats(1536)),
215+
}
206216
g.Add(points[i])
207217
}
208218
b.ResetTimer()
209219

210220
b.Run("Search", func(b *testing.B) {
211221
for i := 0; i < b.N; i++ {
212222
g.Search(
213-
points[i%size].x,
223+
points[i%size].Vec,
214224
4,
215225
)
216226
}
217227
})
218228
}
219229

220230
func TestGraph_DefaultCosine(t *testing.T) {
221-
g := NewGraph[Vector]()
231+
g := NewGraph[int]()
222232
g.Add(
223-
MakeVector("1", []float32{1, 1}),
224-
MakeVector("2", []float32{0, 1}),
225-
MakeVector("3", []float32{1, -1}),
233+
Node[int]{ID: 1, Vec: Vector{1, 1}},
234+
Node[int]{ID: 2, Vec: Vector{0, 1}},
235+
Node[int]{ID: 3, Vec: Vector{1, -1}},
226236
)
227237

228238
neighbors := g.Search(
@@ -232,8 +242,8 @@ func TestGraph_DefaultCosine(t *testing.T) {
232242

233243
require.Equal(
234244
t,
235-
[]Vector{
236-
MakeVector("1", []float32{1, 1}),
245+
[]Node[int]{
246+
{1, Vector{1, 1}},
237247
},
238248
neighbors,
239249
)

0 commit comments

Comments
 (0)