-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- add Embeddings service and its local implementation - add Embedding type and proto message - add nodeEmbeddingCollector tracking nodes - add NodeEmbeddingWatcher watching for events adn sending them to the collector - add the Embedder interface and its openai implementation
- Loading branch information
Showing
9 changed files
with
1,013 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright 2022 Gravitational, Inc | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
syntax = "proto3"; | ||
|
||
package teleport.embedding.v1; | ||
|
||
import "teleport/legacy/types/types.proto"; | ||
|
||
option go_package = "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1;embeddingv1"; | ||
|
||
// Embedding contains a Teleport resource embedding. Embeddings are small semantic | ||
// representations of larger and more complex data. Embeddings can be compared, | ||
// the smaller the distance between two vectors, the closer the concepts are. | ||
// Teleport Assist embeds resources to perform semantic search. | ||
message Embedding { | ||
// Metadata is the embedding metadata. | ||
types.Metadata metadata = 1; | ||
|
||
// Version is the embedding resource version. | ||
string version = 2; | ||
|
||
// EmbeddedHash is the hash of the embedded resource after serialization. | ||
// This helps checking if the resource has changed and needs a new embedding. | ||
bytes embedded_hash = 5; | ||
|
||
// Vector is the embedding itself, as provided by the model. | ||
repeated float vector = 6; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package ai | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
|
||
"github.com/gravitational/trace" | ||
"github.com/sashabaranov/go-openai" | ||
|
||
embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1" | ||
"github.com/gravitational/teleport/api/types" | ||
) | ||
|
||
const maxOpenAIEmbeddingPerRequest = 1000 | ||
|
||
// Embedding contains a Teleport resource embedding. Embeddings are small semantic | ||
// representations of larger and more complex data. Embeddings can be compared, | ||
// the smaller the distance between two vectors, the closer the concepts are. | ||
// Teleport Assist embeds resources to perform semantic search. | ||
// The Embedding is named after the embedded resource id and kind. For example | ||
// the SSH node "bastion-01" has the embedding "node/bastion-01". | ||
type Embedding struct { | ||
*embeddingpb.Embedding | ||
} | ||
|
||
// GetEmbeddedKind returns the kind of the resource that was embedded. | ||
func (e Embedding) GetEmbeddedKind() string { | ||
return strings.Split(e.GetName(), "/")[0] | ||
} | ||
|
||
// GetName returns the Embedding name, composed of the embedded resource kind | ||
// and the embedded resource ID. | ||
func (e Embedding) GetName() string { | ||
return e.GetMetadata().GetName() | ||
} | ||
|
||
// GetEmbeddedID returns the ID of the resource that was embedded. | ||
func (e Embedding) GetEmbeddedID() string { | ||
return strings.Split(e.GetName(), "/")[1] | ||
} | ||
|
||
// NewEmbedding is an Embedding constructor. | ||
func NewEmbedding(kind, id string, vector []float32, hash [32]byte) Embedding { | ||
return Embedding{ | ||
Embedding: &embeddingpb.Embedding{ | ||
Metadata: &types.Metadata{ | ||
Name: kind + "/" + id, | ||
}, | ||
Version: "1", | ||
EmbeddedHash: hash[:], | ||
Vector: vector, | ||
}, | ||
} | ||
} | ||
|
||
// Embedder is implemented for batch text embedding. Embedding can happen in | ||
// place (with an embedding model for example) or be done by a remote embedding | ||
// service like OpenAI. | ||
type Embedder interface { | ||
// ComputeEmbeddings computes the embeddings of multiple strings. | ||
// The embedding list follows the input order (e.g. result[i] is the | ||
// embedding of input[i]). | ||
ComputeEmbeddings(ctx context.Context, input []string) ([][]float32, error) | ||
} | ||
|
||
// ComputeEmbeddings taxes a map of nodes and calls openAI to generate | ||
// embeddings for those nodes. ComputeEmbeddings is responsible for | ||
// implementing a retry mechanism if the embedding computation is flaky. | ||
func (client *Client) ComputeEmbeddings(ctx context.Context, input []string) ([][]float32, error) { | ||
var errors []error | ||
var results [][]float32 | ||
for i := 0; maxOpenAIEmbeddingPerRequest*i < len(input); i++ { | ||
result, err := client.computeEmbeddings(ctx, paginateInput(input, i, maxOpenAIEmbeddingPerRequest)) | ||
if err != nil { | ||
errors = append(errors, trace.Wrap(err)) | ||
} | ||
if result != nil { | ||
results = append(results, result...) | ||
} | ||
} | ||
return results, trace.NewAggregate(errors...) | ||
} | ||
|
||
func paginateInput(input []string, page, pageSize int) []string { | ||
begin := page * pageSize | ||
var end int | ||
if len(input) < (page+1)*pageSize { | ||
end = len(input) | ||
} else { | ||
end = (page + 1) * pageSize | ||
} | ||
return input[begin:end] | ||
} | ||
|
||
func (client *Client) computeEmbeddings(ctx context.Context, input []string) ([][]float32, error) { | ||
if len(input) > maxOpenAIEmbeddingPerRequest { | ||
return nil, trace.BadParameter("too many strings to embed (%s), maximum is %s", len(input)) | ||
} | ||
req := openai.EmbeddingRequest{ | ||
Input: input, | ||
Model: openai.AdaEmbeddingV2, | ||
} | ||
|
||
// TODO: measure if this is flaky and if we need to implement a retry mechanism | ||
// Execute the query | ||
resp, err := client.svc.CreateEmbeddings(ctx, req) | ||
if err != nil { | ||
return nil, trace.Wrap(err) | ||
} | ||
result := make([][]float32, len(input)) | ||
for i, item := range resp.Data { | ||
result[i] = item.Embedding | ||
} | ||
return result, nil | ||
} |
Oops, something went wrong.