From 47de2013df8c4eafe55c76ce25d1f94f23f48f62 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Fri, 10 May 2024 06:49:30 -0700 Subject: [PATCH] genai: fix live test to not require flags (#110) Requiring flags precludes us from running `go test ./...` in this repository. The model name is open now so we can just place it as a constant in the file - it's no longer the only model used by the tests either, so it's less confusing. The API key can be provided with an env var like in the examples. With this change, `go test ./...` now succeeds (the live test skip if the env var key is not set). Also fix some test expectations to be more resilient. --- CONTRIBUTING.md | 6 ++---- genai/client_test.go | 50 ++++++++++++++++++++------------------------ 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e8867ca..33edb91 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,10 +26,8 @@ Guidelines](https://opensource.google/conduct/). ## Contribution process 1. Clone this repo -2. Run tests with `go test ./...` -3. You may need to run "live" tests that talk to a real endpoint; to do so, run - `go test -v ./genai/...` passing it your API key with the `-apikey` flag - and a model name flag like `-model gemini-1.0-pro` +2. Run tests with `go test ./...`; the "live" tests will be skipped + unless a valid API key is set with the `GEMINI_API_KEY` environment variable. ### Code Reviews diff --git a/genai/client_test.go b/genai/client_test.go index cd089c4..722c93d 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "errors" - "flag" "fmt" "io" "log" @@ -34,24 +33,26 @@ import ( "google.golang.org/api/option" ) -var ( - apiKey = flag.String("apikey", "", "API key") - modelName = flag.String("model", "", "model name without vision suffix") -) - +const defaultModel = "gemini-1.0-pro" const imageFile = "personWorkingOnComputer.jpg" func TestLive(t *testing.T) { - if *apiKey == "" || *modelName == "" { - t.Skip("need -apikey and -model") + apiKey := os.Getenv("GEMINI_API_KEY") + if testing.Short() { + t.Skip("skipping live test in -short mode") + } + + if apiKey == "" { + t.Skip("set a GEMINI_API_KEY env var to run live tests") } + ctx := context.Background() - client, err := NewClient(ctx, option.WithAPIKey(*apiKey)) + client, err := NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { t.Fatal(err) } defer client.Close() - model := client.GenerativeModel(*modelName) + model := client.GenerativeModel(defaultModel) model.Temperature = Ptr[float32](0) t.Run("GenerateContent", func(t *testing.T) { @@ -66,8 +67,9 @@ func TestLive(t *testing.T) { t.Run("streaming", func(t *testing.T) { iter := model.GenerateContentStream(ctx, Text("Are you hungry?")) got := responsesString(t, iter) - checkMatch(t, got, `(don't|do\s+not|not capable) (have|possess|experiencing) .*(a .* needs|body|sensations|the ability)`) + checkMatch(t, got, `(don't|do\s+not|not capable) (have|possess|experiencing) .*(a .* needs|body|sensations|the ability|living)`) }) + t.Run("streaming-counting", func(t *testing.T) { // Verify only that we don't crash. See #18. iter := model.GenerateContentStream(ctx, Text("count 1 to 100.")) @@ -83,11 +85,6 @@ func TestLive(t *testing.T) { if !errors.As(err, &gerr) { t.Fatalf("does not wrap a googleapi.Error") } - got := gerr.Error() - want := "INVALID_ARGUMENT" - if !strings.Contains(got, want) { - t.Errorf("got %q\n\ndoes not contain %q", got, want) - } }) t.Run("chat", func(t *testing.T) { session := model.StartChat() @@ -131,7 +128,7 @@ func TestLive(t *testing.T) { }) t.Run("image", func(t *testing.T) { - vmodel := client.GenerativeModel(*modelName + "-vision-latest") + vmodel := client.GenerativeModel(defaultModel + "-vision-latest") vmodel.Temperature = Ptr[float32](0) data, err := os.ReadFile(filepath.Join("testdata", imageFile)) @@ -145,7 +142,7 @@ func TestLive(t *testing.T) { t.Fatal(err) } got := responseString(resp) - checkMatch(t, got, "picture", "person", "computer|laptop") + checkMatch(t, got, "man|person", "computer|laptop") }) t.Run("blocked", func(t *testing.T) { @@ -173,7 +170,7 @@ func TestLive(t *testing.T) { } }) t.Run("max-tokens", func(t *testing.T) { - maxModel := client.GenerativeModel(*modelName) + maxModel := client.GenerativeModel(defaultModel) maxModel.Temperature = Ptr(float32(0)) maxModel.SetMaxOutputTokens(10) res, err := maxModel.GenerateContent(ctx, Text("What is a dog?")) @@ -187,7 +184,7 @@ func TestLive(t *testing.T) { } }) t.Run("max-tokens-streaming", func(t *testing.T) { - maxModel := client.GenerativeModel(*modelName) + maxModel := client.GenerativeModel(defaultModel) maxModel.Temperature = Ptr[float32](0) maxModel.MaxOutputTokens = Ptr[int32](10) iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?")) @@ -289,7 +286,7 @@ func TestLive(t *testing.T) { } }) t.Run("get-model", func(t *testing.T) { - modName := *modelName + modName := defaultModel got, err := client.GenerativeModel(modName).Info(ctx) if err != nil { t.Fatal(err) @@ -307,8 +304,8 @@ func TestLive(t *testing.T) { t.Errorf("got name %q, want %q", got.Name, w) } }) - t.Run("tools", func(t *testing.T) { + t.Run("tools", func(t *testing.T) { weatherChat := func(t *testing.T, s *Schema, fcm FunctionCallingMode) { weatherTool := &Tool{ FunctionDeclarations: []*FunctionDeclaration{{ @@ -317,7 +314,7 @@ func TestLive(t *testing.T) { Parameters: s, }}, } - model := client.GenerativeModel(*modelName) + model := client.GenerativeModel(defaultModel) model.SetTemperature(0) model.Tools = []*Tool{weatherTool} model.ToolConfig = &ToolConfig{ @@ -383,11 +380,8 @@ func TestLive(t *testing.T) { weatherChat(t, schema, FunctionCallingNone) }) }) + t.Run("files", func(t *testing.T) { - const validModel = "gemini-1.5-pro-eval" - if *modelName != validModel { - t.Skipf("need model %q", validModel) - } f, err := os.Open(filepath.Join("testdata", imageFile)) if err != nil { t.Fatal(err) @@ -434,12 +428,14 @@ func TestLive(t *testing.T) { } // Use the uploaded file to generate content. + model := client.GenerativeModel("gemini-1.5-pro-latest") resp, err := model.GenerateContent(ctx, FileData{URI: file.URI}) if err != nil { t.Fatal(err) } checkMatch(t, responseString(resp), "picture|image", "person", "computer|laptop") }) + t.Run("JSON", func(t *testing.T) { model := client.GenerativeModel("gemini-1.5-pro-latest") model.SetTemperature(0)