Skip to content

Commit

Permalink
genai: fix live test to not require flags (#110)
Browse files Browse the repository at this point in the history
Requiring flags precludes us from running `go test ./...` in this
repository.
The model name is open now so we can just place it as a constant in the
file - it's no longer the only model used by the tests either, so it's
less confusing. The API key can be provided with an env var like in the
examples.

With this change, `go test ./...` now succeeds (the live test skip if
the env var key is not set).
    
Also fix some test expectations to be more resilient.
  • Loading branch information
eliben committed May 10, 2024
1 parent e36c492 commit 47de201
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 31 deletions.
6 changes: 2 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ Guidelines](https://opensource.google/conduct/).
## Contribution process

1. Clone this repo
2. Run tests with `go test ./...`
3. You may need to run "live" tests that talk to a real endpoint; to do so, run
`go test -v ./genai/...` passing it your API key with the `-apikey` flag
and a model name flag like `-model gemini-1.0-pro`
2. Run tests with `go test ./...`; the "live" tests will be skipped
unless a valid API key is set with the `GEMINI_API_KEY` environment variable.

### Code Reviews

Expand Down
50 changes: 23 additions & 27 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
Expand All @@ -34,24 +33,26 @@ import (
"google.golang.org/api/option"
)

var (
apiKey = flag.String("apikey", "", "API key")
modelName = flag.String("model", "", "model name without vision suffix")
)

const defaultModel = "gemini-1.0-pro"
const imageFile = "personWorkingOnComputer.jpg"

func TestLive(t *testing.T) {
if *apiKey == "" || *modelName == "" {
t.Skip("need -apikey and -model")
apiKey := os.Getenv("GEMINI_API_KEY")
if testing.Short() {
t.Skip("skipping live test in -short mode")
}

if apiKey == "" {
t.Skip("set a GEMINI_API_KEY env var to run live tests")
}

ctx := context.Background()
client, err := NewClient(ctx, option.WithAPIKey(*apiKey))
client, err := NewClient(ctx, option.WithAPIKey(apiKey))
if err != nil {
t.Fatal(err)
}
defer client.Close()
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.Temperature = Ptr[float32](0)

t.Run("GenerateContent", func(t *testing.T) {
Expand All @@ -66,8 +67,9 @@ func TestLive(t *testing.T) {
t.Run("streaming", func(t *testing.T) {
iter := model.GenerateContentStream(ctx, Text("Are you hungry?"))
got := responsesString(t, iter)
checkMatch(t, got, `(don't|do\s+not|not capable) (have|possess|experiencing) .*(a .* needs|body|sensations|the ability)`)
checkMatch(t, got, `(don't|do\s+not|not capable) (have|possess|experiencing) .*(a .* needs|body|sensations|the ability|living)`)
})

t.Run("streaming-counting", func(t *testing.T) {
// Verify only that we don't crash. See #18.
iter := model.GenerateContentStream(ctx, Text("count 1 to 100."))
Expand All @@ -83,11 +85,6 @@ func TestLive(t *testing.T) {
if !errors.As(err, &gerr) {
t.Fatalf("does not wrap a googleapi.Error")
}
got := gerr.Error()
want := "INVALID_ARGUMENT"
if !strings.Contains(got, want) {
t.Errorf("got %q\n\ndoes not contain %q", got, want)
}
})
t.Run("chat", func(t *testing.T) {
session := model.StartChat()
Expand Down Expand Up @@ -131,7 +128,7 @@ func TestLive(t *testing.T) {
})

t.Run("image", func(t *testing.T) {
vmodel := client.GenerativeModel(*modelName + "-vision-latest")
vmodel := client.GenerativeModel(defaultModel + "-vision-latest")
vmodel.Temperature = Ptr[float32](0)

data, err := os.ReadFile(filepath.Join("testdata", imageFile))
Expand All @@ -145,7 +142,7 @@ func TestLive(t *testing.T) {
t.Fatal(err)
}
got := responseString(resp)
checkMatch(t, got, "picture", "person", "computer|laptop")
checkMatch(t, got, "man|person", "computer|laptop")
})

t.Run("blocked", func(t *testing.T) {
Expand Down Expand Up @@ -173,7 +170,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("max-tokens", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr(float32(0))
maxModel.SetMaxOutputTokens(10)
res, err := maxModel.GenerateContent(ctx, Text("What is a dog?"))
Expand All @@ -187,7 +184,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("max-tokens-streaming", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel := client.GenerativeModel(defaultModel)
maxModel.Temperature = Ptr[float32](0)
maxModel.MaxOutputTokens = Ptr[int32](10)
iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?"))
Expand Down Expand Up @@ -289,7 +286,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("get-model", func(t *testing.T) {
modName := *modelName
modName := defaultModel
got, err := client.GenerativeModel(modName).Info(ctx)
if err != nil {
t.Fatal(err)
Expand All @@ -307,8 +304,8 @@ func TestLive(t *testing.T) {
t.Errorf("got name %q, want %q", got.Name, w)
}
})
t.Run("tools", func(t *testing.T) {

t.Run("tools", func(t *testing.T) {
weatherChat := func(t *testing.T, s *Schema, fcm FunctionCallingMode) {
weatherTool := &Tool{
FunctionDeclarations: []*FunctionDeclaration{{
Expand All @@ -317,7 +314,7 @@ func TestLive(t *testing.T) {
Parameters: s,
}},
}
model := client.GenerativeModel(*modelName)
model := client.GenerativeModel(defaultModel)
model.SetTemperature(0)
model.Tools = []*Tool{weatherTool}
model.ToolConfig = &ToolConfig{
Expand Down Expand Up @@ -383,11 +380,8 @@ func TestLive(t *testing.T) {
weatherChat(t, schema, FunctionCallingNone)
})
})

t.Run("files", func(t *testing.T) {
const validModel = "gemini-1.5-pro-eval"
if *modelName != validModel {
t.Skipf("need model %q", validModel)
}
f, err := os.Open(filepath.Join("testdata", imageFile))
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -434,12 +428,14 @@ func TestLive(t *testing.T) {
}

// Use the uploaded file to generate content.
model := client.GenerativeModel("gemini-1.5-pro-latest")
resp, err := model.GenerateContent(ctx, FileData{URI: file.URI})
if err != nil {
t.Fatal(err)
}
checkMatch(t, responseString(resp), "picture|image", "person", "computer|laptop")
})

t.Run("JSON", func(t *testing.T) {
model := client.GenerativeModel("gemini-1.5-pro-latest")
model.SetTemperature(0)
Expand Down

0 comments on commit 47de201

Please sign in to comment.