Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/core/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func init() {
// TODO: Remove this. The main program should be responsible for configuring logging.
// This is just a convenience during development.
h := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
Level: slog.LevelInfo,
}))
slog.SetDefault(h)
}
Expand Down
20 changes: 19 additions & 1 deletion go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newRegistry() (*registry, error) {
if err != nil {
return nil, err
}
r.registerTraceStore(EnvironmentDev, tstore)
r.setTraceStore(EnvironmentDev, tstore)
r.tstate = tracing.NewState()
r.tstate.AddTraceStoreImmediate(tstore)
return r, nil
Expand Down Expand Up @@ -178,6 +178,24 @@ func (r *registry) registerTraceStore(env Environment, ts tracing.Store) {
r.traceStores[env] = ts
}

// SetDevTraceStore establishes its argument as the [tracing.Store] for the dev environment.
// It is intended for testing.
// Call the returned function to remove the store when the test ends.
func SetDevTraceStore(ts tracing.Store) (remove func()) {
globalRegistry.setTraceStore(EnvironmentDev, ts)
rem := globalRegistry.tstate.AddTraceStoreImmediate(ts)
return func() {
rem()
globalRegistry.setTraceStore(EnvironmentDev, nil)
}
}

func (r *registry) setTraceStore(env Environment, ts tracing.Store) {
r.mu.Lock()
defer r.mu.Unlock()
r.traceStores[env] = ts
}

func (r *registry) lookupTraceStore(env Environment) tracing.Store {
r.mu.Lock()
defer r.mu.Unlock()
Expand Down
9 changes: 7 additions & 2 deletions go/core/tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,18 @@ func (ts *State) RegisterSpanProcessor(sp sdktrace.SpanProcessor) {
// Traces are saved immediately as they are finshed.
// Use this for a gtrace.Store with a fast Save method,
// such as one that writes to a file.
func (ts *State) AddTraceStoreImmediate(tstore Store) {
// The return value is a function that will remove the trace store when called.
func (ts *State) AddTraceStoreImmediate(tstore Store) (remove func()) {
e := newTraceStoreExporter(tstore)
// Adding a SimpleSpanProcessor is like using the WithSyncer option.
ts.RegisterSpanProcessor(sdktrace.NewSimpleSpanProcessor(e))
spp := sdktrace.NewSimpleSpanProcessor(e)
ts.RegisterSpanProcessor(spp)
// Ignore tracerProvider.Shutdown. It shouldn't be needed when using WithSyncer.
// Confirmed for OTel packages as of v1.24.0.
// Also requires traceStoreExporter.Shutdown to be a no-op.
return func() {
ts.tp.UnregisterSpanProcessor(spp)
}
}

// AddTraceStoreBatch adds ts to the tracingState.
Expand Down
6 changes: 3 additions & 3 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0
github.com/aymerick/raymond v2.0.2+incompatible
github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add
github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/invopop/jsonschema v0.12.0
github.com/jba/slog v0.2.0
github.com/jba/xltest/go/xltest v0.1.2
github.com/wk8/go-ordered-map/v2 v2.1.8
github.com/xeipuuv/gojsonschema v1.2.0
go.opentelemetry.io/otel v1.26.0
Expand All @@ -29,7 +30,7 @@ require (

require (
cloud.google.com/go v0.114.0 // indirect
cloud.google.com/go/ai v0.5.0 // indirect
cloud.google.com/go/ai v0.6.0 // indirect
cloud.google.com/go/auth v0.5.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
Expand All @@ -48,7 +49,6 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
Expand Down
11 changes: 6 additions & 5 deletions go/go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY=
cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E=
cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw=
cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8=
cloud.google.com/go/ai v0.6.0 h1:QWjb2UoaM15e51IMeLuIUFyWxooKOKDb66Mk47zZ2/g=
cloud.google.com/go/ai v0.6.0/go.mod h1:6/mrRq6aJdK7MZH76ZvcMpESiAiha5aRvurmroiOrgI=
cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U=
cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME=
cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw=
Expand Down Expand Up @@ -41,7 +41,6 @@ github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx2
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -72,8 +71,8 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add h1:ppPgQNwv1OidlzYyoeN6AEfcPJ5f2cO0hK2UzKDnoXc=
github.com/google/generative-ai-go v0.13.1-0.20240530125111-8decc9df4add/go.mod h1:Pmy+JWGfZt1kjjKPpufz2uunTIOy+dhWA3aOIC7ub3Q=
github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e h1:qSpEdwyYXID3hhX5g7/atrUGj5SDOFiAK4TvzgwvKJs=
github.com/google/generative-ai-go v0.14.1-0.20240618073058-6b2b0ac5749e/go.mod h1:hOzbW3cB5hRV2x05McOwJS4GsqSluYwejjk5tSfb6YY=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
Expand All @@ -95,6 +94,8 @@ github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10C
github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI=
github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338=
github.com/jba/xltest/go/xltest v0.1.2 h1:FS3vj2h+rLchD8h2JQsNm5rj13HQHcQFBr94v4sdq0s=
github.com/jba/xltest/go/xltest v0.1.2/go.mod h1:1t1ZLeFuJj14cypt9MzCna/qdRYBwaObid9dP2cpi7Q=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
Expand Down
12 changes: 9 additions & 3 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ const provider = "googleai"

// Config provides configuration options for the Init function.
type Config struct {
// API key. Required.
// API key. Required unless ClientOptions is set.
APIKey string
// Generative models to provide.
// If empty, a complete list will be obtained from the service.
Models []string
// Embedding models to provide.
// If empty, a complete list will be obtained from the service.
Embedders []string
// Options to pass to the client. If non-empty, APIKey is not used.
ClientOptions []option.ClientOption
}

func Init(ctx context.Context, cfg Config) (err error) {
Expand All @@ -49,11 +51,15 @@ func Init(ctx context.Context, cfg Config) (err error) {
}
}()

if cfg.APIKey == "" {
if cfg.APIKey == "" && len(cfg.ClientOptions) == 0 {
return errors.New("missing API key")
}

client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey))
opts := cfg.ClientOptions
if len(opts) == 0 {
opts = []option.ClientOption{option.WithAPIKey(cfg.APIKey)}
}
client, err := genai.NewClient(ctx, opts...)
if err != nil {
return err
}
Expand Down
165 changes: 165 additions & 0 deletions go/tests/googleai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tests

import (
"cmp"
"context"
"flag"
"fmt"
"net/http"
"path/filepath"
"slices"
"testing"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/plugins/googleai"
gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/jba/xltest/go/xltest"
"google.golang.org/api/option"
)

var update = flag.Bool("update", false, "update golden files with results")

type embedderResult struct {
Values []float32
TraceFile string `yaml:"traceFile"`
trace *tracing.Data
}

func TestGoogleAI(t *testing.T) {
ctx := context.Background()
err := googleai.Init(ctx, googleai.Config{
ClientOptions: []option.ClientOption{option.WithHTTPClient(mockClient())},
Models: []string{"gemini-1.5-pro"},
Embedders: []string{"embedding-001"},
})
if err != nil {
t.Fatal(err)
}

t.Run("embed", func(t *testing.T) {
test, err := xltest.ReadFile[string, embedderResult](filepath.Join("testdata", "googleai-embedder.yaml"))
if err != nil {
t.Fatal(err)
}
testfunc := func(text string) (embedderResult, error) {
ts := &testTraceStore{}
remove := core.SetDevTraceStore(ts)
defer remove()
vals, err := ai.Embed(ctx, googleai.Embedder("embedding-001"), &ai.EmbedRequest{Document: ai.DocumentFromText(text, nil)})
if err != nil {
return embedderResult{}, err
}
return embedderResult{Values: vals, trace: ts.td}, nil
}
valfunc := validateEmbedder
if *update {
valfunc = updateEmbedder
}
test.Run(t, testfunc, valfunc, nil)
})
}

// validateEmbedder compares the results from running an embedder with
// the desired results. The latter are taken from a YAML file and so are in raw
// unmarshalled form.
func validateEmbedder(got, want embedderResult) error {
if err := internal.ReadJSONFile(filepath.Join("testdata", want.TraceFile), &want.trace); err != nil {
return err
}
if !slices.Equal(got.Values, want.Values) {
return fmt.Errorf("values: got %v, want %v", got.Values, want.Values)
}
renameSpans(got.trace)
renameSpans(want.trace)
opts := []gocmp.Option{
cmpopts.IgnoreFields(tracing.Data{}, "TraceID", "StartTime", "EndTime"),
cmpopts.IgnoreFields(tracing.SpanData{}, "TraceID", "StartTime", "EndTime"),
}
if diff := gocmp.Diff(want.trace, got.trace, opts...); diff != "" {
return fmt.Errorf("traces: %s", diff)
}
return nil
}

// renameSpans changes the keys of td.Spans to s000, s001, ... in order of the span start time,
// as well as references to those IDs within the spans.
// This makes it possible to compare two span maps with different span IDs.
func renameSpans(td *tracing.Data) {
type item struct {
id string
t tracing.Milliseconds
}

var items []item
startTimes := map[tracing.Milliseconds]bool{}
for id, span := range td.Spans {
if startTimes[span.StartTime] {
panic("duplicate start times")
}
startTimes[span.StartTime] = true
items = append(items, item{id, span.StartTime})
}
slices.SortFunc(items, func(i1, i2 item) int {
return cmp.Compare(i1.t, i2.t)
})
oldIDToNew := map[string]string{}
for i, item := range items {
oldIDToNew[item.id] = fmt.Sprintf("s%03d", i)
}
// Re-create the span map with the new span IDs.
m := map[string]*tracing.SpanData{}
for oldID, span := range td.Spans {
newID := oldIDToNew[oldID]
if newID == "" {
panic(fmt.Sprintf("missing id: %q", oldID))
}
m[newID] = span
// A span references it own span ID and possibly its parent's.
span.SpanID = oldIDToNew[span.SpanID]
if span.ParentSpanID != "" {
span.ParentSpanID = oldIDToNew[span.ParentSpanID]
}
}
td.Spans = m
}
Comment on lines +105 to +142
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavelgj This code makes it possible to compare two traces completely, except for times and the trace ID.


func updateEmbedder(got, want embedderResult) error {
fmt.Printf("writing %s\n", want.TraceFile)
return internal.WriteJSONFile(want.TraceFile, got.trace)
}

func mockClient() *http.Client {
mrt := &MockRoundTripper{}
mrt.Handle("POST /v1beta/models/embedding-001:embedContent", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"embedding": {"values": [0.25, 0.25, 0.5]}}`)
}))
return &http.Client{Transport: mrt}
}

type testTraceStore struct {
tracing.Store
td *tracing.Data
}

func (ts *testTraceStore) Save(ctx context.Context, id string, td *tracing.Data) error {
ts.td = td
return nil
}
48 changes: 48 additions & 0 deletions go/tests/mock_round_tripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tests

import (
"errors"
"net/http"
"net/http/httptest"
)

// MockRoundTripper is a RoundTripper whose responses can be set by the program.
// It is configured exactly like a ServeMux: by registering http.Handlers with
// patterns. When a request is received, the matching handler is run and
// its response is returned.
//
// This type is a convenience; the same result can be achieved by registering
// handlers with an httptest.Server and using its client. This avoids the
// (local) network round trip.
type MockRoundTripper struct {
mux http.ServeMux
}

// Handle registers a handle with the MockRoundTripper associated with the given pattern.
func (rt *MockRoundTripper) Handle(pattern string, handler http.Handler) {
rt.mux.Handle(pattern, handler)
}

func (rt *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
h, pat := rt.mux.Handler(req)
if pat == "" {
return nil, errors.New("no matching handler matches")
}
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
return w.Result(), nil
}
Loading