diff --git a/vertexai/genai/aiplatformpb_veneer.gen.go b/vertexai/genai/aiplatformpb_veneer.gen.go index 42b7591af1a..22fb0892db4 100644 --- a/vertexai/genai/aiplatformpb_veneer.gen.go +++ b/vertexai/genai/aiplatformpb_veneer.gen.go @@ -485,7 +485,7 @@ type GenerationConfig struct { // Optional. If specified, nucleus sampling will be used. TopP *float32 // Optional. If specified, top-k sampling will be used. - TopK *float32 + TopK *int32 // Optional. Number of candidates to generate. CandidateCount *int32 // Optional. The maximum number of output tokens to generate per message. @@ -501,7 +501,7 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig { return &pb.GenerationConfig{ Temperature: v.Temperature, TopP: v.TopP, - TopK: v.TopK, + TopK: int32pToFloat32p(v.TopK), CandidateCount: v.CandidateCount, MaxOutputTokens: v.MaxOutputTokens, StopSequences: v.StopSequences, @@ -515,7 +515,7 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig { return &GenerationConfig{ Temperature: p.Temperature, TopP: p.TopP, - TopK: p.TopK, + TopK: float32pToInt32p(p.TopK), CandidateCount: p.CandidateCount, MaxOutputTokens: p.MaxOutputTokens, StopSequences: p.StopSequences, diff --git a/vertexai/genai/client.go b/vertexai/genai/client.go index 0dc6f5bf2a7..1e781ee9516 100644 --- a/vertexai/genai/client.go +++ b/vertexai/genai/client.go @@ -324,3 +324,19 @@ func mergeTexts(in []Part) []Part { } return out } + +func int32pToFloat32p(x *int32) *float32 { + if x == nil { + return nil + } + f := float32(*x) + return &f +} + +func float32pToInt32p(x *float32) *int32 { + if x == nil { + return nil + } + i := int32(*x) + return &i +} diff --git a/vertexai/genai/client_test.go b/vertexai/genai/client_test.go index e1e08a048aa..17b06d5d544 100644 --- a/vertexai/genai/client_test.go +++ b/vertexai/genai/client_test.go @@ -507,3 +507,30 @@ func TestTemperature(t *testing.T) { t.Errorf("got %v, want 0", g) } } + +func TestIntFloatConversions(t *testing.T) { + for n, test := range []struct { + i *int32 + f *float32 + }{ + {nil, nil}, + {Ptr[int32](1), Ptr[float32](1)}, + } { + t.Run(fmt.Sprintf("int-to-float-%d", n), func(t *testing.T) { + gotf := int32pToFloat32p(test.i) + if !reflect.DeepEqual(gotf, test.f) { + t.Errorf("got %v, want %v", gotf, test.f) + } + }) + t.Run(fmt.Sprintf("float-to-int-%d", n), func(t *testing.T) { + goti := float32pToInt32p(test.f) + if !reflect.DeepEqual(goti, test.i) { + t.Errorf("got %v, want %v", goti, test.i) + } + }) + } + goti := float32pToInt32p(Ptr[float32](1.5)) + if !reflect.DeepEqual(goti, Ptr[int32](1)) { + t.Errorf("got %v, want *1", goti) + } +} diff --git a/vertexai/genai/config.yaml b/vertexai/genai/config.yaml index 1520d0e92c4..7c2d9b03226 100644 --- a/vertexai/genai/config.yaml +++ b/vertexai/genai/config.yaml @@ -59,6 +59,10 @@ types: FunctionResponse: GenerationConfig: + fields: + TopK: + type: '*int32' + convertToFrom: int32pToFloat32p, float32pToInt32p SafetyRating: docVerb: 'is the' diff --git a/vertexai/genai/content.go b/vertexai/genai/content.go index 6d94d10904f..22d4e8ed8a3 100644 --- a/vertexai/genai/content.go +++ b/vertexai/genai/content.go @@ -133,4 +133,4 @@ func (c *GenerationConfig) SetTemperature(x float32) { c.Temperature = &x } func (c *GenerationConfig) SetTopP(x float32) { c.TopP = &x } // SetTopK sets the TopK field. -func (c *GenerationConfig) SetTopK(x float32) { c.TopK = &x } +func (c *GenerationConfig) SetTopK(x int32) { c.TopK = &x }