Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,44 @@ import (

// An Embedder is used to convert a document to a
// multidimensional vector.
type Embedder core.Action[*EmbedRequest, []float32, struct{}]
type Embedder core.Action[*EmbedRequest, *EmbedResponse, struct{}]

// EmbedRequest is the data we pass to convert a document
// EmbedRequest is the data we pass to convert one or more documents
// to a multidimensional vector.
type EmbedRequest struct {
Document *Document `json:"input"`
Options any `json:"options,omitempty"`
Documents []*Document `json:"input"`
Options any `json:"options,omitempty"`
}

type EmbedResponse struct {
// One embedding for each Document in the request, in the same order.
Embeddings []*DocumentEmbedding `json:"embeddings"`
}

// DocumentEmbedding holds emdedding information about a single document.
type DocumentEmbedding struct {
// The vector for the embedding.
Embedding []float32 `json:"embedding"`
}

// DefineEmbedder registers the given embed function as an action, and returns an
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder {
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) (*EmbedResponse, error)) *Embedder {
return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
}

// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *Embedder {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name)
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](atype.Embedder, provider, name)
if action == nil {
return nil
}
return (*Embedder)(action)
}

// Embed runs the given [Embedder].
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
a := (*core.Action[*EmbedRequest, []float32, struct{}])(e)
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
a := (*core.Action[*EmbedRequest, *EmbedResponse, struct{}])(e)
return a.Run(ctx, req, nil)
}
14 changes: 9 additions & 5 deletions go/internal/fakeembedder/fakeembedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) {
e.registry[d] = vals
}

func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
vals, ok := e.registry[req.Document]
if !ok {
return nil, errors.New("fake embedder called with unregistered document")
func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
res := &ai.EmbedResponse{}
for _, doc := range req.Documents {
vals, ok := e.registry[doc]
if !ok {
return nil, errors.New("fake embedder called with unregistered document")
}
res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: vals})
}
return vals, nil
return res, nil
}
7 changes: 4 additions & 3 deletions go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@ func TestFakeEmbedder(t *testing.T) {
embed.Register(d, vals)

req := &ai.EmbedRequest{
Document: d,
Documents: []*ai.Document{d},
}
ctx := context.Background()
got, err := emb.Embed(ctx, req)
res, err := emb.Embed(ctx, req)
if err != nil {
t.Fatal(err)
}
got := res.Embeddings[0].Embedding
if !slices.Equal(got, vals) {
t.Errorf("lookup returned %v, want %v", got, vals)
}

req.Document = ai.DocumentFromText("missing document", nil)
req.Documents[0] = ai.DocumentFromText("missing document", nil)
if _, err = emb.Embed(ctx, req); err == nil {
t.Error("embedding unknown document succeeded unexpectedly")
}
Expand Down
23 changes: 14 additions & 9 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,22 @@ func DefineEmbedder(name string) *ai.Embedder {

// requires state.mu
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
// TODO: use the batch embedding API.
em := state.client.EmbeddingModel(name)
parts, err := convertParts(input.Document.Content)
if err != nil {
return nil, err
}
res, err := em.EmbedContent(ctx, parts...)
if err != nil {
return nil, err
var res ai.EmbedResponse
for _, doc := range input.Documents {
parts, err := convertParts(doc.Content)
if err != nil {
return nil, err
}
eres, err := em.EmbedContent(ctx, parts...)
if err != nil {
return nil, err
}
res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: eres.Embedding.Values})
}
return res.Embedding.Values, nil
return &res, nil
})
}

