forked from gravitational/teleport
-
Notifications
You must be signed in to change notification settings - Fork 0
/
embeddings.go
116 lines (104 loc) · 3.94 KB
/
embeddings.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package local
import (
"context"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/ai"
"github.com/gravitational/teleport/lib/ai/embedding"
"github.com/gravitational/teleport/lib/backend"
)
// EmbeddingsService implements the services.Embeddings interface.
type EmbeddingsService struct {
log *logrus.Entry
jitter retryutils.Jitter
backend.Backend
clock clockwork.Clock
}
const (
embeddingsPrefix = "embeddings"
embeddingExpiry = 30 * 24 * time.Hour // 30 days
)
// GetEmbedding looks up a single embedding by its name in the backend.
func (e EmbeddingsService) GetEmbedding(ctx context.Context, kind, resourceID string) (*embedding.Embedding, error) {
result, err := e.Get(ctx, backend.Key(embeddingsPrefix, kind, resourceID))
if err != nil {
return nil, trace.Wrap(err)
}
return ai.UnmarshalEmbedding(result.Value)
}
// GetEmbeddings returns a stream of all embeddings
func (e EmbeddingsService) GetAllEmbeddings(ctx context.Context) stream.Stream[*embedding.Embedding] {
startKey := backend.ExactKey(embeddingsPrefix)
items := backend.StreamRange(ctx, e, startKey, backend.RangeEnd(startKey), 50)
return stream.FilterMap(items, func(item backend.Item) (*embedding.Embedding, bool) {
embedding, err := ai.UnmarshalEmbedding(item.Value)
if err != nil {
e.log.Warnf("Skipping embedding at %s, failed to unmarshal: %v", item.Key, err)
return nil, false
}
return embedding, true
})
}
// GetEmbeddings returns a stream of embeddings for a given kind.
func (e EmbeddingsService) GetEmbeddings(ctx context.Context, kind string) stream.Stream[*embedding.Embedding] {
startKey := backend.ExactKey(embeddingsPrefix, kind)
items := backend.StreamRange(ctx, e, startKey, backend.RangeEnd(startKey), 50)
return stream.FilterMap(items, func(item backend.Item) (*embedding.Embedding, bool) {
embedding, err := ai.UnmarshalEmbedding(item.Value)
if err != nil {
e.log.Warnf("Skipping embedding at %s, failed to unmarshal: %v", item.Key, err)
return nil, false
}
return embedding, true
})
}
// UpsertEmbedding creates or update a single ai.Embedding in the backend.
func (e EmbeddingsService) UpsertEmbedding(ctx context.Context, embedding *embedding.Embedding) (*embedding.Embedding, error) {
value, err := ai.MarshalEmbedding(embedding)
if err != nil {
return nil, trace.Wrap(err)
}
_, err = e.Put(ctx, backend.Item{
Key: embeddingItemKey(embedding),
Value: value,
Expires: e.clock.Now().Add(embeddingExpiry),
})
if err != nil {
return nil, trace.Wrap(err)
}
return embedding, nil
}
// NewEmbeddingsService is a constructor for the EmbeddingsService.
func NewEmbeddingsService(b backend.Backend) *EmbeddingsService {
return &EmbeddingsService{
log: logrus.WithFields(logrus.Fields{trace.Component: "Embeddings"}),
jitter: retryutils.NewFullJitter(),
Backend: b,
clock: clockwork.NewRealClock(),
}
}
// embeddingItemKey builds the backend item key for a given ai.Embedding.
func embeddingItemKey(embedding *embedding.Embedding) []byte {
return backend.Key(embeddingsPrefix, embedding.GetName())
}