11package hnsw
22
33import (
4+ "cmp"
45 "fmt"
56 "math"
67 "math/rand"
@@ -14,28 +15,28 @@ import (
1415type Embedding = []float32
1516
1617// Embeddable describes a type that can be embedded in a HNSW graph.
17- type Embeddable interface {
18+ type Embeddable [ K cmp. Ordered ] interface {
1819 // ID returns a unique identifier for the object.
19- ID () string
20+ ID () K
2021 // Embedding returns the embedding of the object.
2122 // float32 is used for compatibility with OpenAI embeddings.
2223 Embedding () Embedding
2324}
2425
2526// layerNode is a node in a layer of the graph.
26- type layerNode [T Embeddable ] struct {
27- Point Embeddable
27+ type layerNode [K cmp. Ordered , V Embeddable [ K ] ] struct {
28+ Point Embeddable [ K ]
2829 // neighbors is map of neighbor IDs to neighbor nodes.
2930 // It is a map and not a slice to allow for efficient deletes, esp.
3031 // when M is high.
31- neighbors map [string ]* layerNode [T ]
32+ neighbors map [K ]* layerNode [K , V ]
3233}
3334
3435// addNeighbor adds a o neighbor to the node, replacing the neighbor
3536// with the worst distance if the neighbor set is full.
36- func (n * layerNode [T ]) addNeighbor (newNode * layerNode [T ], m int , dist DistanceFunc ) {
37+ func (n * layerNode [K , V ]) addNeighbor (newNode * layerNode [K , V ], m int , dist DistanceFunc ) {
3738 if n .neighbors == nil {
38- n .neighbors = make (map [string ]* layerNode [T ], m )
39+ n .neighbors = make (map [K ]* layerNode [K , V ], m )
3940 }
4041
4142 n .neighbors [newNode .Point .ID ()] = newNode
@@ -46,7 +47,7 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
4647 // Find the neighbor with the worst distance.
4748 var (
4849 worstDist = float32 (math .Inf (- 1 ))
49- worst * layerNode [T ]
50+ worst * layerNode [K , V ]
5051 )
5152 for _ , neighbor := range n .neighbors {
5253 d := dist (neighbor .Point .Embedding (), n .Point .Embedding ())
@@ -64,39 +65,39 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu
6465 worst .replenish (m )
6566}
6667
67- type searchCandidate [T Embeddable ] struct {
68- node * layerNode [T ]
68+ type searchCandidate [K cmp. Ordered , V Embeddable [ K ] ] struct {
69+ node * layerNode [K , V ]
6970 dist float32
7071}
7172
72- func (s searchCandidate [T ]) Less (o searchCandidate [T ]) bool {
73+ func (s searchCandidate [K , V ]) Less (o searchCandidate [K , V ]) bool {
7374 return s .dist < o .dist
7475}
7576
7677// search returns the layer node closest to the target node
7778// within the same layer.
78- func (n * layerNode [T ]) search (
79+ func (n * layerNode [K , V ]) search (
7980 // k is the number of candidates in the result set.
8081 k int ,
8182 efSearch int ,
8283 target Embedding ,
8384 distance DistanceFunc ,
84- ) []searchCandidate [T ] {
85+ ) []searchCandidate [K , V ] {
8586 // This is a basic greedy algorithm to find the entry point at the given level
8687 // that is closest to the target node.
87- candidates := heap.Heap [searchCandidate [T ]]{}
88- candidates .Init (make ([]searchCandidate [T ], 0 , efSearch ))
88+ candidates := heap.Heap [searchCandidate [K , V ]]{}
89+ candidates .Init (make ([]searchCandidate [K , V ], 0 , efSearch ))
8990 candidates .Push (
90- searchCandidate [T ]{
91+ searchCandidate [K , V ]{
9192 node : n ,
9293 dist : distance (n .Point .Embedding (), target ),
9394 },
9495 )
9596 var (
96- result = heap.Heap [searchCandidate [T ]]{}
97- visited = make (map [string ]bool )
97+ result = heap.Heap [searchCandidate [K , V ]]{}
98+ visited = make (map [K ]bool )
9899 )
99- result .Init (make ([]searchCandidate [T ], 0 , k ))
100+ result .Init (make ([]searchCandidate [K , V ], 0 , k ))
100101
101102 // Begin with the entry node in the result set.
102103 result .Push (candidates .Min ())
@@ -122,13 +123,13 @@ func (n *layerNode[T]) search(
122123 dist := distance (neighbor .Point .Embedding (), target )
123124 improved = improved || dist < result .Min ().dist
124125 if result .Len () < k {
125- result .Push (searchCandidate [T ]{node : neighbor , dist : dist })
126+ result .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
126127 } else if dist < result .Max ().dist {
127128 result .PopLast ()
128- result .Push (searchCandidate [T ]{node : neighbor , dist : dist })
129+ result .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
129130 }
130131
131- candidates .Push (searchCandidate [T ]{node : neighbor , dist : dist })
132+ candidates .Push (searchCandidate [K , V ]{node : neighbor , dist : dist })
132133 // Always store candidates if we haven't reached the limit.
133134 if candidates .Len () > efSearch {
134135 candidates .PopLast ()
@@ -145,7 +146,7 @@ func (n *layerNode[T]) search(
145146 return result .Slice ()
146147}
147148
148- func (n * layerNode [T ]) replenish (m int ) {
149+ func (n * layerNode [K , V ]) replenish (m int ) {
149150 if len (n .neighbors ) >= m {
150151 return
151152 }
@@ -172,7 +173,7 @@ func (n *layerNode[T]) replenish(m int) {
172173
173174// isolates remove the node from the graph by removing all connections
174175// to neighbors.
175- func (n * layerNode [T ]) isolate (m int ) {
176+ func (n * layerNode [K , V ]) isolate (m int ) {
176177 for _ , neighbor := range n .neighbors {
177178 delete (neighbor .neighbors , n .Point .ID ())
178179 neighbor .replenish (m )
0 commit comments