Expand Down
14 changes: 7 additions & 7 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func TestLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
embedder := googleai.DefineEmbedder("embedding-001")
model, err := googleai.DefineModel("gemini-1.0-pro", nil)
embedder := googleai.Embedder("embedding-001")
model := googleai.Model("gemini-1.0-pro")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -85,13 +85,13 @@ func TestLive(t *testing.T) {
},
)
t.Run("embedder", func(t *testing.T) {
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
Document: ai.DocumentFromText("yellow banana", nil),
res, err := embedder.Embed(ctx, &ai.EmbedRequest{
Documents: []*ai.Document{ai.DocumentFromText("yellow banana", nil)},
})
if err != nil {
t.Fatal(err)
}

out := res.Embeddings[0].Embedding
// There's not a whole lot we can test about the result.
// Just do a few sanity checks.
if len(out) < 100 {
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestLive(t *testing.T) {
Candidates: 1,
Messages: []*ai.Message{
{
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the North Pole.")},
Role: ai.RoleUser,
},
},
Expand All @@ -160,7 +160,7 @@ func TestLive(t *testing.T) {
if out != out2 {
t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2)
}
const want = "Golden"
const want = "North"
if !strings.Contains(out, want) {
t.Errorf("got %q, expecting it to contain %q", out, want)
}
Expand Down
33 changes: 16 additions & 17 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,19 @@ func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) (

// index indexes a document.
func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
for _, doc := range req.Documents {
ereq := &ai.EmbedRequest{
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}

id, err := docID(doc)
ereq := &ai.EmbedRequest{
Documents: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
for i, de := range eres.Embeddings {
id, err := docID(req.Documents[i])
if err != nil {
return err
}

if _, ok := ds.data[id]; ok {
logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id)
continue
Expand All @@ -144,8 +142,8 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
}

ds.data[id] = dbValue{
Doc: doc,
Embedding: vals,
Doc: req.Documents[i],
Embedding: de.Embedding,
}
}

Expand Down Expand Up @@ -183,13 +181,14 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Document: req.Document,
Options: ds.embedderOptions,
Documents: []*ai.Document{req.Document},
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
vals := eres.Embeddings[0].Embedding

type scoredDoc struct {
score float64
Expand Down
30 changes: 15 additions & 15 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,16 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {

// Use the embedder to convert each Document into a vector.
vecs := make([]vector, 0, len(req.Documents))
for _, doc := range req.Documents {
ereq := &ai.EmbedRequest{
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}

ereq := &ai.EmbedRequest{
Documents: req.Documents,
Options: ds.embedderOptions,
}
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}
for i, de := range eres.Embeddings {
doc := req.Documents[i]
id, err := docID(doc)
if err != nil {
return err
Expand All @@ -216,7 +216,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {

v := vector{
ID: id,
Values: vals,
Values: de.Embedding,
Metadata: metadata,
}
vecs = append(vecs, v)
Expand Down Expand Up @@ -282,15 +282,15 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Document: req.Document,
Options: ds.embedderOptions,
Documents: []*ai.Document{req.Document},
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
eres, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
}

results, err := ds.index.query(ctx, vals, count, wantMetadata, namespace)
results, err := ds.index.query(ctx, eres.Embeddings[0].Embedding, count, wantMetadata, namespace)
if err != nil {
return nil, err
}
Expand Down
45 changes: 30 additions & 15 deletions go/plugins/vertexai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ package vertexai

import (
"context"
"errors"
"fmt"
"strings"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
Expand All @@ -34,7 +35,7 @@ type EmbedOptions struct {
TaskType string `json:"task_type,omitempty"`
}

func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) {
func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
Expand All @@ -44,32 +45,34 @@ func embed(ctx context.Context, reqEndpoint string, client *aiplatform.Predictio
return nil, err
}

// TODO(ianlancetaylor): This can return multiple vectors.
// We just use the first one for now.

if len(resp.Predictions) < 1 {
return nil, errors.New("vertexai: embed request returned no values")
if g, w := len(resp.Predictions), len(req.Documents); g != w {
return nil, fmt.Errorf("vertexai: got %d embeddings, expected %d", g, w)
}

values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
ret := make([]float32, len(values))
for i, value := range values {
ret[i] = float32(value.GetNumberValue())
ret := &ai.EmbedResponse{}
for _, pred := range resp.Predictions {
values := pred.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
vals := make([]float32, len(values))
for i, value := range values {
vals[i] = float32(value.GetNumberValue())
}
ret.Embeddings = append(ret.Embeddings, &ai.DocumentEmbedding{Embedding: vals})
}

return ret, nil
}

// newPredictRequest creates a PredictRequest from an EmbedRequest.
// Each Document in the EmbedRequest becomes a separate instance in the PredictRequest.
func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) {
var title, taskType string
if options, _ := req.Options.(*EmbedOptions); options != nil {
title = options.Title
taskType = options.TaskType
}
instances := make([]*structpb.Value, 0, len(req.Document.Content))
for _, part := range req.Document.Content {
instances := make([]*structpb.Value, 0, len(req.Documents))
for _, doc := range req.Documents {
fields := map[string]any{
"content": part.Text,
"content": text(doc),
}
if title != "" {
fields["title"] = title
Expand All @@ -90,3 +93,15 @@ func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.Pre
Instances: instances,
}, nil
}

// text concatenates all the text parts of the document together,
// with no delimiter.
func text(d *ai.Document) string {
var b strings.Builder
for _, p := range d.Content {
if p.IsText() {
b.WriteString(p.Text)
}
}
return b.String()
}
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func DefineEmbedder(name string) *ai.Embedder {
panic("vertexai.Init not called")
}
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name)
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
return embed(ctx, fullName, state.pclient, req)
})
}
Expand Down
Loading