From bcdf8b122f92154f6ba0a5f7d3e8c59d08c169a8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 6 Jun 2024 15:34:21 -0400 Subject: [PATCH 1/6] [Go] cross-language test: googleai embedder Add a test designed to be run in different languages. This uses a cross-language test framework I'm developing, github.com/jba/xltest. At its heart is the file testdata/googleai-embedder.yaml, which defines the test in a language-neutral way. The test refers to a json file with the expected trace. The code in tests/googleai_test.go reads the test file and runs the test. The main challenge is making two traces comparable, which requires canonicalizing their span IDs. To make sure we are testing the complete logic of an embedder action, we pass a mock HTTP client to the googleai plugin, so we control the interaction with the service. Capturing traces cleanly also required some changes to the tracing system, to allow installing a trace store temporarily and then removing it. --- go/core/logger/logger.go | 2 +- go/core/registry.go | 20 +- go/core/tracing/tracing.go | 9 +- go/go.mod | 6 +- go/go.sum | 9 +- go/plugins/googleai/googleai.go | 10 +- go/tests/googleai_test.go | 176 ++++++++++++++++++ go/tests/mock_round_tripper.go | 48 +++++ .../testdata/googleai-embedder-trace.json | 36 ++++ go/tests/testdata/googleai-embedder.yaml | 22 +++ 10 files changed, 325 insertions(+), 13 deletions(-) create mode 100644 go/tests/googleai_test.go create mode 100644 go/tests/mock_round_tripper.go create mode 100644 go/tests/testdata/googleai-embedder-trace.json create mode 100644 go/tests/testdata/googleai-embedder.yaml diff --git a/go/core/logger/logger.go b/go/core/logger/logger.go index 22e8fb7bca..d99c1c4c15 100644 --- a/go/core/logger/logger.go +++ b/go/core/logger/logger.go @@ -27,7 +27,7 @@ func init() { // TODO: Remove this. The main program should be responsible for configuring logging. // This is just a convenience during development. h := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelDebug, + Level: slog.LevelInfo, })) slog.SetDefault(h) } diff --git a/go/core/registry.go b/go/core/registry.go index de5b278643..5a7a92ac6c 100644 --- a/go/core/registry.go +++ b/go/core/registry.go @@ -66,7 +66,7 @@ func newRegistry() (*registry, error) { if err != nil { return nil, err } - r.registerTraceStore(EnvironmentDev, tstore) + r.setTraceStore(EnvironmentDev, tstore) r.tstate = tracing.NewState() r.tstate.AddTraceStoreImmediate(tstore) return r, nil @@ -178,6 +178,24 @@ func (r *registry) registerTraceStore(env Environment, ts tracing.Store) { r.traceStores[env] = ts } +// SetDevTraceStore establishes its argument as the [tracing.Store] for the dev environment. +// It is intended for testing. +// Call the returned function to remove the store when the test ends. +func SetDevTraceStore(ts tracing.Store) (remove func()) { + globalRegistry.setTraceStore(EnvironmentDev, ts) + rem := globalRegistry.tstate.AddTraceStoreImmediate(ts) + return func() { + rem() + globalRegistry.setTraceStore(EnvironmentDev, nil) + } +} + +func (r *registry) setTraceStore(env Environment, ts tracing.Store) { + r.mu.Lock() + defer r.mu.Unlock() + r.traceStores[env] = ts +} + func (r *registry) lookupTraceStore(env Environment) tracing.Store { r.mu.Lock() defer r.mu.Unlock() diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index 4464bdb1f9..853567957a 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -49,13 +49,18 @@ func (ts *State) RegisterSpanProcessor(sp sdktrace.SpanProcessor) { // Traces are saved immediately as they are finshed. // Use this for a gtrace.Store with a fast Save method, // such as one that writes to a file. -func (ts *State) AddTraceStoreImmediate(tstore Store) { +// The return value is a function that will remove the trace store when called. +func (ts *State) AddTraceStoreImmediate(tstore Store) (remove func()) { e := newTraceStoreExporter(tstore) // Adding a SimpleSpanProcessor is like using the WithSyncer option. - ts.RegisterSpanProcessor(sdktrace.NewSimpleSpanProcessor(e)) + spp := sdktrace.NewSimpleSpanProcessor(e) + ts.RegisterSpanProcessor(spp) // Ignore tracerProvider.Shutdown. It shouldn't be needed when using WithSyncer. // Confirmed for OTel packages as of v1.24.0. // Also requires traceStoreExporter.Shutdown to be a no-op. + return func() { + ts.tp.UnregisterSpanProcessor(spp) + } } // AddTraceStoreBatch adds ts to the tracingState. diff --git a/go/go.mod b/go/go.mod index 60f1a33dcb..540195c13d 100644 --- a/go/go.mod +++ b/go/go.mod @@ -14,6 +14,7 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 + github.com/jba/xltest/go/xltest v0.1.1 github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 @@ -29,7 +30,7 @@ require ( require ( cloud.google.com/go v0.114.0 // indirect - cloud.google.com/go/ai v0.5.0 // indirect + cloud.google.com/go/ai v0.6.0 // indirect cloud.google.com/go/auth v0.5.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect @@ -48,7 +49,6 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect - github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect @@ -68,3 +68,5 @@ require ( google.golang.org/grpc v1.64.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) + +replace github.com/google/generative-ai-go => /usr/local/google/home/jba/repos/github.com/google/generative-ai-go diff --git a/go/go.sum b/go/go.sum index ea68f9f6e6..9592ecf008 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,8 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY= cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E= -cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw= -cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8= +cloud.google.com/go/ai v0.6.0 h1:QWjb2UoaM15e51IMeLuIUFyWxooKOKDb66Mk47zZ2/g= +cloud.google.com/go/ai v0.6.0/go.mod h1:6/mrRq6aJdK7MZH76ZvcMpESiAiha5aRvurmroiOrgI= cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U= cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw= @@ -41,7 +41,6 @@ github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx2 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -72,8 +71,6 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add h1:ppPgQNwv1OidlzYyoeN6AEfcPJ5f2cO0hK2UzKDnoXc= -github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add/go.mod h1:Pmy+JWGfZt1kjjKPpufz2uunTIOy+dhWA3aOIC7ub3Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -95,6 +92,8 @@ github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10C github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= +github.com/jba/xltest/go/xltest v0.1.1 h1:HV6454xz+kYvycBw/CHgk4Baj4klInZ0jichK0CaNhs= +github.com/jba/xltest/go/xltest v0.1.1/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 9b8128b823..a48d0c5ac6 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -40,6 +40,8 @@ type Config struct { // Embedding models to provide. // If empty, a complete list will be obtained from the service. Embedders []string + + ClientOptions []option.ClientOption } func Init(ctx context.Context, cfg Config) (err error) { @@ -49,11 +51,15 @@ func Init(ctx context.Context, cfg Config) (err error) { } }() - if cfg.APIKey == "" { + if cfg.APIKey == "" && len(cfg.ClientOptions) == 0 { return errors.New("missing API key") } - client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey)) + opts := cfg.ClientOptions + if len(opts) == 0 { + opts = []option.ClientOption{option.WithAPIKey(cfg.APIKey)} + } + client, err := genai.NewClient(ctx, opts...) if err != nil { return err } diff --git a/go/tests/googleai_test.go b/go/tests/googleai_test.go new file mode 100644 index 0000000000..06743eb07e --- /dev/null +++ b/go/tests/googleai_test.go @@ -0,0 +1,176 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "cmp" + "context" + "flag" + "fmt" + "net/http" + "path/filepath" + "slices" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/plugins/googleai" + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/jba/xltest/go/xltest" + "google.golang.org/api/option" +) + +var update = flag.Bool("update", false, "update golden files with results") + +func TestGoogleAI(t *testing.T) { + ctx := context.Background() + err := googleai.Init(ctx, googleai.Config{ + ClientOptions: []option.ClientOption{option.WithHTTPClient(mockClient())}, + Models: []string{"gemini-1.5-pro"}, + Embedders: []string{"embedding-001"}, + }) + if err != nil { + t.Fatal(err) + } + + t.Run("embed", func(t *testing.T) { + test, err := xltest.ReadFile(filepath.Join("testdata", "googleai-embedder.yaml")) + if err != nil { + t.Fatal(err) + } + valfunc := validateEmbedder + if *update { + valfunc = updateEmbedder + } + test.Run(t, func(text string) (embedderResult, error) { + ts := &testTraceStore{} + remove := core.SetDevTraceStore(ts) + defer remove() + vals, err := ai.Embed(ctx, googleai.Embedder("embedding-001"), &ai.EmbedRequest{Document: ai.DocumentFromText(text, nil)}) + if err != nil { + return embedderResult{}, err + } + return embedderResult{vals, ts.td}, nil + }, valfunc) + }) +} + +type embedderResult struct { + values []float32 + trace *tracing.Data +} + +// validateEmbedder compares the results from running an embedder with +// the desired results. The latter are taken from a YAML file and so are in raw +// unmarshalled form. +func validateEmbedder(got embedderResult, rawWant map[string]any) error { + var want embedderResult + rawVals := rawWant["values"].([]any) + f32s := make([]float32, len(rawVals)) + for i, v := range rawVals { + f32s[i] = float32(v.(float64)) + } + want.values = f32s + traceFile := rawWant["traceFile"].(string) + if err := internal.ReadJSONFile(filepath.Join("testdata", traceFile), &want.trace); err != nil { + return err + } + return compareEmbedderResults(got, want) +} + +func compareEmbedderResults(got, want embedderResult) error { + if !slices.Equal(got.values, want.values) { + return fmt.Errorf("values: got %v, want %v", got.values, want.values) + } + renameSpans(got.trace) + renameSpans(want.trace) + opts := []gocmp.Option{ + cmpopts.IgnoreFields(tracing.Data{}, "TraceID", "StartTime", "EndTime"), + cmpopts.IgnoreFields(tracing.SpanData{}, "TraceID", "StartTime", "EndTime"), + } + if diff := gocmp.Diff(want.trace, got.trace, opts...); diff != "" { + return fmt.Errorf("traces: %s", diff) + } + return nil +} + +// renameSpans changes the keys of td.Spans to s0, s1, ... in order of the span start time, +// as well as references to those IDs within the spans. +// This makes it possible to compare two span maps with different span IDs. +func renameSpans(td *tracing.Data) { + type item struct { + id string + t tracing.Milliseconds + } + + var items []item + startTimes := map[tracing.Milliseconds]bool{} + for id, span := range td.Spans { + if startTimes[span.StartTime] { + panic("duplicate start times") + } + startTimes[span.StartTime] = true + items = append(items, item{id, span.StartTime}) + } + slices.SortFunc(items, func(i1, i2 item) int { + return cmp.Compare(i1.t, i2.t) + }) + oldIDToNew := map[string]string{} + for i, item := range items { + oldIDToNew[item.id] = fmt.Sprintf("s%03d", i) + } + // Re-create the span map with the new span IDs. + m := map[string]*tracing.SpanData{} + for oldID, span := range td.Spans { + newID := oldIDToNew[oldID] + if newID == "" { + panic(fmt.Sprintf("missing id: %q", oldID)) + } + m[newID] = span + // A span references it own span ID and possibly its parent's. + span.SpanID = oldIDToNew[span.SpanID] + if span.ParentSpanID != "" { + span.ParentSpanID = oldIDToNew[span.ParentSpanID] + } + } + td.Spans = m +} + +func updateEmbedder(got embedderResult, rawWant map[string]any) error { + filename := rawWant["traceFile"].(string) + fmt.Printf("writing %s\n", filename) + return internal.WriteJSONFile(filename, got.trace) +} + +func mockClient() *http.Client { + mrt := &MockRoundTripper{} + mrt.Handle("POST /v1beta/models/embedding-001:embedContent", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"embedding": {"values": [0.25, 0.25, 0.5]}}`) + })) + return &http.Client{Transport: mrt} +} + +type testTraceStore struct { + tracing.Store + td *tracing.Data +} + +func (ts *testTraceStore) Save(ctx context.Context, id string, td *tracing.Data) error { + ts.td = td + return nil +} diff --git a/go/tests/mock_round_tripper.go b/go/tests/mock_round_tripper.go new file mode 100644 index 0000000000..20c76b4227 --- /dev/null +++ b/go/tests/mock_round_tripper.go @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "errors" + "net/http" + "net/http/httptest" +) + +// MockRoundTripper is a RoundTripper whose responses can be set by the program. +// It is configured exactly like a ServeMux: by registering http.Handlers with +// patterns. When a request is received, the matching handler is run and +// its response is returned. +// +// This type is a convenience; the same result can be achieved by registering +// handlers with an httptest.Server and using its client. This avoids the +// (local) network round trip. +type MockRoundTripper struct { + mux http.ServeMux +} + +// Handle registers a handle with the MockRoundTripper associated with the given pattern. +func (rt *MockRoundTripper) Handle(pattern string, handler http.Handler) { + rt.mux.Handle(pattern, handler) +} + +func (rt *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h, pat := rt.mux.Handler(req) + if pat == "" { + return nil, errors.New("no matching handler matches") + } + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + return w.Result(), nil +} diff --git a/go/tests/testdata/googleai-embedder-trace.json b/go/tests/testdata/googleai-embedder-trace.json new file mode 100644 index 0000000000..2a88ad8ad1 --- /dev/null +++ b/go/tests/testdata/googleai-embedder-trace.json @@ -0,0 +1,36 @@ +{ + "traceId": "", + "displayName": "embedding-001", + "startTime": 1718521125809.9194, + "endTime": 1718521125811.092, + "spans": { + "edb1aba2f4e66854": { + "spanId": "edb1aba2f4e66854", + "traceId": "5a59c7e04a2abbe463d4e617009f0779", + "startTime": 1718521125809.9194, + "endTime": 1718521125811.092, + "attributes": { + "genkit:input": "{\"input\":{\"content\":[{\"text\":\"banana\"}]}}", + "genkit:metadata:subtype": "embedder", + "genkit:name": "embedding-001", + "genkit:output": "[0.25,0.25,0.5]", + "genkit:path": "/embedding-001", + "genkit:state": "success", + "genkit:type": "action" + }, + "displayName": "embedding-001", + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + } + } +} diff --git a/go/tests/testdata/googleai-embedder.yaml b/go/tests/testdata/googleai-embedder.yaml new file mode 100644 index 0000000000..3597063c1f --- /dev/null +++ b/go/tests/testdata/googleai-embedder.yaml @@ -0,0 +1,22 @@ +# This is a cross-language test for embedders in the googlai plugin. +# Its main purpose is to validate the trace resulting from +# calling an embedder. +# +# Test function: call a mocked version of the googleai embedder +# that returns the given output. +# +# Validation function: +# - values: compare using equality +# - traceFile: +# - read the file +# - canonicalize the span IDs +# - ignore or remove trace ids, start and end times +# - then compare for equality. + +name: googleai-embedder +description: googleai plugin, embedder +in: 'banana' +want: + values: [0.25, 0.25, 0.5] + traceFile: 'googleai-embedder-trace.json' + From 4433ff53e2c562ff6d1e07390d87d5e75ddd9d6c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 17 Jun 2024 02:57:12 -0400 Subject: [PATCH 2/6] reviewer comments --- go/plugins/googleai/googleai.go | 4 ++-- go/tests/googleai_test.go | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index a48d0c5ac6..7844063b37 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -32,7 +32,7 @@ const provider = "googleai" // Config provides configuration options for the Init function. type Config struct { - // API key. Required. + // API key. Required unless ClientOptions is set. APIKey string // Generative models to provide. // If empty, a complete list will be obtained from the service. @@ -40,7 +40,7 @@ type Config struct { // Embedding models to provide. // If empty, a complete list will be obtained from the service. Embedders []string - + // Options to pass to the client. If non-empty, APIKey is not used. ClientOptions []option.ClientOption } diff --git a/go/tests/googleai_test.go b/go/tests/googleai_test.go index 06743eb07e..65eb12ee8d 100644 --- a/go/tests/googleai_test.go +++ b/go/tests/googleai_test.go @@ -53,11 +53,7 @@ func TestGoogleAI(t *testing.T) { if err != nil { t.Fatal(err) } - valfunc := validateEmbedder - if *update { - valfunc = updateEmbedder - } - test.Run(t, func(text string) (embedderResult, error) { + testfunc := func(text string) (embedderResult, error) { ts := &testTraceStore{} remove := core.SetDevTraceStore(ts) defer remove() @@ -66,7 +62,12 @@ func TestGoogleAI(t *testing.T) { return embedderResult{}, err } return embedderResult{vals, ts.td}, nil - }, valfunc) + } + valfunc := validateEmbedder + if *update { + valfunc = updateEmbedder + } + test.Run(t, testfunc, valfunc) }) } @@ -109,7 +110,7 @@ func compareEmbedderResults(got, want embedderResult) error { return nil } -// renameSpans changes the keys of td.Spans to s0, s1, ... in order of the span start time, +// renameSpans changes the keys of td.Spans to s000, s001, ... in order of the span start time, // as well as references to those IDs within the spans. // This makes it possible to compare two span maps with different span IDs. func renameSpans(td *tracing.Data) { From d199ec7c9b34b23a7e3348b24195899c91067883 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 18 Jun 2024 03:13:23 -0400 Subject: [PATCH 3/6] use typed xltest --- go/go.mod | 2 +- go/go.sum | 4 ++++ go/tests/googleai_test.go | 44 ++++++++++++++------------------------- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/go/go.mod b/go/go.mod index 540195c13d..f5dead6f32 100644 --- a/go/go.mod +++ b/go/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 - github.com/jba/xltest/go/xltest v0.1.1 + github.com/jba/xltest/go/xltest v0.1.2 github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 diff --git a/go/go.sum b/go/go.sum index 9592ecf008..19e81c44e7 100644 --- a/go/go.sum +++ b/go/go.sum @@ -94,6 +94,10 @@ github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= github.com/jba/xltest/go/xltest v0.1.1 h1:HV6454xz+kYvycBw/CHgk4Baj4klInZ0jichK0CaNhs= github.com/jba/xltest/go/xltest v0.1.1/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= +github.com/jba/xltest/go/xltest v0.1.2-0.20240618065056-84e8917dfe8c h1:dpi0pnYxG5MCmYPuqPQQ1vowgi62uyA3d8seJE5iNHA= +github.com/jba/xltest/go/xltest v0.1.2-0.20240618065056-84e8917dfe8c/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= +github.com/jba/xltest/go/xltest v0.1.2 h1:FS3vj2h+rLchD8h2JQsNm5rj13HQHcQFBr94v4sdq0s= +github.com/jba/xltest/go/xltest v0.1.2/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= diff --git a/go/tests/googleai_test.go b/go/tests/googleai_test.go index 65eb12ee8d..b4c7ce900c 100644 --- a/go/tests/googleai_test.go +++ b/go/tests/googleai_test.go @@ -37,6 +37,12 @@ import ( var update = flag.Bool("update", false, "update golden files with results") +type embedderResult struct { + Values []float32 + TraceFile string `yaml:"traceFile"` + trace *tracing.Data +} + func TestGoogleAI(t *testing.T) { ctx := context.Background() err := googleai.Init(ctx, googleai.Config{ @@ -49,7 +55,7 @@ func TestGoogleAI(t *testing.T) { } t.Run("embed", func(t *testing.T) { - test, err := xltest.ReadFile(filepath.Join("testdata", "googleai-embedder.yaml")) + test, err := xltest.ReadFile[string, embedderResult](filepath.Join("testdata", "googleai-embedder.yaml")) if err != nil { t.Fatal(err) } @@ -61,42 +67,25 @@ func TestGoogleAI(t *testing.T) { if err != nil { return embedderResult{}, err } - return embedderResult{vals, ts.td}, nil + return embedderResult{Values: vals, trace: ts.td}, nil } valfunc := validateEmbedder if *update { valfunc = updateEmbedder } - test.Run(t, testfunc, valfunc) + test.Run(t, testfunc, valfunc, nil) }) } -type embedderResult struct { - values []float32 - trace *tracing.Data -} - // validateEmbedder compares the results from running an embedder with // the desired results. The latter are taken from a YAML file and so are in raw // unmarshalled form. -func validateEmbedder(got embedderResult, rawWant map[string]any) error { - var want embedderResult - rawVals := rawWant["values"].([]any) - f32s := make([]float32, len(rawVals)) - for i, v := range rawVals { - f32s[i] = float32(v.(float64)) - } - want.values = f32s - traceFile := rawWant["traceFile"].(string) - if err := internal.ReadJSONFile(filepath.Join("testdata", traceFile), &want.trace); err != nil { +func validateEmbedder(got, want embedderResult) error { + if err := internal.ReadJSONFile(filepath.Join("testdata", want.TraceFile), &want.trace); err != nil { return err } - return compareEmbedderResults(got, want) -} - -func compareEmbedderResults(got, want embedderResult) error { - if !slices.Equal(got.values, want.values) { - return fmt.Errorf("values: got %v, want %v", got.values, want.values) + if !slices.Equal(got.Values, want.Values) { + return fmt.Errorf("values: got %v, want %v", got.Values, want.Values) } renameSpans(got.trace) renameSpans(want.trace) @@ -152,10 +141,9 @@ func renameSpans(td *tracing.Data) { td.Spans = m } -func updateEmbedder(got embedderResult, rawWant map[string]any) error { - filename := rawWant["traceFile"].(string) - fmt.Printf("writing %s\n", filename) - return internal.WriteJSONFile(filename, got.trace) +func updateEmbedder(got, want embedderResult) error { + fmt.Printf("writing %s\n", want.TraceFile) + return internal.WriteJSONFile(want.TraceFile, got.trace) } func mockClient() *http.Client { From 2efb9d356726b26d9722d6d2fb7bd77b44e0cfae Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 18 Jun 2024 03:24:05 -0400 Subject: [PATCH 4/6] go mod tidy --- go/go.sum | 4 ---- 1 file changed, 4 deletions(-) diff --git a/go/go.sum b/go/go.sum index 19e81c44e7..ec048f0d23 100644 --- a/go/go.sum +++ b/go/go.sum @@ -92,10 +92,6 @@ github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10C github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= -github.com/jba/xltest/go/xltest v0.1.1 h1:HV6454xz+kYvycBw/CHgk4Baj4klInZ0jichK0CaNhs= -github.com/jba/xltest/go/xltest v0.1.1/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= -github.com/jba/xltest/go/xltest v0.1.2-0.20240618065056-84e8917dfe8c h1:dpi0pnYxG5MCmYPuqPQQ1vowgi62uyA3d8seJE5iNHA= -github.com/jba/xltest/go/xltest v0.1.2-0.20240618065056-84e8917dfe8c/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= github.com/jba/xltest/go/xltest v0.1.2 h1:FS3vj2h+rLchD8h2JQsNm5rj13HQHcQFBr94v4sdq0s= github.com/jba/xltest/go/xltest v0.1.2/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= From f9bc51848133f055f78cd096feee3c39f00343af Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 18 Jun 2024 03:30:07 -0400 Subject: [PATCH 5/6] remove replace from go.mod --- go/go.mod | 4 +--- go/go.sum | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index f5dead6f32..906a9bb66f 100644 --- a/go/go.mod +++ b/go/go.mod @@ -9,7 +9,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 github.com/aymerick/raymond v2.0.2+incompatible - github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add + github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 @@ -68,5 +68,3 @@ require ( google.golang.org/grpc v1.64.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) - -replace github.com/google/generative-ai-go => /usr/local/google/home/jba/repos/github.com/google/generative-ai-go diff --git a/go/go.sum b/go/go.sum index ec048f0d23..014789aa50 100644 --- a/go/go.sum +++ b/go/go.sum @@ -71,6 +71,8 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e h1:qSpEdwyYXID3hhX5g7/atrUGj5SDOFiAK4TvzgwvKJs= +github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e/go.mod h1:hOzbW3cB5hRV2x05McOwJS4GsqSluYwejjk5tSfb6YY= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= From c77f519d7868da1bce6ca58aba29283a0afeb1aa Mon Sep 17 00:00:00 2001 From: Samuel Bushi Date: Mon, 24 Jun 2024 16:06:31 +0000 Subject: [PATCH 6/6] Update test trace to match what JS SDK --- .../testdata/googleai-embedder-trace.json | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/go/tests/testdata/googleai-embedder-trace.json b/go/tests/testdata/googleai-embedder-trace.json index 2a88ad8ad1..e4831fc4fc 100644 --- a/go/tests/testdata/googleai-embedder-trace.json +++ b/go/tests/testdata/googleai-embedder-trace.json @@ -1,24 +1,23 @@ { - "traceId": "", - "displayName": "embedding-001", - "startTime": 1718521125809.9194, - "endTime": 1718521125811.092, + "traceId": "d6b5a3f8a48f1277445c6c6c205d2f81", "spans": { - "edb1aba2f4e66854": { - "spanId": "edb1aba2f4e66854", - "traceId": "5a59c7e04a2abbe463d4e617009f0779", - "startTime": 1718521125809.9194, - "endTime": 1718521125811.092, + "5b80287f0f49c6e9": { + "spanId": "5b80287f0f49c6e9", + "traceId": "d6b5a3f8a48f1277445c6c6c205d2f81", + "startTime": 1719244014336, + "endTime": 1719244014841.5574, + "parentSpanId": "6327e0cd498e6e52", "attributes": { - "genkit:input": "{\"input\":{\"content\":[{\"text\":\"banana\"}]}}", + "genkit:input": "{\"input\":[{\"content\":[{\"text\":\"banana\"}]}]}", "genkit:metadata:subtype": "embedder", - "genkit:name": "embedding-001", - "genkit:output": "[0.25,0.25,0.5]", - "genkit:path": "/embedding-001", + "genkit:name": "googleai/embedding-001", + "genkit:output": "{\"embeddings\":[{\"embedding\":[0.25,0.25,0.5]}]}", + "genkit:path": "/{dev-run-action-wrapper}/{googleai/embedding-001,t:action,s:embedder}", "genkit:state": "success", "genkit:type": "action" }, - "displayName": "embedding-001", + "displayName": "googleai/embedding-001", + "links": [], "instrumentationLibrary": { "name": "genkit-tracer", "version": "v1" @@ -30,7 +29,9 @@ "status": { "code": 0 }, - "timeEvents": {} + "timeEvents": { + "timeEvent": [] + } } } -} +} \ No newline at end of file