Skip to content

Commit 0236acb

Browse files
committed
Tests pass!
1 parent 52ae936 commit 0236acb

File tree

5 files changed

+94
-65
lines changed

5 files changed

+94
-65
lines changed

encode.go

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package hnsw
22

33
import (
44
"bufio"
5+
"cmp"
56
"encoding/binary"
67
"fmt"
78
"io"
@@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) {
4344
*v = string(s)
4445
return len(s), err
4546

47+
case *[]float32:
48+
var ln int
49+
_, err := binaryRead(r, &ln)
50+
if err != nil {
51+
return 0, err
52+
}
53+
54+
*v = make([]float32, ln)
55+
return binary.Size(*v), binary.Read(r, byteOrder, *v)
56+
4657
case io.ReaderFrom:
4758
n, err := v.ReadFrom(r)
4859
return int(n), err
@@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) {
7384
}
7485

7586
return n + n2, nil
87+
case []float32:
88+
n, err := binaryWrite(w, len(v))
89+
if err != nil {
90+
return n, err
91+
}
92+
return n + binary.Size(v), binary.Write(w, byteOrder, v)
7693

7794
default:
7895
sz := binary.Size(data)
@@ -113,7 +130,7 @@ const encodingVersion = 1
113130
// Export writes the graph to a writer.
114131
//
115132
// T must implement io.WriterTo.
116-
func (h *Graph[T]) Export(w io.Writer) error {
133+
func (h *Graph[K]) Export(w io.Writer) error {
117134
distFuncName, ok := distanceFuncToName(h.Distance)
118135
if !ok {
119136
return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance)
@@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error {
134151
return fmt.Errorf("encode number of layers: %w", err)
135152
}
136153
for _, layer := range h.layers {
137-
_, err = binaryWrite(w, len(layer.Nodes))
154+
_, err = binaryWrite(w, len(layer.nodes))
138155
if err != nil {
139156
return fmt.Errorf("encode number of nodes: %w", err)
140157
}
141-
for _, node := range layer.Nodes {
142-
_, err = binaryWrite(w, node.Point)
158+
for _, node := range layer.nodes {
159+
_, err = multiBinaryWrite(w, node.ID, node.Vec, len(node.neighbors))
143160
if err != nil {
144-
return fmt.Errorf("encode node point: %w", err)
145-
}
146-
147-
if _, err = binaryWrite(w, len(node.neighbors)); err != nil {
148-
return fmt.Errorf("encode number of neighbors: %w", err)
161+
return fmt.Errorf("encode node data: %w", err)
149162
}
150163

151164
for neighbor := range node.neighbors {
152165
_, err = binaryWrite(w, neighbor)
153166
if err != nil {
154-
return fmt.Errorf("encode neighbor %q: %w", neighbor, err)
167+
return fmt.Errorf("encode neighbor %v: %w", neighbor, err)
155168
}
156169
}
157170
}
@@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error {
164177
// T must implement io.ReaderFrom.
165178
// The imported graph does not have to match the exported graph's parameters (except for
166179
// dimensionality). The graph will converge onto the new parameters.
167-
func (h *Graph[T]) Import(r io.Reader) error {
180+
func (h *Graph[K]) Import(r io.Reader) error {
168181
var (
169182
version int
170183
dist string
@@ -195,44 +208,43 @@ func (h *Graph[T]) Import(r io.Reader) error {
195208
return err
196209
}
197210

198-
h.layers = make([]*layer[T], nLayers)
211+
h.layers = make([]*layer[K], nLayers)
199212
for i := 0; i < nLayers; i++ {
200213
var nNodes int
201214
_, err = binaryRead(r, &nNodes)
202215
if err != nil {
203216
return err
204217
}
205218

206-
nodes := make(map[string]*layerNode[T], nNodes)
219+
nodes := make(map[K]*layerNode[K], nNodes)
207220
for j := 0; j < nNodes; j++ {
208-
var point T
209-
_, err = binaryRead(r, &point)
210-
if err != nil {
211-
return fmt.Errorf("decoding node %d: %w", j, err)
212-
}
213-
221+
var id K
222+
var vec Vector
214223
var nNeighbors int
215-
_, err = binaryRead(r, &nNeighbors)
224+
_, err = multiBinaryRead(r, &id, &vec, &nNeighbors)
216225
if err != nil {
217-
return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err)
226+
return fmt.Errorf("decoding node %d: %w", j, err)
218227
}
219228

220-
neighbors := make([]string, nNeighbors)
229+
neighbors := make([]K, nNeighbors)
221230
for k := 0; k < nNeighbors; k++ {
222-
var neighbor string
231+
var neighbor K
223232
_, err = binaryRead(r, &neighbor)
224233
if err != nil {
225234
return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err)
226235
}
227236
neighbors[k] = neighbor
228237
}
229238

230-
node := &layerNode[T]{
231-
vec: point,
232-
neighbors: make(map[string]*layerNode[T]),
239+
node := &layerNode[K]{
240+
Node: Node[K]{
241+
ID: id,
242+
Vec: vec,
243+
},
244+
neighbors: make(map[K]*layerNode[K]),
233245
}
234246

235-
nodes[point.ID()] = node
247+
nodes[id] = node
236248
for _, neighbor := range neighbors {
237249
node.neighbors[neighbor] = nil
238250
}
@@ -243,7 +255,7 @@ func (h *Graph[T]) Import(r io.Reader) error {
243255
node.neighbors[id] = nodes[id]
244256
}
245257
}
246-
h.layers[i] = &layer[T]{nodes: nodes}
258+
h.layers[i] = &layer[K]{nodes: nodes}
247259
}
248260

249261
return nil
@@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error {
253265
// changes to a file upon calls to Save. It is more convenient
254266
// but less powerful than calling Graph.Export and Graph.Import
255267
// directly.
256-
type SavedGraph[T Embeddable] struct {
257-
*Graph[T]
268+
type SavedGraph[K cmp.Ordered] struct {
269+
*Graph[K]
258270
Path string
259271
}
260272

@@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct {
265277
//
266278
// It does not hold open a file descriptor, so SavedGraph can be forgotten
267279
// without ever calling Save.
268-
func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
280+
func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) {
269281
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
270282
if err != nil {
271283
return nil, err
@@ -276,15 +288,15 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
276288
return nil, err
277289
}
278290

279-
g := NewGraph[T]()
291+
g := NewGraph[K]()
280292
if info.Size() > 0 {
281293
err = g.Import(bufio.NewReader(f))
282294
if err != nil {
283295
return nil, fmt.Errorf("import: %w", err)
284296
}
285297
}
286298

287-
return &SavedGraph[T]{Graph: g, Path: path}, nil
299+
return &SavedGraph[K]{Graph: g, Path: path}, nil
288300
}
289301

290302
// Save writes the graph to the file.

encode_test.go

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ package hnsw
22

33
import (
44
"bytes"
5-
"math/rand"
6-
"strconv"
5+
"cmp"
76
"testing"
87

98
"github.com/stretchr/testify/require"
@@ -50,21 +49,21 @@ func Test_binaryWrite_string(t *testing.T) {
5049
require.Empty(t, buf.Bytes())
5150
}
5251

53-
func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) {
52+
func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) {
5453
for _, layer := range g.layers {
55-
for _, node := range layer.Nodes {
54+
for _, node := range layer.nodes {
5655
for neighborKey, neighbor := range node.neighbors {
57-
_, ok := layer.Nodes[neighbor.Point.ID()]
56+
_, ok := layer.nodes[neighbor.ID]
5857
if !ok {
5958
t.Errorf(
60-
"node %s has neighbor %s, but neighbor does not exist",
61-
node.Point.ID(), neighbor.Point.ID(),
59+
"node %v has neighbor %v, but neighbor does not exist",
60+
node.ID, neighbor.ID,
6261
)
6362
}
6463

65-
if neighborKey != neighbor.Point.ID() {
66-
t.Errorf("node %s has neighbor %s, but neighbor key is %s", node.Point.ID(),
67-
neighbor.Point.ID(),
64+
if neighborKey != neighbor.ID {
65+
t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.ID,
66+
neighbor.ID,
6867
neighborKey,
6968
)
7069
}
@@ -74,10 +73,10 @@ func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) {
7473
}
7574

7675
// requireGraphApproxEquals checks that two graphs are equal.
77-
func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) {
76+
func requireGraphApproxEquals[K cmp.Ordered](t *testing.T, g1, g2 *Graph[K]) {
7877
require.Equal(t, g1.Len(), g2.Len())
79-
a1 := Analyzer[T]{g1}
80-
a2 := Analyzer[T]{g2}
78+
a1 := Analyzer[K]{g1}
79+
a2 := Analyzer[K]{g2}
8180

8281
require.Equal(
8382
t,
@@ -119,11 +118,13 @@ func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) {
119118
}
120119

121120
func TestGraph_ExportImport(t *testing.T) {
122-
rng := rand.New(rand.NewSource(0))
123-
124-
g1 := newTestGraph[Vector]()
121+
g1 := newTestGraph[int]()
125122
for i := 0; i < 128; i++ {
126-
g1.Add(MakeVector(strconv.Itoa(i), []float32{rng.Float32()}))
123+
g1.Add(
124+
Node[int]{
125+
i, randFloats(1),
126+
},
127+
)
127128
}
128129

129130
buf := &bytes.Buffer{}
@@ -132,7 +133,7 @@ func TestGraph_ExportImport(t *testing.T) {
132133

133134
// Don't use newTestGraph to ensure parameters
134135
// are imported.
135-
g2 := &Graph[Vector]{}
136+
g2 := &Graph[int]{}
136137
err = g2.Import(buf)
137138
require.NoError(t, err)
138139

@@ -157,17 +158,21 @@ func TestGraph_ExportImport(t *testing.T) {
157158
func TestSavedGraph(t *testing.T) {
158159
dir := t.TempDir()
159160

160-
g1, err := LoadSavedGraph[Vector](dir + "/graph")
161+
g1, err := LoadSavedGraph[int](dir + "/graph")
161162
require.NoError(t, err)
162163
require.Equal(t, 0, g1.Len())
163164
for i := 0; i < 128; i++ {
164-
g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)}))
165+
g1.Add(
166+
Node[int]{
167+
i, randFloats(1),
168+
},
169+
)
165170
}
166171

167172
err = g1.Save()
168173
require.NoError(t, err)
169174

170-
g2, err := LoadSavedGraph[Vector](dir + "/graph")
175+
g2, err := LoadSavedGraph[int](dir + "/graph")
171176
require.NoError(t, err)
172177

173178
requireGraphApproxEquals(t, g1.Graph, g2.Graph)
@@ -177,9 +182,13 @@ const benchGraphSize = 100
177182

178183
func BenchmarkGraph_Import(b *testing.B) {
179184
b.ReportAllocs()
180-
g := newTestGraph[Vector]()
185+
g := newTestGraph[int]()
181186
for i := 0; i < benchGraphSize; i++ {
182-
g.Add(MakeVector(strconv.Itoa(i), randFloats(100)))
187+
g.Add(
188+
Node[int]{
189+
i, randFloats(256),
190+
},
191+
)
183192
}
184193

185194
buf := &bytes.Buffer{}
@@ -192,7 +201,7 @@ func BenchmarkGraph_Import(b *testing.B) {
192201
for i := 0; i < b.N; i++ {
193202
b.StopTimer()
194203
rdr := bytes.NewReader(buf.Bytes())
195-
g := newTestGraph[Vector]()
204+
g := newTestGraph[int]()
196205
b.StartTimer()
197206
err = g.Import(rdr)
198207
require.NoError(b, err)
@@ -201,9 +210,13 @@ func BenchmarkGraph_Import(b *testing.B) {
201210

202211
func BenchmarkGraph_Export(b *testing.B) {
203212
b.ReportAllocs()
204-
g := newTestGraph[Vector]()
213+
g := newTestGraph[int]()
205214
for i := 0; i < benchGraphSize; i++ {
206-
g.Add(MakeVector(strconv.Itoa(i), randFloats(256)))
215+
g.Add(
216+
Node[int]{
217+
i, randFloats(256),
218+
},
219+
)
207220
}
208221

209222
var buf bytes.Buffer

example/readme/main.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ import (
77
)
88

99
func main() {
10-
g := hnsw.NewGraph[hnsw.Vector]()
10+
g := hnsw.NewGraph[int]()
1111
g.Add(
12-
hnsw.MakeVector("1", []float32{1, 1, 1}),
13-
hnsw.MakeVector("2", []float32{1, -1, 0.999}),
14-
hnsw.MakeVector("3", []float32{1, 0, -0.5}),
12+
hnsw.MakeNode(1, []float32{1, 1, 1}),
13+
hnsw.MakeNode(2, []float32{1, -1, 0.999}),
14+
hnsw.MakeNode(3, []float32{1, 0, -0.5}),
1515
)
1616

1717
neighbors := g.Search(
1818
[]float32{0.5, 0.5, 0.5},
1919
1,
2020
)
21-
fmt.Printf("best friend: %v\n", neighbors[0].Embedding())
21+
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
2222
}

graph.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ type Node[K cmp.Ordered] struct {
2020
Vec Vector
2121
}
2222

23+
func MakeNode[K cmp.Ordered](id K, vec Vector) Node[K] {
24+
return Node[K]{ID: id, Vec: vec}
25+
}
26+
2327
// layerNode is a node in a layer of the graph.
2428
type layerNode[K cmp.Ordered] struct {
2529
Node[K]

graph_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func Test_layerNode_search(t *testing.T) {
6363

6464
best := entry.search(2, 4, []float32{4}, EuclideanDistance)
6565

66-
require.Equal(t, 4, best[0].node.ID)
66+
require.Equal(t, 5, best[0].node.ID)
6767
require.Equal(t, 3, best[1].node.ID)
6868
require.Len(t, best, 2)
6969
}

0 commit comments

Comments
 (0)