From ca52dfa4788423c3b37b536e15a1018309db02dc Mon Sep 17 00:00:00 2001 From: aothelal Date: Wed, 1 Oct 2025 21:41:38 +0300 Subject: [PATCH] Handle Gemini maxOutputTokens attribute properly --- lib/ruby_llm/provider.rb | 29 ++++++- lib/ruby_llm/providers/gemini.rb | 8 ++ ...lly_maps_max_tokens_to_maxoutputtokens.yml | 81 +++++++++++++++++++ spec/ruby_llm/chat_request_options_spec.rb | 19 +++++ spec/ruby_llm/providers/gemini/chat_spec.rb | 45 +++++++++++ 5 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 spec/fixtures/vcr_cassettes/chat_with_params_gemini_gemini-2_5-flash_automatically_maps_max_tokens_to_maxoutputtokens.yml diff --git a/lib/ruby_llm/provider.rb b/lib/ruby_llm/provider.rb index f3344e57d..5e0d3ac4c 100644 --- a/lib/ruby_llm/provider.rb +++ b/lib/ruby_llm/provider.rb @@ -21,6 +21,10 @@ def headers {} end + def parameter_mappings + {} + end + def slug self.class.slug end @@ -39,6 +43,7 @@ def configuration_requirements def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists normalized_temperature = maybe_normalize_temperature(temperature, model) + transformed_params = apply_parameter_mappings(params) payload = Utils.deep_merge( render_payload( @@ -49,7 +54,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc stream: block_given?, schema: schema ), - params + transformed_params ) if block_given? @@ -192,6 +197,28 @@ def configured_remote_providers(config) private + def apply_parameter_mappings(params) + return params if parameter_mappings.empty? + + transformed = params.dup + + parameter_mappings.each do |source_key, target_path| + next unless transformed.key?(source_key) + + value = transformed.delete(source_key) + *keys, last_key = target_path + + target = keys.inject(transformed) do |hash, key| + hash[key] = {} unless hash[key].is_a?(Hash) + hash[key] + end + + target[last_key] = value + end + + transformed + end + def try_parse_json(maybe_json) return maybe_json unless maybe_json.is_a?(String) diff --git a/lib/ruby_llm/providers/gemini.rb b/lib/ruby_llm/providers/gemini.rb index 30e90b449..dd19808c3 100644 --- a/lib/ruby_llm/providers/gemini.rb +++ b/lib/ruby_llm/providers/gemini.rb @@ -22,6 +22,14 @@ def headers } end + private + + def parameter_mappings + { + max_tokens: %i[generationConfig maxOutputTokens] + } + end + class << self def capabilities Gemini::Capabilities diff --git a/spec/fixtures/vcr_cassettes/chat_with_params_gemini_gemini-2_5-flash_automatically_maps_max_tokens_to_maxoutputtokens.yml b/spec/fixtures/vcr_cassettes/chat_with_params_gemini_gemini-2_5-flash_automatically_maps_max_tokens_to_maxoutputtokens.yml new file mode 100644 index 000000000..3c2eb4cb1 --- /dev/null +++ b/spec/fixtures/vcr_cassettes/chat_with_params_gemini_gemini-2_5-flash_automatically_maps_max_tokens_to_maxoutputtokens.yml @@ -0,0 +1,81 @@ +--- +http_interactions: +- request: + method: post + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent + body: + encoding: UTF-8 + string: '{"contents":[{"role":"user","parts":[{"text":"Say hello in 3 words."}]}],"generationConfig":{"maxOutputTokens":100}}' + headers: + User-Agent: + - Faraday v2.13.4 + X-Goog-Api-Key: + - "" + Content-Type: + - application/json + Accept-Encoding: + - gzip;q=1.0,deflate;q=0.6,identity;q=0.3 + Accept: + - "*/*" + response: + status: + code: 200 + message: OK + headers: + Content-Type: + - application/json; charset=UTF-8 + Vary: + - Origin + - Referer + - X-Origin + Date: + - Wed, 01 Oct 2025 19:13:17 GMT + Server: + - scaffolding on HTTPServer2 + X-Xss-Protection: + - '0' + X-Frame-Options: + - SAMEORIGIN + X-Content-Type-Options: + - nosniff + Server-Timing: + - gfet4t7; dur=6012 + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + body: + encoding: ASCII-8BIT + string: | + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Well, hello there!" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 8, + "candidatesTokenCount": 5, + "totalTokenCount": 555, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 8 + } + ], + "thoughtsTokenCount": 542 + }, + "modelVersion": "gemini-2.5-flash", + "responseId": "TX3daKXSCOvnvdIP-63HUA" + } + recorded_at: Wed, 01 Oct 2025 19:13:17 GMT +recorded_with: VCR 6.3.1 diff --git a/spec/ruby_llm/chat_request_options_spec.rb b/spec/ruby_llm/chat_request_options_spec.rb index 11dc7f82f..2e9b285d4 100644 --- a/spec/ruby_llm/chat_request_options_spec.rb +++ b/spec/ruby_llm/chat_request_options_spec.rb @@ -106,5 +106,24 @@ expect(json_response).to eq({ 'result' => 8 }) end end + + # Provider [:gemini] automatically maps max_tokens to generationConfig.maxOutputTokens + CHAT_MODELS.select { |model_info| model_info[:provider] == :gemini }.each do |model_info| + model = model_info[:model] + provider = model_info[:provider] + it "#{provider}/#{model} automatically maps max_tokens to maxOutputTokens" do + chat = RubyLLM + .chat(model: model, provider: provider) + .with_params(max_tokens: 100) + + response = chat.ask('Say hello in 3 words.') + + request_body = JSON.parse(response.raw.env.request_body) + expect(request_body.dig('generationConfig', 'maxOutputTokens')).to eq(100) + expect(request_body).not_to have_key('max_tokens') + + expect(response.content).to be_present + end + end end end diff --git a/spec/ruby_llm/providers/gemini/chat_spec.rb b/spec/ruby_llm/providers/gemini/chat_spec.rb index 88eb6ec24..5a558fbf4 100644 --- a/spec/ruby_llm/providers/gemini/chat_spec.rb +++ b/spec/ruby_llm/providers/gemini/chat_spec.rb @@ -233,4 +233,49 @@ # Verify our implementation correctly sums both token types expect(response.output_tokens).to eq(candidates_tokens + thoughts_tokens) end + + describe 'parameter mapping' do + let(:provider) do + config = RubyLLM::Configuration.new + config.gemini_api_key = 'test_key' + RubyLLM::Providers::Gemini.new(config) + end + + it 'maps max_tokens to generationConfig.maxOutputTokens' do + params = { max_tokens: 1000 } + result = provider.send(:apply_parameter_mappings, params) + + expect(result).to eq({ generationConfig: { maxOutputTokens: 1000 } }) + end + + it 'removes max_tokens from the params after mapping' do + params = { max_tokens: 500 } + result = provider.send(:apply_parameter_mappings, params) + + expect(result).not_to have_key(:max_tokens) + end + + it 'preserves other params while mapping max_tokens' do + params = { max_tokens: 1000, other_param: 'value' } + result = provider.send(:apply_parameter_mappings, params) + + expect(result[:other_param]).to eq('value') + expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(1000) + end + + it 'merges with existing generationConfig hash' do + params = { max_tokens: 500, generationConfig: { temperature: 0.7 } } + result = provider.send(:apply_parameter_mappings, params) + + expect(result.dig(:generationConfig, :temperature)).to eq(0.7) + expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(500) + end + + it 'handles params without max_tokens' do + params = { other: 'value', custom: 123 } + result = provider.send(:apply_parameter_mappings, params) + + expect(result).to eq({ other: 'value', custom: 123 }) + end + end end