@@ -2,6 +2,7 @@ package hnsw
22
33import (
44 "bufio"
5+ "cmp"
56 "encoding/binary"
67 "fmt"
78 "io"
@@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) {
4344 * v = string (s )
4445 return len (s ), err
4546
47+ case * []float32 :
48+ var ln int
49+ _ , err := binaryRead (r , & ln )
50+ if err != nil {
51+ return 0 , err
52+ }
53+
54+ * v = make ([]float32 , ln )
55+ return binary .Size (* v ), binary .Read (r , byteOrder , * v )
56+
4657 case io.ReaderFrom :
4758 n , err := v .ReadFrom (r )
4859 return int (n ), err
@@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) {
7384 }
7485
7586 return n + n2 , nil
87+ case []float32 :
88+ n , err := binaryWrite (w , len (v ))
89+ if err != nil {
90+ return n , err
91+ }
92+ return n + binary .Size (v ), binary .Write (w , byteOrder , v )
7693
7794 default :
7895 sz := binary .Size (data )
@@ -113,7 +130,7 @@ const encodingVersion = 1
113130// Export writes the graph to a writer.
114131//
115132// T must implement io.WriterTo.
116- func (h * Graph [T ]) Export (w io.Writer ) error {
133+ func (h * Graph [K ]) Export (w io.Writer ) error {
117134 distFuncName , ok := distanceFuncToName (h .Distance )
118135 if ! ok {
119136 return fmt .Errorf ("distance function %v must be registered with RegisterDistanceFunc" , h .Distance )
@@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error {
134151 return fmt .Errorf ("encode number of layers: %w" , err )
135152 }
136153 for _ , layer := range h .layers {
137- _ , err = binaryWrite (w , len (layer .Nodes ))
154+ _ , err = binaryWrite (w , len (layer .nodes ))
138155 if err != nil {
139156 return fmt .Errorf ("encode number of nodes: %w" , err )
140157 }
141- for _ , node := range layer .Nodes {
142- _ , err = binaryWrite (w , node .Point )
158+ for _ , node := range layer .nodes {
159+ _ , err = multiBinaryWrite (w , node .ID , node . Vec , len ( node . neighbors ) )
143160 if err != nil {
144- return fmt .Errorf ("encode node point: %w" , err )
145- }
146-
147- if _ , err = binaryWrite (w , len (node .neighbors )); err != nil {
148- return fmt .Errorf ("encode number of neighbors: %w" , err )
161+ return fmt .Errorf ("encode node data: %w" , err )
149162 }
150163
151164 for neighbor := range node .neighbors {
152165 _ , err = binaryWrite (w , neighbor )
153166 if err != nil {
154- return fmt .Errorf ("encode neighbor %q : %w" , neighbor , err )
167+ return fmt .Errorf ("encode neighbor %v : %w" , neighbor , err )
155168 }
156169 }
157170 }
@@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error {
164177// T must implement io.ReaderFrom.
165178// The imported graph does not have to match the exported graph's parameters (except for
166179// dimensionality). The graph will converge onto the new parameters.
167- func (h * Graph [T ]) Import (r io.Reader ) error {
180+ func (h * Graph [K ]) Import (r io.Reader ) error {
168181 var (
169182 version int
170183 dist string
@@ -195,44 +208,43 @@ func (h *Graph[T]) Import(r io.Reader) error {
195208 return err
196209 }
197210
198- h .layers = make ([]* layer [T ], nLayers )
211+ h .layers = make ([]* layer [K ], nLayers )
199212 for i := 0 ; i < nLayers ; i ++ {
200213 var nNodes int
201214 _ , err = binaryRead (r , & nNodes )
202215 if err != nil {
203216 return err
204217 }
205218
206- nodes := make (map [string ]* layerNode [T ], nNodes )
219+ nodes := make (map [K ]* layerNode [K ], nNodes )
207220 for j := 0 ; j < nNodes ; j ++ {
208- var point T
209- _ , err = binaryRead (r , & point )
210- if err != nil {
211- return fmt .Errorf ("decoding node %d: %w" , j , err )
212- }
213-
221+ var id K
222+ var vec Vector
214223 var nNeighbors int
215- _ , err = binaryRead ( r , & nNeighbors )
224+ _ , err = multiBinaryRead ( r , & id , & vec , & nNeighbors )
216225 if err != nil {
217- return fmt .Errorf ("decoding number of neighbors for node %d: %w" , j , err )
226+ return fmt .Errorf ("decoding node %d: %w" , j , err )
218227 }
219228
220- neighbors := make ([]string , nNeighbors )
229+ neighbors := make ([]K , nNeighbors )
221230 for k := 0 ; k < nNeighbors ; k ++ {
222- var neighbor string
231+ var neighbor K
223232 _ , err = binaryRead (r , & neighbor )
224233 if err != nil {
225234 return fmt .Errorf ("decoding neighbor %d for node %d: %w" , k , j , err )
226235 }
227236 neighbors [k ] = neighbor
228237 }
229238
230- node := & layerNode [T ]{
231- vec : point ,
232- neighbors : make (map [string ]* layerNode [T ]),
239+ node := & layerNode [K ]{
240+ Node : Node [K ]{
241+ ID : id ,
242+ Vec : vec ,
243+ },
244+ neighbors : make (map [K ]* layerNode [K ]),
233245 }
234246
235- nodes [point . ID () ] = node
247+ nodes [id ] = node
236248 for _ , neighbor := range neighbors {
237249 node .neighbors [neighbor ] = nil
238250 }
@@ -243,7 +255,7 @@ func (h *Graph[T]) Import(r io.Reader) error {
243255 node .neighbors [id ] = nodes [id ]
244256 }
245257 }
246- h .layers [i ] = & layer [T ]{nodes : nodes }
258+ h .layers [i ] = & layer [K ]{nodes : nodes }
247259 }
248260
249261 return nil
@@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error {
253265// changes to a file upon calls to Save. It is more convenient
254266// but less powerful than calling Graph.Export and Graph.Import
255267// directly.
256- type SavedGraph [T Embeddable ] struct {
257- * Graph [T ]
268+ type SavedGraph [K cmp. Ordered ] struct {
269+ * Graph [K ]
258270 Path string
259271}
260272
@@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct {
265277//
266278// It does not hold open a file descriptor, so SavedGraph can be forgotten
267279// without ever calling Save.
268- func LoadSavedGraph [T Embeddable ](path string ) (* SavedGraph [T ], error ) {
280+ func LoadSavedGraph [K cmp. Ordered ](path string ) (* SavedGraph [K ], error ) {
269281 f , err := os .OpenFile (path , os .O_RDWR | os .O_CREATE , 0o600 )
270282 if err != nil {
271283 return nil , err
@@ -276,15 +288,15 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
276288 return nil , err
277289 }
278290
279- g := NewGraph [T ]()
291+ g := NewGraph [K ]()
280292 if info .Size () > 0 {
281293 err = g .Import (bufio .NewReader (f ))
282294 if err != nil {
283295 return nil , fmt .Errorf ("import: %w" , err )
284296 }
285297 }
286298
287- return & SavedGraph [T ]{Graph : g , Path : path }, nil
299+ return & SavedGraph [K ]{Graph : g , Path : path }, nil
288300}
289301
290302// Save writes the graph to the file.
0 commit comments