Skip to content

Commit

Permalink
ai: add embeddings basic support
Browse files Browse the repository at this point in the history
- add Embeddings service and its local implementation
- add Embedding type and proto message
- add nodeEmbeddingCollector tracking nodes
- add NodeEmbeddingWatcher watching for events adn sending them to the
  collector
- add the Embedder interface and its openai implementation
  • Loading branch information
hugoShaka committed Jun 7, 2023
1 parent d670eb9 commit 7afeb85
Show file tree
Hide file tree
Showing 9 changed files with 1,013 additions and 0 deletions.
208 changes: 208 additions & 0 deletions api/gen/proto/go/teleport/embedding/v1/embedding.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions api/proto/teleport/embedding/v1/embedding.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package teleport.embedding.v1;

import "teleport/legacy/types/types.proto";

option go_package = "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1;embeddingv1";

// Embedding contains a Teleport resource embedding. Embeddings are small semantic
// representations of larger and more complex data. Embeddings can be compared,
// the smaller the distance between two vectors, the closer the concepts are.
// Teleport Assist embeds resources to perform semantic search.
message Embedding {
// Metadata is the embedding metadata.
types.Metadata metadata = 1;

// Version is the embedding resource version.
string version = 2;

// EmbeddedHash is the hash of the embedded resource after serialization.
// This helps checking if the resource has changed and needs a new embedding.
bytes embedded_hash = 5;

// Vector is the embedding itself, as provided by the model.
repeated float vector = 6;
}
115 changes: 115 additions & 0 deletions lib/ai/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package ai

import (
"context"
"strings"

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"

embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1"
"github.com/gravitational/teleport/api/types"
)

const maxOpenAIEmbeddingPerRequest = 1000

// Embedding contains a Teleport resource embedding. Embeddings are small semantic
// representations of larger and more complex data. Embeddings can be compared,
// the smaller the distance between two vectors, the closer the concepts are.
// Teleport Assist embeds resources to perform semantic search.
// The Embedding is named after the embedded resource id and kind. For example
// the SSH node "bastion-01" has the embedding "node/bastion-01".
type Embedding struct {
*embeddingpb.Embedding
}

// GetEmbeddedKind returns the kind of the resource that was embedded.
func (e Embedding) GetEmbeddedKind() string {
return strings.Split(e.GetName(), "/")[0]
}

// GetName returns the Embedding name, composed of the embedded resource kind
// and the embedded resource ID.
func (e Embedding) GetName() string {
return e.GetMetadata().GetName()
}

// GetEmbeddedID returns the ID of the resource that was embedded.
func (e Embedding) GetEmbeddedID() string {
return strings.Split(e.GetName(), "/")[1]
}

// NewEmbedding is an Embedding constructor.
func NewEmbedding(kind, id string, vector []float32, hash [32]byte) Embedding {
return Embedding{
Embedding: &embeddingpb.Embedding{
Metadata: &types.Metadata{
Name: kind + "/" + id,
},
Version: "1",
EmbeddedHash: hash[:],
Vector: vector,
},
}
}

// Embedder is implemented for batch text embedding. Embedding can happen in
// place (with an embedding model for example) or be done by a remote embedding
// service like OpenAI.
type Embedder interface {
// ComputeEmbeddings computes the embeddings of multiple strings.
// The embedding list follows the input order (e.g. result[i] is the
// embedding of input[i]).
ComputeEmbeddings(ctx context.Context, input []string) ([][]float32, error)
}

// ComputeEmbeddings taxes a map of nodes and calls openAI to generate
// embeddings for those nodes. ComputeEmbeddings is responsible for
// implementing a retry mechanism if the embedding computation is flaky.
func (client *Client) ComputeEmbeddings(ctx context.Context, input []string) ([][]float32, error) {
var errors []error
var results [][]float32
for i := 0; maxOpenAIEmbeddingPerRequest*i < len(input); i++ {
result, err := client.computeEmbeddings(ctx, paginateInput(input, i, maxOpenAIEmbeddingPerRequest))
if err != nil {
errors = append(errors, trace.Wrap(err))
}
if result != nil {
results = append(results, result...)
}
}
return results, trace.NewAggregate(errors...)
}

func paginateInput(input []string, page, pageSize int) []string {
begin := page * pageSize
var end int
if len(input) < (page+1)*pageSize {
end = len(input)
} else {
end = (page + 1) * pageSize
}
return input[begin:end]
}

func (client *Client) computeEmbeddings(ctx context.Context, input []string) ([][]float32, error) {
if len(input) > maxOpenAIEmbeddingPerRequest {
return nil, trace.BadParameter("too many strings to embed (%s), maximum is %s", len(input))
}
req := openai.EmbeddingRequest{
Input: input,
Model: openai.AdaEmbeddingV2,
}

// TODO: measure if this is flaky and if we need to implement a retry mechanism
// Execute the query
resp, err := client.svc.CreateEmbeddings(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}
result := make([][]float32, len(input))
for i, item := range resp.Data {
result[i] = item.Embedding
}
return result, nil
}

0 comments on commit 7afeb85

Please sign in to comment.