diff --git a/go/core/logger/logger.go b/go/core/logger/logger.go index 22e8fb7bca..d99c1c4c15 100644 --- a/go/core/logger/logger.go +++ b/go/core/logger/logger.go @@ -27,7 +27,7 @@ func init() { // TODO: Remove this. The main program should be responsible for configuring logging. // This is just a convenience during development. h := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelDebug, + Level: slog.LevelInfo, })) slog.SetDefault(h) } diff --git a/go/core/registry.go b/go/core/registry.go index de5b278643..5a7a92ac6c 100644 --- a/go/core/registry.go +++ b/go/core/registry.go @@ -66,7 +66,7 @@ func newRegistry() (*registry, error) { if err != nil { return nil, err } - r.registerTraceStore(EnvironmentDev, tstore) + r.setTraceStore(EnvironmentDev, tstore) r.tstate = tracing.NewState() r.tstate.AddTraceStoreImmediate(tstore) return r, nil @@ -178,6 +178,24 @@ func (r *registry) registerTraceStore(env Environment, ts tracing.Store) { r.traceStores[env] = ts } +// SetDevTraceStore establishes its argument as the [tracing.Store] for the dev environment. +// It is intended for testing. +// Call the returned function to remove the store when the test ends. +func SetDevTraceStore(ts tracing.Store) (remove func()) { + globalRegistry.setTraceStore(EnvironmentDev, ts) + rem := globalRegistry.tstate.AddTraceStoreImmediate(ts) + return func() { + rem() + globalRegistry.setTraceStore(EnvironmentDev, nil) + } +} + +func (r *registry) setTraceStore(env Environment, ts tracing.Store) { + r.mu.Lock() + defer r.mu.Unlock() + r.traceStores[env] = ts +} + func (r *registry) lookupTraceStore(env Environment) tracing.Store { r.mu.Lock() defer r.mu.Unlock() diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index 4464bdb1f9..853567957a 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -49,13 +49,18 @@ func (ts *State) RegisterSpanProcessor(sp sdktrace.SpanProcessor) { // Traces are saved immediately as they are finshed. // Use this for a gtrace.Store with a fast Save method, // such as one that writes to a file. -func (ts *State) AddTraceStoreImmediate(tstore Store) { +// The return value is a function that will remove the trace store when called. +func (ts *State) AddTraceStoreImmediate(tstore Store) (remove func()) { e := newTraceStoreExporter(tstore) // Adding a SimpleSpanProcessor is like using the WithSyncer option. - ts.RegisterSpanProcessor(sdktrace.NewSimpleSpanProcessor(e)) + spp := sdktrace.NewSimpleSpanProcessor(e) + ts.RegisterSpanProcessor(spp) // Ignore tracerProvider.Shutdown. It shouldn't be needed when using WithSyncer. // Confirmed for OTel packages as of v1.24.0. // Also requires traceStoreExporter.Shutdown to be a no-op. + return func() { + ts.tp.UnregisterSpanProcessor(spp) + } } // AddTraceStoreBatch adds ts to the tracingState. diff --git a/go/go.mod b/go/go.mod index 60f1a33dcb..906a9bb66f 100644 --- a/go/go.mod +++ b/go/go.mod @@ -9,11 +9,12 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 github.com/aymerick/raymond v2.0.2+incompatible - github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add + github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 + github.com/jba/xltest/go/xltest v0.1.2 github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 @@ -29,7 +30,7 @@ require ( require ( cloud.google.com/go v0.114.0 // indirect - cloud.google.com/go/ai v0.5.0 // indirect + cloud.google.com/go/ai v0.6.0 // indirect cloud.google.com/go/auth v0.5.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect @@ -48,7 +49,6 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect - github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/go/go.sum b/go/go.sum index ea68f9f6e6..014789aa50 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,8 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY= cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E= -cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw= -cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8= +cloud.google.com/go/ai v0.6.0 h1:QWjb2UoaM15e51IMeLuIUFyWxooKOKDb66Mk47zZ2/g= +cloud.google.com/go/ai v0.6.0/go.mod h1:6/mrRq6aJdK7MZH76ZvcMpESiAiha5aRvurmroiOrgI= cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U= cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw= @@ -41,7 +41,6 @@ github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx2 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -72,8 +71,8 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add h1:ppPgQNwv1OidlzYyoeN6AEfcPJ5f2cO0hK2UzKDnoXc= -github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add/go.mod h1:Pmy+JWGfZt1kjjKPpufz2uunTIOy+dhWA3aOIC7ub3Q= +github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e h1:qSpEdwyYXID3hhX5g7/atrUGj5SDOFiAK4TvzgwvKJs= +github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e/go.mod h1:hOzbW3cB5hRV2x05McOwJS4GsqSluYwejjk5tSfb6YY= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -95,6 +94,8 @@ github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10C github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= +github.com/jba/xltest/go/xltest v0.1.2 h1:FS3vj2h+rLchD8h2JQsNm5rj13HQHcQFBr94v4sdq0s= +github.com/jba/xltest/go/xltest v0.1.2/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 9b8128b823..7844063b37 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -32,7 +32,7 @@ const provider = "googleai" // Config provides configuration options for the Init function. type Config struct { - // API key. Required. + // API key. Required unless ClientOptions is set. APIKey string // Generative models to provide. // If empty, a complete list will be obtained from the service. @@ -40,6 +40,8 @@ type Config struct { // Embedding models to provide. // If empty, a complete list will be obtained from the service. Embedders []string + // Options to pass to the client. If non-empty, APIKey is not used. + ClientOptions []option.ClientOption } func Init(ctx context.Context, cfg Config) (err error) { @@ -49,11 +51,15 @@ func Init(ctx context.Context, cfg Config) (err error) { } }() - if cfg.APIKey == "" { + if cfg.APIKey == "" && len(cfg.ClientOptions) == 0 { return errors.New("missing API key") } - client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey)) + opts := cfg.ClientOptions + if len(opts) == 0 { + opts = []option.ClientOption{option.WithAPIKey(cfg.APIKey)} + } + client, err := genai.NewClient(ctx, opts...) if err != nil { return err } diff --git a/go/tests/googleai_test.go b/go/tests/googleai_test.go new file mode 100644 index 0000000000..b4c7ce900c --- /dev/null +++ b/go/tests/googleai_test.go @@ -0,0 +1,165 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "cmp" + "context" + "flag" + "fmt" + "net/http" + "path/filepath" + "slices" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/plugins/googleai" + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/jba/xltest/go/xltest" + "google.golang.org/api/option" +) + +var update = flag.Bool("update", false, "update golden files with results") + +type embedderResult struct { + Values []float32 + TraceFile string `yaml:"traceFile"` + trace *tracing.Data +} + +func TestGoogleAI(t *testing.T) { + ctx := context.Background() + err := googleai.Init(ctx, googleai.Config{ + ClientOptions: []option.ClientOption{option.WithHTTPClient(mockClient())}, + Models: []string{"gemini-1.5-pro"}, + Embedders: []string{"embedding-001"}, + }) + if err != nil { + t.Fatal(err) + } + + t.Run("embed", func(t *testing.T) { + test, err := xltest.ReadFile[string, embedderResult](filepath.Join("testdata", "googleai-embedder.yaml")) + if err != nil { + t.Fatal(err) + } + testfunc := func(text string) (embedderResult, error) { + ts := &testTraceStore{} + remove := core.SetDevTraceStore(ts) + defer remove() + vals, err := ai.Embed(ctx, googleai.Embedder("embedding-001"), &ai.EmbedRequest{Document: ai.DocumentFromText(text, nil)}) + if err != nil { + return embedderResult{}, err + } + return embedderResult{Values: vals, trace: ts.td}, nil + } + valfunc := validateEmbedder + if *update { + valfunc = updateEmbedder + } + test.Run(t, testfunc, valfunc, nil) + }) +} + +// validateEmbedder compares the results from running an embedder with +// the desired results. The latter are taken from a YAML file and so are in raw +// unmarshalled form. +func validateEmbedder(got, want embedderResult) error { + if err := internal.ReadJSONFile(filepath.Join("testdata", want.TraceFile), &want.trace); err != nil { + return err + } + if !slices.Equal(got.Values, want.Values) { + return fmt.Errorf("values: got %v, want %v", got.Values, want.Values) + } + renameSpans(got.trace) + renameSpans(want.trace) + opts := []gocmp.Option{ + cmpopts.IgnoreFields(tracing.Data{}, "TraceID", "StartTime", "EndTime"), + cmpopts.IgnoreFields(tracing.SpanData{}, "TraceID", "StartTime", "EndTime"), + } + if diff := gocmp.Diff(want.trace, got.trace, opts...); diff != "" { + return fmt.Errorf("traces: %s", diff) + } + return nil +} + +// renameSpans changes the keys of td.Spans to s000, s001, ... in order of the span start time, +// as well as references to those IDs within the spans. +// This makes it possible to compare two span maps with different span IDs. +func renameSpans(td *tracing.Data) { + type item struct { + id string + t tracing.Milliseconds + } + + var items []item + startTimes := map[tracing.Milliseconds]bool{} + for id, span := range td.Spans { + if startTimes[span.StartTime] { + panic("duplicate start times") + } + startTimes[span.StartTime] = true + items = append(items, item{id, span.StartTime}) + } + slices.SortFunc(items, func(i1, i2 item) int { + return cmp.Compare(i1.t, i2.t) + }) + oldIDToNew := map[string]string{} + for i, item := range items { + oldIDToNew[item.id] = fmt.Sprintf("s%03d", i) + } + // Re-create the span map with the new span IDs. + m := map[string]*tracing.SpanData{} + for oldID, span := range td.Spans { + newID := oldIDToNew[oldID] + if newID == "" { + panic(fmt.Sprintf("missing id: %q", oldID)) + } + m[newID] = span + // A span references it own span ID and possibly its parent's. + span.SpanID = oldIDToNew[span.SpanID] + if span.ParentSpanID != "" { + span.ParentSpanID = oldIDToNew[span.ParentSpanID] + } + } + td.Spans = m +} + +func updateEmbedder(got, want embedderResult) error { + fmt.Printf("writing %s\n", want.TraceFile) + return internal.WriteJSONFile(want.TraceFile, got.trace) +} + +func mockClient() *http.Client { + mrt := &MockRoundTripper{} + mrt.Handle("POST /v1beta/models/embedding-001:embedContent", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, `{"embedding": {"values": [0.25, 0.25, 0.5]}}`) + })) + return &http.Client{Transport: mrt} +} + +type testTraceStore struct { + tracing.Store + td *tracing.Data +} + +func (ts *testTraceStore) Save(ctx context.Context, id string, td *tracing.Data) error { + ts.td = td + return nil +} diff --git a/go/tests/mock_round_tripper.go b/go/tests/mock_round_tripper.go new file mode 100644 index 0000000000..20c76b4227 --- /dev/null +++ b/go/tests/mock_round_tripper.go @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "errors" + "net/http" + "net/http/httptest" +) + +// MockRoundTripper is a RoundTripper whose responses can be set by the program. +// It is configured exactly like a ServeMux: by registering http.Handlers with +// patterns. When a request is received, the matching handler is run and +// its response is returned. +// +// This type is a convenience; the same result can be achieved by registering +// handlers with an httptest.Server and using its client. This avoids the +// (local) network round trip. +type MockRoundTripper struct { + mux http.ServeMux +} + +// Handle registers a handle with the MockRoundTripper associated with the given pattern. +func (rt *MockRoundTripper) Handle(pattern string, handler http.Handler) { + rt.mux.Handle(pattern, handler) +} + +func (rt *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h, pat := rt.mux.Handler(req) + if pat == "" { + return nil, errors.New("no matching handler matches") + } + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + return w.Result(), nil +} diff --git a/go/tests/testdata/googleai-embedder-trace.json b/go/tests/testdata/googleai-embedder-trace.json new file mode 100644 index 0000000000..e4831fc4fc --- /dev/null +++ b/go/tests/testdata/googleai-embedder-trace.json @@ -0,0 +1,37 @@ +{ + "traceId": "d6b5a3f8a48f1277445c6c6c205d2f81", + "spans": { + "5b80287f0f49c6e9": { + "spanId": "5b80287f0f49c6e9", + "traceId": "d6b5a3f8a48f1277445c6c6c205d2f81", + "startTime": 1719244014336, + "endTime": 1719244014841.5574, + "parentSpanId": "6327e0cd498e6e52", + "attributes": { + "genkit:input": "{\"input\":[{\"content\":[{\"text\":\"banana\"}]}]}", + "genkit:metadata:subtype": "embedder", + "genkit:name": "googleai/embedding-001", + "genkit:output": "{\"embeddings\":[{\"embedding\":[0.25,0.25,0.5]}]}", + "genkit:path": "/{dev-run-action-wrapper}/{googleai/embedding-001,t:action,s:embedder}", + "genkit:state": "success", + "genkit:type": "action" + }, + "displayName": "googleai/embedding-001", + "links": [], + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": { + "timeEvent": [] + } + } + } +} \ No newline at end of file diff --git a/go/tests/testdata/googleai-embedder.yaml b/go/tests/testdata/googleai-embedder.yaml new file mode 100644 index 0000000000..3597063c1f --- /dev/null +++ b/go/tests/testdata/googleai-embedder.yaml @@ -0,0 +1,22 @@ +# This is a cross-language test for embedders in the googlai plugin. +# Its main purpose is to validate the trace resulting from +# calling an embedder. +# +# Test function: call a mocked version of the googleai embedder +# that returns the given output. +# +# Validation function: +# - values: compare using equality +# - traceFile: +# - read the file +# - canonicalize the span IDs +# - ignore or remove trace ids, start and end times +# - then compare for equality. + +name: googleai-embedder +description: googleai plugin, embedder +in: 'banana' +want: + values: [0.25, 0.25, 0.5] + traceFile: 'googleai-embedder-trace.json' +