From 6ab8141d79189a0c3d34b4cf9662d4c56f84a7ef Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Jul 2024 15:39:36 -0400 Subject: [PATCH] [Go] align Embedder API with JS - An EmbedRequest takes a slice of Documents instead of a single Document. - An EmbedResponse contains embeddings for each document. The []float32 containing the embedding is inside a struct, to accommodate future additions (and to match the JS). - The googleai embedder works on multiple documents sequentially. It should be changed to use the BatchEmbed RPC. - The vertexai embedder always handled multiple "instances". Now an instance is the concatenated text parts of a document; before it was one text part of the sole document. (This is the only behavioral change.) There is one unrelated change: the prompt of a generation test was changed because the previous prompt is now blocked for the "recitation" reason. --- go/ai/embedder.go | 27 +++++++---- go/internal/fakeembedder/fakeembedder.go | 14 +++--- go/internal/fakeembedder/fakeembedder_test.go | 7 +-- go/plugins/googleai/googleai.go | 23 ++++++---- go/plugins/googleai/googleai_test.go | 14 +++--- go/plugins/localvec/localvec.go | 33 +++++++------- go/plugins/pinecone/genkit.go | 30 ++++++------- go/plugins/vertexai/embed.go | 45 ++++++++++++------- go/plugins/vertexai/vertexai.go | 2 +- go/plugins/vertexai/vertexai_test.go | 36 ++++++++------- 10 files changed, 134 insertions(+), 97 deletions(-) diff --git a/go/ai/embedder.go b/go/ai/embedder.go index df612bd58b..daaa1a5e59 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -23,25 +23,36 @@ import ( // An Embedder is used to convert a document to a // multidimensional vector. -type Embedder core.Action[*EmbedRequest, []float32, struct{}] +type Embedder core.Action[*EmbedRequest, *EmbedResponse, struct{}] -// EmbedRequest is the data we pass to convert a document +// EmbedRequest is the data we pass to convert one or more documents // to a multidimensional vector. type EmbedRequest struct { - Document *Document `json:"input"` - Options any `json:"options,omitempty"` + Documents []*Document `json:"input"` + Options any `json:"options,omitempty"` +} + +type EmbedResponse struct { + // One embedding for each Document in the request, in the same order. + Embeddings []*DocumentEmbedding `json:"embeddings"` +} + +// DocumentEmbedding holds emdedding information about a single document. +type DocumentEmbedding struct { + // The vector for the embedding. + Embedding []float32 `json:"embedding"` } // DefineEmbedder registers the given embed function as an action, and returns an // [EmbedderAction] that runs it. -func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder { +func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) (*EmbedResponse, error)) *Embedder { return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed)) } // LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder]. // It returns nil if the embedder was not defined. func LookupEmbedder(provider, name string) *Embedder { - action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name) + action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](atype.Embedder, provider, name) if action == nil { return nil } @@ -49,7 +60,7 @@ func LookupEmbedder(provider, name string) *Embedder { } // Embed runs the given [Embedder]. -func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) { - a := (*core.Action[*EmbedRequest, []float32, struct{}])(e) +func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + a := (*core.Action[*EmbedRequest, *EmbedResponse, struct{}])(e) return a.Run(ctx, req, nil) } diff --git a/go/internal/fakeembedder/fakeembedder.go b/go/internal/fakeembedder/fakeembedder.go index 083a5f8874..11fd2e259b 100644 --- a/go/internal/fakeembedder/fakeembedder.go +++ b/go/internal/fakeembedder/fakeembedder.go @@ -43,10 +43,14 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) { e.registry[d] = vals } -func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { - vals, ok := e.registry[req.Document] - if !ok { - return nil, errors.New("fake embedder called with unregistered document") +func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + res := &ai.EmbedResponse{} + for _, doc := range req.Documents { + vals, ok := e.registry[doc] + if !ok { + return nil, errors.New("fake embedder called with unregistered document") + } + res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: vals}) } - return vals, nil + return res, nil } diff --git a/go/internal/fakeembedder/fakeembedder_test.go b/go/internal/fakeembedder/fakeembedder_test.go index b157463dd6..77b5ca42b2 100644 --- a/go/internal/fakeembedder/fakeembedder_test.go +++ b/go/internal/fakeembedder/fakeembedder_test.go @@ -31,18 +31,19 @@ func TestFakeEmbedder(t *testing.T) { embed.Register(d, vals) req := &ai.EmbedRequest{ - Document: d, + Documents: []*ai.Document{d}, } ctx := context.Background() - got, err := emb.Embed(ctx, req) + res, err := emb.Embed(ctx, req) if err != nil { t.Fatal(err) } + got := res.Embeddings[0].Embedding if !slices.Equal(got, vals) { t.Errorf("lookup returned %v, want %v", got, vals) } - req.Document = ai.DocumentFromText("missing document", nil) + req.Documents[0] = ai.DocumentFromText("missing document", nil) if _, err = emb.Embed(ctx, req); err == nil { t.Error("embedding unknown document succeeded unexpectedly") } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index e257cefd5a..aabcf90805 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -162,17 +162,22 @@ func DefineEmbedder(name string) *ai.Embedder { // requires state.mu func defineEmbedder(name string) *ai.Embedder { - return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { + return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) { + // TODO: use the batch embedding API. em := state.client.EmbeddingModel(name) - parts, err := convertParts(input.Document.Content) - if err != nil { - return nil, err - } - res, err := em.EmbedContent(ctx, parts...) - if err != nil { - return nil, err + var res ai.EmbedResponse + for _, doc := range input.Documents { + parts, err := convertParts(doc.Content) + if err != nil { + return nil, err + } + eres, err := em.EmbedContent(ctx, parts...) + if err != nil { + return nil, err + } + res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: eres.Embedding.Values}) } - return res.Embedding.Values, nil + return &res, nil }) } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 41bd03ed29..b2656d1cb8 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -46,8 +46,8 @@ func TestLive(t *testing.T) { if err != nil { t.Fatal(err) } - embedder := googleai.DefineEmbedder("embedding-001") - model, err := googleai.DefineModel("gemini-1.0-pro", nil) + embedder := googleai.Embedder("embedding-001") + model := googleai.Model("gemini-1.0-pro") if err != nil { t.Fatal(err) } @@ -85,13 +85,13 @@ func TestLive(t *testing.T) { }, ) t.Run("embedder", func(t *testing.T) { - out, err := embedder.Embed(ctx, &ai.EmbedRequest{ - Document: ai.DocumentFromText("yellow banana", nil), + res, err := embedder.Embed(ctx, &ai.EmbedRequest{ + Documents: []*ai.Document{ai.DocumentFromText("yellow banana", nil)}, }) if err != nil { t.Fatal(err) } - + out := res.Embeddings[0].Embedding // There's not a whole lot we can test about the result. // Just do a few sanity checks. if len(out) < 100 { @@ -137,7 +137,7 @@ func TestLive(t *testing.T) { Candidates: 1, Messages: []*ai.Message{ { - Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")}, + Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the North Pole.")}, Role: ai.RoleUser, }, }, @@ -160,7 +160,7 @@ func TestLive(t *testing.T) { if out != out2 { t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2) } - const want = "Golden" + const want = "North" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) } diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index ded43584f6..c661538b41 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -119,21 +119,19 @@ func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) ( // index indexes a document. func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error { - for _, doc := range req.Documents { - ereq := &ai.EmbedRequest{ - Document: doc, - Options: ds.embedderOptions, - } - vals, err := ds.embedder.Embed(ctx, ereq) - if err != nil { - return fmt.Errorf("localvec index embedding failed: %v", err) - } - - id, err := docID(doc) + ereq := &ai.EmbedRequest{ + Documents: req.Documents, + Options: ds.embedderOptions, + } + eres, err := ds.embedder.Embed(ctx, ereq) + if err != nil { + return fmt.Errorf("localvec index embedding failed: %v", err) + } + for i, de := range eres.Embeddings { + id, err := docID(req.Documents[i]) if err != nil { return err } - if _, ok := ds.data[id]; ok { logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id) continue @@ -144,8 +142,8 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error { } ds.data[id] = dbValue{ - Doc: doc, - Embedding: vals, + Doc: req.Documents[i], + Embedding: de.Embedding, } } @@ -183,13 +181,14 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai // Use the embedder to convert the document we want to // retrieve into a vector. ereq := &ai.EmbedRequest{ - Document: req.Document, - Options: ds.embedderOptions, + Documents: []*ai.Document{req.Document}, + Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + eres, err := ds.embedder.Embed(ctx, ereq) if err != nil { return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err) } + vals := eres.Embeddings[0].Embedding type scoredDoc struct { score float64 diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index ada31bbdd7..496fc8ab26 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -185,16 +185,16 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { // Use the embedder to convert each Document into a vector. vecs := make([]vector, 0, len(req.Documents)) - for _, doc := range req.Documents { - ereq := &ai.EmbedRequest{ - Document: doc, - Options: ds.embedderOptions, - } - vals, err := ds.embedder.Embed(ctx, ereq) - if err != nil { - return fmt.Errorf("pinecone index embedding failed: %v", err) - } - + ereq := &ai.EmbedRequest{ + Documents: req.Documents, + Options: ds.embedderOptions, + } + eres, err := ds.embedder.Embed(ctx, ereq) + if err != nil { + return fmt.Errorf("pinecone index embedding failed: %v", err) + } + for i, de := range eres.Embeddings { + doc := req.Documents[i] id, err := docID(doc) if err != nil { return err @@ -216,7 +216,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { v := vector{ ID: id, - Values: vals, + Values: de.Embedding, Metadata: metadata, } vecs = append(vecs, v) @@ -282,15 +282,15 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai // Use the embedder to convert the document we want to // retrieve into a vector. ereq := &ai.EmbedRequest{ - Document: req.Document, - Options: ds.embedderOptions, + Documents: []*ai.Document{req.Document}, + Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + eres, err := ds.embedder.Embed(ctx, ereq) if err != nil { return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err) } - results, err := ds.index.query(ctx, vals, count, wantMetadata, namespace) + results, err := ds.index.query(ctx, eres.Embeddings[0].Embedding, count, wantMetadata, namespace) if err != nil { return nil, err } diff --git a/go/plugins/vertexai/embed.go b/go/plugins/vertexai/embed.go index 53d04ec0be..ae5bd5bedf 100644 --- a/go/plugins/vertexai/embed.go +++ b/go/plugins/vertexai/embed.go @@ -16,7 +16,8 @@ package vertexai import ( "context" - "errors" + "fmt" + "strings" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" @@ -34,7 +35,7 @@ type EmbedOptions struct { TaskType string `json:"task_type,omitempty"` } -func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) { +func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { preq, err := newPredictRequest(reqEndpoint, req) if err != nil { return nil, err @@ -44,32 +45,34 @@ func embed(ctx context.Context, reqEndpoint string, client *aiplatform.Predictio return nil, err } - // TODO(ianlancetaylor): This can return multiple vectors. - // We just use the first one for now. - - if len(resp.Predictions) < 1 { - return nil, errors.New("vertexai: embed request returned no values") + if g, w := len(resp.Predictions), len(req.Documents); g != w { + return nil, fmt.Errorf("vertexai: got %d embeddings, expected %d", g, w) } - values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values - ret := make([]float32, len(values)) - for i, value := range values { - ret[i] = float32(value.GetNumberValue()) + ret := &ai.EmbedResponse{} + for _, pred := range resp.Predictions { + values := pred.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values + vals := make([]float32, len(values)) + for i, value := range values { + vals[i] = float32(value.GetNumberValue()) + } + ret.Embeddings = append(ret.Embeddings, &ai.DocumentEmbedding{Embedding: vals}) } - return ret, nil } +// newPredictRequest creates a PredictRequest from an EmbedRequest. +// Each Document in the EmbedRequest becomes a separate instance in the PredictRequest. func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) { var title, taskType string if options, _ := req.Options.(*EmbedOptions); options != nil { title = options.Title taskType = options.TaskType } - instances := make([]*structpb.Value, 0, len(req.Document.Content)) - for _, part := range req.Document.Content { + instances := make([]*structpb.Value, 0, len(req.Documents)) + for _, doc := range req.Documents { fields := map[string]any{ - "content": part.Text, + "content": text(doc), } if title != "" { fields["title"] = title @@ -90,3 +93,15 @@ func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.Pre Instances: instances, }, nil } + +// text concatenates all the text parts of the document together, +// with no delimiter. +func text(d *ai.Document) string { + var b strings.Builder + for _, p := range d.Content { + if p.IsText() { + b.WriteString(p.Text) + } + } + return b.String() +} diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 4ec920488e..2b5b6f6058 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -177,7 +177,7 @@ func DefineEmbedder(name string) *ai.Embedder { panic("vertexai.Init not called") } fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name) - return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { + return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { return embed(ctx, fullName, state.pclient, req) }) } diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index 9ade6d89ea..b8675d647c 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -39,16 +39,12 @@ func TestLive(t *testing.T) { } ctx := context.Background() const modelName = "gemini-1.0-pro" - const embedderName = "textembedding-gecko" err := vertexai.Init(ctx, *projectID, *location) if err != nil { t.Fatal(err) } - model, err := vertexai.DefineModel(modelName, nil) - if err != nil { - t.Fatal(err) - } - embedder := vertexai.DefineEmbedder(embedderName) + model := vertexai.Model(modelName) + embedder := vertexai.Embedder("textembedding-gecko@003") toolDef := &ai.ToolDefinition{ Name: "exponentiation", @@ -176,8 +172,11 @@ func TestLive(t *testing.T) { } }) t.Run("embedder", func(t *testing.T) { - out, err := embedder.Embed(ctx, &ai.EmbedRequest{ - Document: ai.DocumentFromText("time flies like an arrow", nil), + res, err := embedder.Embed(ctx, &ai.EmbedRequest{ + Documents: []*ai.Document{ + ai.DocumentFromText("time flies like an arrow", nil), + ai.DocumentFromText("fruit flies like a banana", nil), + }, }) if err != nil { t.Fatal(err) @@ -185,15 +184,18 @@ func TestLive(t *testing.T) { // There's not a whole lot we can test about the result. // Just do a few sanity checks. - if len(out) < 100 { - t.Errorf("embedding vector looks too short: len(out)=%d", len(out)) - } - var normSquared float32 - for _, x := range out { - normSquared += x * x - } - if normSquared < 0.9 || normSquared > 1.1 { - t.Errorf("embedding vector not unit length: %f", normSquared) + for _, de := range res.Embeddings { + out := de.Embedding + if len(out) < 100 { + t.Errorf("embedding vector looks too short: len(out)=%d", len(out)) + } + var normSquared float32 + for _, x := range out { + normSquared += x * x + } + if normSquared < 0.9 || normSquared > 1.1 { + t.Errorf("embedding vector not unit length: %f", normSquared) + } } }) }