Skip to content

Commit

Permalink
feat(vertexai/genai): change TopK to int (#9522)
Browse files Browse the repository at this point in the history
GenerationConfig.TopK is logically an integer. Represent it as such
in the client.
  • Loading branch information
jba committed Mar 8, 2024
1 parent c250928 commit 29d2c7d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 4 deletions.
6 changes: 3 additions & 3 deletions vertexai/genai/aiplatformpb_veneer.gen.go
Expand Up @@ -485,7 +485,7 @@ type GenerationConfig struct {
// Optional. If specified, nucleus sampling will be used.
TopP *float32
// Optional. If specified, top-k sampling will be used.
TopK *float32
TopK *int32
// Optional. Number of candidates to generate.
CandidateCount *int32
// Optional. The maximum number of output tokens to generate per message.
Expand All @@ -501,7 +501,7 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig {
return &pb.GenerationConfig{
Temperature: v.Temperature,
TopP: v.TopP,
TopK: v.TopK,
TopK: int32pToFloat32p(v.TopK),
CandidateCount: v.CandidateCount,
MaxOutputTokens: v.MaxOutputTokens,
StopSequences: v.StopSequences,
Expand All @@ -515,7 +515,7 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig {
return &GenerationConfig{
Temperature: p.Temperature,
TopP: p.TopP,
TopK: p.TopK,
TopK: float32pToInt32p(p.TopK),
CandidateCount: p.CandidateCount,
MaxOutputTokens: p.MaxOutputTokens,
StopSequences: p.StopSequences,
Expand Down
16 changes: 16 additions & 0 deletions vertexai/genai/client.go
Expand Up @@ -324,3 +324,19 @@ func mergeTexts(in []Part) []Part {
}
return out
}

func int32pToFloat32p(x *int32) *float32 {
if x == nil {
return nil
}
f := float32(*x)
return &f
}

func float32pToInt32p(x *float32) *int32 {
if x == nil {
return nil
}
i := int32(*x)
return &i
}
27 changes: 27 additions & 0 deletions vertexai/genai/client_test.go
Expand Up @@ -507,3 +507,30 @@ func TestTemperature(t *testing.T) {
t.Errorf("got %v, want 0", g)
}
}

func TestIntFloatConversions(t *testing.T) {
for n, test := range []struct {
i *int32
f *float32
}{
{nil, nil},
{Ptr[int32](1), Ptr[float32](1)},
} {
t.Run(fmt.Sprintf("int-to-float-%d", n), func(t *testing.T) {
gotf := int32pToFloat32p(test.i)
if !reflect.DeepEqual(gotf, test.f) {
t.Errorf("got %v, want %v", gotf, test.f)
}
})
t.Run(fmt.Sprintf("float-to-int-%d", n), func(t *testing.T) {
goti := float32pToInt32p(test.f)
if !reflect.DeepEqual(goti, test.i) {
t.Errorf("got %v, want %v", goti, test.i)
}
})
}
goti := float32pToInt32p(Ptr[float32](1.5))
if !reflect.DeepEqual(goti, Ptr[int32](1)) {
t.Errorf("got %v, want *1", goti)
}
}
4 changes: 4 additions & 0 deletions vertexai/genai/config.yaml
Expand Up @@ -59,6 +59,10 @@ types:
FunctionResponse:

GenerationConfig:
fields:
TopK:
type: '*int32'
convertToFrom: int32pToFloat32p, float32pToInt32p

SafetyRating:
docVerb: 'is the'
Expand Down
2 changes: 1 addition & 1 deletion vertexai/genai/content.go
Expand Up @@ -133,4 +133,4 @@ func (c *GenerationConfig) SetTemperature(x float32) { c.Temperature = &x }
func (c *GenerationConfig) SetTopP(x float32) { c.TopP = &x }

// SetTopK sets the TopK field.
func (c *GenerationConfig) SetTopK(x float32) { c.TopK = &x }
func (c *GenerationConfig) SetTopK(x int32) { c.TopK = &x }

0 comments on commit 29d2c7d

Please sign in to comment.