Skip to content

Commit e32962b

Browse files
committed
Add persistence
Solved #3
1 parent 4c5496e commit e32962b

File tree

10 files changed

+683
-43
lines changed

10 files changed

+683
-43
lines changed

README.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,40 @@ And, if you're struggling with excess memory usage, consider:
6060
* Reducing $m_L$ a.k.a `Graph.Ml` (the level generation parameter)
6161

6262

63-
## Roadmap
63+
## Persistence
6464

65-
- [ ] [#3](https://github.com/coder/hnsw/issues/3) Persistence / serialization
65+
While all graph operations are in-memory, `hnsw` provides facilities for loading/saving from persistent storage.
66+
67+
For an `io.Reader`/`io.Writer` interface use `Graph.Export` and `Graph.Import`.
68+
69+
If you're storing within a filesystem, you can use the more convenient `SavedGraph` instead:
70+
71+
```go
72+
path := "some.graph"
73+
g1, err := LoadSavedGraph[hnsw.Vector](path)
74+
if err != nil {
75+
panic(err)
76+
}
77+
// Insert some vectors
78+
for i := 0; i < 128; i++ {
79+
g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)}))
80+
}
81+
82+
// Save to disk
83+
err = g1.Save()
84+
if err != nil {
85+
panic(err)
86+
}
87+
88+
// Later...
89+
// g2 is a copy of g1
90+
g2, err := LoadSavedGraph[Vector](path)
91+
if err != nil {
92+
panic(err)
93+
}
94+
```
95+
96+
See more:
97+
* [Export](https://pkg.go.dev/github.com/coder/hnsw#Graph.Export)
98+
* [Import](https://pkg.go.dev/github.com/coder/hnsw#Graph.Import)
99+
* [SavedGraph](https://pkg.go.dev/github.com/coder/hnsw#SavedGraph)

analyzer.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ func (a *Analyzer[T]) Height() int {
1717
func (a *Analyzer[T]) Connectivity() []float64 {
1818
var layerConnectivity []float64
1919
for _, layer := range a.Graph.layers {
20-
if len(layer.nodes) == 0 {
20+
if len(layer.Nodes) == 0 {
2121
continue
2222
}
2323

2424
var sum float64
25-
for _, node := range layer.nodes {
25+
for _, node := range layer.Nodes {
2626
sum += float64(len(node.neighbors))
2727
}
2828

29-
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes)))
29+
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))
3030
}
3131

3232
return layerConnectivity
@@ -36,7 +36,7 @@ func (a *Analyzer[T]) Connectivity() []float64 {
3636
func (a *Analyzer[T]) Topography() []int {
3737
var topography []int
3838
for _, layer := range a.Graph.layers {
39-
topography = append(topography, len(layer.nodes))
39+
topography = append(topography, len(layer.Nodes))
4040
}
4141
return topography
4242
}

distance.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package hnsw
22

3-
import "math"
3+
import (
4+
"math"
5+
"reflect"
6+
)
47

58
// DistanceFunc is a function that computes the distance between two vectors.
69
type DistanceFunc func(a, b []float32) float32
@@ -34,3 +37,26 @@ func EuclideanDistance(a, b []float32) float32 {
3437
}
3538
return float32(math.Sqrt(float64(sum)))
3639
}
40+
41+
var distanceFuncs = map[string]DistanceFunc{
42+
"euclidean": EuclideanDistance,
43+
"cosine": CosineDistance,
44+
}
45+
46+
func distanceFuncToName(fn DistanceFunc) (string, bool) {
47+
for name, f := range distanceFuncs {
48+
fnptr := reflect.ValueOf(fn).Pointer()
49+
fptr := reflect.ValueOf(f).Pointer()
50+
if fptr == fnptr {
51+
return name, true
52+
}
53+
}
54+
return "", false
55+
}
56+
57+
// RegisterDistanceFunc registers a distance function with a name.
58+
// A distance function must be registered here before a graph can be
59+
// exported and imported.
60+
func RegisterDistanceFunc(name string, fn DistanceFunc) {
61+
distanceFuncs[name] = fn
62+
}

0 commit comments

Comments
 (0)