11package hnsw
22
33import (
4+ "cmp"
45 "math/rand"
56 "strconv"
67 "testing"
@@ -18,34 +19,42 @@ func Test_maxLevel(t *testing.T) {
1819 require .Equal (t , 11 , m )
1920}
2021
21- type basicPoint float32
22-
23- func (n basicPoint ) ID () string {
24- return strconv .FormatFloat (float64 (n ), 'f' , - 1 , 32 )
25- }
26-
27- func (n basicPoint ) Embedding () []float32 {
28- return []float32 {float32 (n )}
29- }
30-
3122func Test_layerNode_search (t * testing.T ) {
32- entry := & layerNode [basicPoint ]{
33- vec : basicPoint (0 ),
34- neighbors : map [string ]* layerNode [basicPoint ]{
35- "1" : {
36- vec : basicPoint (1 ),
23+ entry := & layerNode [int ]{
24+ Node : Node [int ]{
25+ Vec : Vector {0 },
26+ ID : 0 ,
27+ },
28+ neighbors : map [int ]* layerNode [int ]{
29+ 1 : {
30+ Node : Node [int ]{
31+ Vec : Vector {1 },
32+ ID : 1 ,
33+ },
3734 },
38- "2" : {
39- vec : basicPoint (2 ),
35+ 2 : {
36+ Node : Node [int ]{
37+ Vec : Vector {2 },
38+ ID : 2 ,
39+ },
4040 },
41- "3" : {
42- vec : basicPoint (3 ),
43- neighbors : map [string ]* layerNode [basicPoint ]{
44- "3.8" : {
45- vec : basicPoint (3.8 ),
41+ 3 : {
42+ Node : Node [int ]{
43+ Vec : Vector {3 },
44+ ID : 3 ,
45+ },
46+ neighbors : map [int ]* layerNode [int ]{
47+ 4 : {
48+ Node : Node [int ]{
49+ Vec : Vector {4 },
50+ ID : 5 ,
51+ },
4652 },
47- "4.3" : {
48- vec : basicPoint (4.3 ),
53+ 5 : {
54+ Node : Node [int ]{
55+ Vec : Vector {5 },
56+ ID : 5 ,
57+ },
4958 },
5059 },
5160 },
@@ -54,13 +63,13 @@ func Test_layerNode_search(t *testing.T) {
5463
5564 best := entry .search (2 , 4 , []float32 {4 }, EuclideanDistance )
5665
57- require .Equal (t , "3.8" , best [0 ].node .Point . ID () )
58- require .Equal (t , "4.3" , best [1 ].node .Point . ID () )
66+ require .Equal (t , 4 , best [0 ].node .ID )
67+ require .Equal (t , 3 , best [1 ].node .ID )
5968 require .Len (t , best , 2 )
6069}
6170
62- func newTestGraph [T Embeddable ]() * Graph [T ] {
63- return & Graph [T ]{
71+ func newTestGraph [K cmp. Ordered ]() * Graph [K ] {
72+ return & Graph [K ]{
6473 M : 6 ,
6574 Distance : EuclideanDistance ,
6675 Ml : 0.5 ,
@@ -72,13 +81,18 @@ func newTestGraph[T Embeddable]() *Graph[T] {
7281func TestGraph_AddSearch (t * testing.T ) {
7382 t .Parallel ()
7483
75- g := newTestGraph [basicPoint ]()
84+ g := newTestGraph [int ]()
7685
7786 for i := 0 ; i < 128 ; i ++ {
78- g .Add (basicPoint (float32 (i )))
87+ g .Add (
88+ Node [int ]{
89+ ID : i ,
90+ Vec : Vector {float32 (i )},
91+ },
92+ )
7993 }
8094
81- al := Analyzer [basicPoint ]{Graph : g }
95+ al := Analyzer [int ]{Graph : g }
8296
8397 // Layers should be approximately log2(128) = 7
8498 // Look for an approximate doubling of the number of nodes in each layer.
@@ -101,11 +115,11 @@ func TestGraph_AddSearch(t *testing.T) {
101115 require .Len (t , nearest , 4 )
102116 require .EqualValues (
103117 t ,
104- []basicPoint {
105- ( 64 ) ,
106- ( 65 ) ,
107- ( 62 ) ,
108- ( 63 ) ,
118+ []Node [ int ] {
119+ { 64 , Vector { 64 }} ,
120+ { 65 , Vector { 65 }} ,
121+ { 62 , Vector { 62 }} ,
122+ { 63 , Vector { 63 }} ,
109123 },
110124 nearest ,
111125 )
@@ -114,19 +128,22 @@ func TestGraph_AddSearch(t *testing.T) {
114128func TestGraph_AddDelete (t * testing.T ) {
115129 t .Parallel ()
116130
117- g := newTestGraph [basicPoint ]()
131+ g := newTestGraph [int ]()
118132 for i := 0 ; i < 128 ; i ++ {
119- g .Add (basicPoint (i ))
133+ g .Add (Node [int ]{
134+ ID : i ,
135+ Vec : Vector {float32 (i )},
136+ })
120137 }
121138
122139 require .Equal (t , 128 , g .Len ())
123- an := Analyzer [basicPoint ]{Graph : g }
140+ an := Analyzer [int ]{Graph : g }
124141
125142 preDeleteConnectivity := an .Connectivity ()
126143
127144 // Delete every even node.
128145 for i := 0 ; i < 128 ; i += 2 {
129- ok := g .Delete (basicPoint ( i ). ID () )
146+ ok := g .Delete (i )
130147 require .True (t , ok )
131148 }
132149
@@ -141,7 +158,7 @@ func TestGraph_AddDelete(t *testing.T) {
141158 )
142159
143160 t .Run ("DeleteNotFound" , func (t * testing.T ) {
144- ok := g .Delete ("not found" )
161+ ok := g .Delete (- 1 )
145162 require .False (t , ok )
146163 })
147164}
@@ -154,11 +171,14 @@ func Benchmark_HSNW(b *testing.B) {
154171 // Use this to ensure that complexity is O(log n) where n = h.Len().
155172 for _ , size := range sizes {
156173 b .Run (strconv .Itoa (size ), func (b * testing.B ) {
157- g := Graph [basicPoint ]{}
174+ g := Graph [int ]{}
158175 g .Ml = 0.5
159176 g .Distance = EuclideanDistance
160177 for i := 0 ; i < size ; i ++ {
161- g .Add (basicPoint (i ))
178+ g .Add (Node [int ]{
179+ ID : i ,
180+ Vec : Vector {float32 (i )},
181+ })
162182 }
163183 b .ResetTimer ()
164184
@@ -174,19 +194,6 @@ func Benchmark_HSNW(b *testing.B) {
174194 }
175195}
176196
177- type genericPoint struct {
178- id string
179- x []float32
180- }
181-
182- func (n genericPoint ) ID () string {
183- return n .id
184- }
185-
186- func (n genericPoint ) Embedding () []float32 {
187- return n .x
188- }
189-
190197func randFloats (n int ) []float32 {
191198 x := make ([]float32 , n )
192199 for i := range x {
@@ -198,31 +205,34 @@ func randFloats(n int) []float32 {
198205func Benchmark_HNSW_1536 (b * testing.B ) {
199206 b .ReportAllocs ()
200207
201- g := newTestGraph [genericPoint ]()
208+ g := newTestGraph [int ]()
202209 const size = 1000
203- points := make ([]genericPoint , size )
210+ points := make ([]Node [ int ] , size )
204211 for i := 0 ; i < size ; i ++ {
205- points [i ] = genericPoint {x : randFloats (1536 ), id : strconv .Itoa (i )}
212+ points [i ] = Node [int ]{
213+ ID : i ,
214+ Vec : Vector (randFloats (1536 )),
215+ }
206216 g .Add (points [i ])
207217 }
208218 b .ResetTimer ()
209219
210220 b .Run ("Search" , func (b * testing.B ) {
211221 for i := 0 ; i < b .N ; i ++ {
212222 g .Search (
213- points [i % size ].x ,
223+ points [i % size ].Vec ,
214224 4 ,
215225 )
216226 }
217227 })
218228}
219229
220230func TestGraph_DefaultCosine (t * testing.T ) {
221- g := NewGraph [Vector ]()
231+ g := NewGraph [int ]()
222232 g .Add (
223- MakeVector ( "1" , [] float32 {1 , 1 }) ,
224- MakeVector ( "2" , [] float32 {0 , 1 }) ,
225- MakeVector ( "3" , [] float32 {1 , - 1 }) ,
233+ Node [ int ]{ ID : 1 , Vec : Vector {1 , 1 }} ,
234+ Node [ int ]{ ID : 2 , Vec : Vector {0 , 1 }} ,
235+ Node [ int ]{ ID : 3 , Vec : Vector {1 , - 1 }} ,
226236 )
227237
228238 neighbors := g .Search (
@@ -232,8 +242,8 @@ func TestGraph_DefaultCosine(t *testing.T) {
232242
233243 require .Equal (
234244 t ,
235- []Vector {
236- MakeVector ( "1" , [] float32 {1 , 1 }) ,
245+ []Node [ int ] {
246+ { 1 , Vector {1 , 1 }} ,
237247 },
238248 neighbors ,
239249 )
0 commit comments